From 2b3dcf6a68d6d3707a3ebea86a5dda0edcac1bad Mon Sep 17 00:00:00 2001 From: Thomas Neidhart Date: Fri, 16 Feb 2024 11:01:15 +0100 Subject: [PATCH] Support multiple optional fields in a model by removing a hack that resulted in all optional fields being assigned the same type --- odmantic/model.py | 8 +++----- tests/unit/test_field.py | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/odmantic/model.py b/odmantic/model.py index 74e42272..2e53ad31 100644 --- a/odmantic/model.py +++ b/odmantic/model.py @@ -194,11 +194,9 @@ def validate_type(type_: Type) -> Type: # generics is found # https://github.com/pydantic/pydantic/issues/8354 if type_origin is Union: - new_root = Union[ - int, str - ] # We don't care about int,str since they will be replaced - setattr(new_root, "__args__", new_arg_types) - type_ = new_root # type: ignore + # as new_arg_types is a tuple, we can directly create a matching Union instance, + # instead of hacking our way around it: https://stackoverflow.com/a/72884529/3784643 + type_ = Union[new_arg_types] # type: ignore else: type_ = GenericAlias(type_origin, new_arg_types) # type: ignore return type_ diff --git a/tests/unit/test_field.py b/tests/unit/test_field.py index 6f140f74..1e044a9c 100644 --- a/tests/unit/test_field.py +++ b/tests/unit/test_field.py @@ -1,3 +1,7 @@ +from datetime import datetime +from typing import Optional + +import odmantic import pytest from odmantic.field import Field @@ -89,3 +93,19 @@ class M(Model): } assert not M.__odm_fields__["field"].is_required_in_doc() + + +def test_multiple_optional_fields(): + class M(Model): + field: str = Field(default_factory=lambda: "hi") # pragma: no cover + optionalBoolField: Optional[bool] = None + optionalDatetimeField: Optional[datetime] = None + + assert M.__odm_fields__["optionalBoolField"].pydantic_field.annotation == Optional[bool] + assert M.__odm_fields__["optionalDatetimeField"].pydantic_field.annotation == Optional[odmantic.bson._datetime] + + try: + instance = M(field="Hi") + instance.optionalBoolField = True + except: + pytest.fail("a boolean value can not be assigned to a boolean field")