diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index a220b193f1..11457bd911 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -205,74 +205,6 @@ def get_field_metadata(field: Any) -> Any: return FakeMetadata() -def sqlmodel_table_construct( - *, - self_instance: _TSQLModel, - values: dict[str, Any], - _fields_set: set[str] | None = None, -) -> _TSQLModel: - # Copy from Pydantic's BaseModel.construct() - # Ref: https://github.com/pydantic/pydantic/blob/v2.5.2/pydantic/main.py#L198 - # Modified to not include everything, only the model fields, and to - # set relationships - # SQLModel override to get class SQLAlchemy __dict__ attributes and - # set them back in after creating the object - # new_obj = cls.__new__(cls) - cls = type(self_instance) - old_dict = self_instance.__dict__.copy() - # End SQLModel override - - fields_values: dict[str, Any] = {} - defaults: dict[ - str, Any - ] = {} # keeping this separate from `fields_values` helps us compute `_fields_set` - for name, field in cls.model_fields.items(): - if field.alias and field.alias in values: - fields_values[name] = values.pop(field.alias) - elif name in values: - fields_values[name] = values.pop(name) - elif not field.is_required(): - defaults[name] = field.get_default(call_default_factory=True) - if _fields_set is None: - _fields_set = set(fields_values.keys()) - fields_values.update(defaults) - - _extra: dict[str, Any] | None = None - if cls.model_config.get("extra") == "allow": - _extra = {} - for k, v in values.items(): - _extra[k] = v - # SQLModel override, do not include everything, only the model fields - # else: - # fields_values.update(values) - # End SQLModel override - # SQLModel override - # Do not set __dict__, instead use setattr to trigger SQLAlchemy - # object.__setattr__(new_obj, "__dict__", fields_values) - # instrumentation - for key, value in {**old_dict, **fields_values}.items(): - setattr(self_instance, key, value) - # End SQLModel override - object.__setattr__(self_instance, "__pydantic_fields_set__", _fields_set) - if not cls.__pydantic_root_model__: - object.__setattr__(self_instance, "__pydantic_extra__", _extra) - - if cls.__pydantic_post_init__: - self_instance.model_post_init(None) - elif not cls.__pydantic_root_model__: - # Note: if there are any private attributes, cls.__pydantic_post_init__ would exist - # Since it doesn't, that means that `__pydantic_private__` should be set to None - object.__setattr__(self_instance, "__pydantic_private__", None) - # SQLModel override, set relationships - # Get and set any relationship objects - for key in self_instance.__sqlmodel_relationships__: - value = values.get(key, Undefined) - if value is not Undefined: - setattr(self_instance, key, value) - # End SQLModel override - return self_instance - - def sqlmodel_validate( cls: type[_TSQLModel], obj: Any, @@ -328,18 +260,22 @@ def sqlmodel_validate( def sqlmodel_init(*, self: "SQLModel", data: dict[str, Any]) -> None: old_dict = self.__dict__.copy() + self.__pydantic_validator__.validate_python( + data, + self_instance=self, + ) if not is_table_model_class(self.__class__): - self.__pydantic_validator__.validate_python( - data, - self_instance=self, + object.__setattr__( + self, + "__dict__", + {**old_dict, **self.__dict__}, ) else: - sqlmodel_table_construct( - self_instance=self, - values=data, - ) - object.__setattr__( - self, - "__dict__", - {**old_dict, **self.__dict__}, - ) + fields_set = self.__pydantic_fields_set__.copy() + for key, value in {**old_dict, **self.__dict__}.items(): + setattr(self, key, value) + object.__setattr__(self, "__pydantic_fields_set__", fields_set) + for key in self.__sqlmodel_relationships__: + value = data.get(key, Undefined) + if value is not Undefined: + setattr(self, key, value) diff --git a/tests/test_instance_no_args.py b/tests/test_instance_no_args.py index 72680dfff9..c15d5a8bb3 100644 --- a/tests/test_instance_no_args.py +++ b/tests/test_instance_no_args.py @@ -3,7 +3,17 @@ from sqlmodel import Field, Session, SQLModel, create_engine, select -def test_allow_instantiation_without_arguments(clear_sqlmodel): +def test_not_allow_instantiation_without_arguments(clear_sqlmodel): + class Item(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + description: str | None = None + + with pytest.raises(ValidationError): + Item() + + +def test_allow_instantiation_with_required_arguments(clear_sqlmodel): class Item(SQLModel, table=True): id: int | None = Field(default=None, primary_key=True) name: str @@ -12,8 +22,7 @@ class Item(SQLModel, table=True): engine = create_engine("sqlite:///:memory:") SQLModel.metadata.create_all(engine) with Session(engine) as db: - item = Item() - item.name = "Rick" + item = Item(name="Rick") db.add(item) db.commit() statement = select(Item) @@ -21,13 +30,3 @@ class Item(SQLModel, table=True): assert len(result) == 1 assert isinstance(item.id, int) SQLModel.metadata.clear() - - -def test_not_allow_instantiation_without_arguments_if_not_table(): - class Item(SQLModel): - id: int | None = Field(default=None, primary_key=True) - name: str - description: str | None = None - - with pytest.raises(ValidationError): - Item() diff --git a/tests/test_validation.py b/tests/test_validation.py index 47fbca87c2..67d636f6b1 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,6 +1,6 @@ import pytest from pydantic.error_wrappers import ValidationError -from sqlmodel import SQLModel +from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select def test_validation_pydantic_v2(clear_sqlmodel): @@ -29,3 +29,242 @@ def reject_none(cls, v): with pytest.raises(ValidationError): Hero.model_validate({"name": None, "age": 25}) + + +def test_table_model_field_validator(clear_sqlmodel): + from pydantic import field_validator + + class Hero(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + age: int | None = None + + @field_validator("name") + @classmethod + def name_must_not_be_empty(cls, v: str) -> str: + if not v.strip(): + raise ValueError("name must not be empty") + return v + + Hero(name="Deadpond", age=25) + + with pytest.raises(ValidationError): + Hero(name="", age=25) + + +def test_table_model_field_validator_before_mode(clear_sqlmodel): + from pydantic import field_validator + + class Hero(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + + @field_validator("name", mode="before") + @classmethod + def coerce_name(cls, v: object) -> str: + if isinstance(v, int): + return f"Hero-{v}" + return v + + hero = Hero(name=42) + assert hero.name == "Hero-42" + + +def test_table_model_model_validator_after(clear_sqlmodel): + from pydantic import model_validator + + class Hero(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + secret_name: str + + @model_validator(mode="after") + def names_must_differ(self) -> "Hero": + if self.name == self.secret_name: + raise ValueError("name and secret_name must differ") + return self + + Hero(name="Deadpond", secret_name="Dive Wilson") + + with pytest.raises(ValidationError): + Hero(name="Same", secret_name="Same") + + +def test_table_model_model_validator_before(clear_sqlmodel): + from pydantic import model_validator + + class Hero(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + + @model_validator(mode="before") + @classmethod + def uppercase_name(cls, data: dict) -> dict: + if "name" in data: + data["name"] = data["name"].upper() + return data + + hero = Hero(name="deadpond") + assert hero.name == "DEADPOND" + + +def test_table_model_before_validator_annotated(clear_sqlmodel): + from typing import Annotated + + from pydantic import BeforeValidator + + def parse_int(v: object) -> object: + if isinstance(v, str) and v.isdigit(): + return int(v) + return v + + class Hero(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + age: Annotated[int | None, BeforeValidator(parse_int)] = None + + hero = Hero(name="Deadpond", age="25") + assert hero.age == 25 + + +def test_table_model_orm_round_trip_with_validator(clear_sqlmodel): + from pydantic import field_validator + + class Hero(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + age: int | None = None + + @field_validator("age") + @classmethod + def double_age(cls, v: int | None) -> int | None: + if v is not None: + return v * 2 + return v + + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + + hero = Hero(name="Deadpond", age=25) + assert hero.age == 50 + + with Session(engine) as session: + session.add(hero) + session.commit() + session.refresh(hero) + + with Session(engine) as session: + loaded = session.exec(select(Hero)).first() + assert loaded is not None + assert loaded.name == "Deadpond" + assert loaded.age == 50 + + SQLModel.metadata.clear() + + +def test_validation_does_not_run_on_orm_load(clear_sqlmodel): + from pydantic import field_validator + + class Hero(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + + @field_validator("name") + @classmethod + def name_must_be_short(cls, v: str) -> str: + if len(v) > 5: + raise ValueError("too long") + return v + + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(Hero(name="short")) + session.commit() + + with engine.connect() as conn: + conn.execute( + Hero.__table__.update() + .where(Hero.__table__.c.id == 1) + .values(name="this is way too long") + ) + conn.commit() + + with Session(engine) as session: + loaded = session.exec(select(Hero)).first() + assert loaded is not None + assert loaded.name == "this is way too long" + + SQLModel.metadata.clear() + + +def test_table_model_relationship_without_related_object(clear_sqlmodel): + class Team(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + heroes: list["Hero"] = Relationship(back_populates="team") + + class Hero(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + team_id: int | None = Field(default=None, foreign_key="team.id") + team: Team | None = Relationship(back_populates="heroes") + + team = Team(name="Preventers") + hero = Hero(name="Deadpond") + + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(team) + session.add(hero) + session.commit() + session.refresh(hero) + assert hero.team is None + + SQLModel.metadata.clear() + + +def test_table_model_relationship_assigned_after_construction(clear_sqlmodel): + class Team(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + heroes: list["Hero"] = Relationship(back_populates="team") + + class Hero(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + team_id: int | None = Field(default=None, foreign_key="team.id") + team: Team | None = Relationship(back_populates="heroes") + + team = Team(name="Preventers") + hero = Hero(name="Deadpond") + hero.team = team + + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(hero) + session.commit() + session.refresh(hero) + assert hero.team is not None + assert hero.team.name == "Preventers" + + SQLModel.metadata.clear() + + +def test_table_model_model_validate_still_works(clear_sqlmodel): + class Hero(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + age: int | None = None + + hero = Hero.model_validate({"name": "Deadpond", "age": 25}) + assert hero.name == "Deadpond" + assert hero.age == 25 + + with pytest.raises(ValidationError): + Hero.model_validate({"name": "Deadpond", "age": "not a number"})