From a73969fc3b8c5da6acf92bd8b2f9262f15ad8a2e Mon Sep 17 00:00:00 2001 From: Leo Germond Date: Tue, 14 Oct 2025 09:16:47 +0200 Subject: [PATCH] support for dict --- src/contractme/typecheck.py | 20 ++++++++++++++++++++ tests/test_annotations.py | 25 +++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/contractme/typecheck.py b/src/contractme/typecheck.py index a2390af..d9ed790 100644 --- a/src/contractme/typecheck.py +++ b/src/contractme/typecheck.py @@ -145,6 +145,7 @@ def resolve_type(t: AnnotationForm) -> AnnotationForm: class ChildMod(Enum): NONE = auto() REPEAT = auto() # 0..n + REPEAT_KV = auto() # {0: a, 1: b ...} OPTION = auto() # A | B | C @@ -176,6 +177,9 @@ class CompositeTypeInfo: ), "Cannot have list[T1, T2...] must either be list[T1] or list[tuple[T1, T2]]" self.child_mod = ChildMod.REPEAT self.children = args + elif self.origin in (dict,): + self.child_mod = ChildMod.REPEAT_KV + self.children = args elif self.origin in (types.UnionType, typing.Union): self.child_mod = ChildMod.OPTION self.children = args @@ -266,6 +270,15 @@ def get_first_structure_error(value, an, recursive_type_info=None) -> str | None struct.check_cond(not err, str(err)) if not struct.is_ok(): break + elif type_info.child_mod == ChildMod.REPEAT_KV: + assert len(type_info.children) == 2 + for v in val.items(): + err = get_first_structure_error(v[0], type_info.children[0]) + struct.check_cond(not err, str(err)) + err = get_first_structure_error(v[1], type_info.children[1]) + struct.check_cond(not err, str(err)) + if not struct.is_ok(): + break elif is_sequence: # depth-first struct.check_cond( @@ -324,6 +337,13 @@ def get_constraints_errors( for v in val: err = get_constraints_errors(v, type_info.children[0]) constraints.check_cond(not err, str(err)) + elif type_info.child_mod == ChildMod.REPEAT_KV: + assert len(type_info.children) == 2 + for v in val.items(): + err = get_constraints_errors(v[0], type_info.children[0]) + constraints.check_cond(not err, str(err)) + err = get_constraints_errors(v[1], type_info.children[1]) + constraints.check_cond(not err, str(err)) elif is_sequence: # depth-first check_values = list(val) + check_values diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 322abb6..3505c4f 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -631,9 +631,9 @@ def test_recursive_type(): assert f([[[e]]]) is e assert f([[[e], e], e]) is e with pytest.raises(AssertionError): - assert f([[[e]], 1]) is e # type: ignore + _ = f([[[e]], 1]) # type: ignore with pytest.raises(AssertionError): - assert f([[[e, None]]]) is e # type: ignore + _ = f([[[e, None]]]) # type: ignore type MutRec1 = list[MutRec2] @@ -651,6 +651,27 @@ def test_mutually_recursive_types(): assert f([[[0]]]) == 0 +type PropsMap = dict[str, Annotated[float, Gt(0.0)]] + + +def test_dict(): + @annotated + def get_val(kv: PropsMap, key: str, default: float = 0.0) -> float: + if key in kv: + return kv[key] + else: + return default + + assert get_val({"toto": 10.0}, "toto") == 10.0 + assert get_val({"toto": 10.0}, "tata") == 0.0 + + with pytest.raises(AssertionError): + _ = get_val({"toto": 10}, "toto") + + with pytest.raises(AssertionError): + _ = get_val({"toto": 0.0}, "toto") + + type Height = Annotated[float, Ge(0.0)] type InternalPressureFlyingAtm = Annotated[float, Interval(ge=0.0, le=1.0)] type Depth = Annotated[float, Le(0.0)] -- GitLab