Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 16 additions & 80 deletions sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,74 +205,6 @@ def get_field_metadata(field: Any) -> Any:
return FakeMetadata()


def sqlmodel_table_construct(
*,
self_instance: _TSQLModel,
values: dict[str, Any],
_fields_set: set[str] | None = None,
) -> _TSQLModel:
# Copy from Pydantic's BaseModel.construct()
# Ref: https://github.com/pydantic/pydantic/blob/v2.5.2/pydantic/main.py#L198
# Modified to not include everything, only the model fields, and to
# set relationships
# SQLModel override to get class SQLAlchemy __dict__ attributes and
# set them back in after creating the object
# new_obj = cls.__new__(cls)
cls = type(self_instance)
old_dict = self_instance.__dict__.copy()
# End SQLModel override

fields_values: dict[str, Any] = {}
defaults: dict[
str, Any
] = {} # keeping this separate from `fields_values` helps us compute `_fields_set`
for name, field in cls.model_fields.items():
if field.alias and field.alias in values:
fields_values[name] = values.pop(field.alias)
elif name in values:
fields_values[name] = values.pop(name)
elif not field.is_required():
defaults[name] = field.get_default(call_default_factory=True)
if _fields_set is None:
_fields_set = set(fields_values.keys())
fields_values.update(defaults)

_extra: dict[str, Any] | None = None
if cls.model_config.get("extra") == "allow":
_extra = {}
for k, v in values.items():
_extra[k] = v
# SQLModel override, do not include everything, only the model fields
# else:
# fields_values.update(values)
# End SQLModel override
# SQLModel override
# Do not set __dict__, instead use setattr to trigger SQLAlchemy
# object.__setattr__(new_obj, "__dict__", fields_values)
# instrumentation
for key, value in {**old_dict, **fields_values}.items():
setattr(self_instance, key, value)
# End SQLModel override
object.__setattr__(self_instance, "__pydantic_fields_set__", _fields_set)
if not cls.__pydantic_root_model__:
object.__setattr__(self_instance, "__pydantic_extra__", _extra)

if cls.__pydantic_post_init__:
self_instance.model_post_init(None)
elif not cls.__pydantic_root_model__:
# Note: if there are any private attributes, cls.__pydantic_post_init__ would exist
# Since it doesn't, that means that `__pydantic_private__` should be set to None
object.__setattr__(self_instance, "__pydantic_private__", None)
# SQLModel override, set relationships
# Get and set any relationship objects
for key in self_instance.__sqlmodel_relationships__:
value = values.get(key, Undefined)
if value is not Undefined:
setattr(self_instance, key, value)
# End SQLModel override
return self_instance


def sqlmodel_validate(
cls: type[_TSQLModel],
obj: Any,
Expand Down Expand Up @@ -328,18 +260,22 @@ def sqlmodel_validate(

def sqlmodel_init(*, self: "SQLModel", data: dict[str, Any]) -> None:
old_dict = self.__dict__.copy()
self.__pydantic_validator__.validate_python(
data,
self_instance=self,
)
if not is_table_model_class(self.__class__):
self.__pydantic_validator__.validate_python(
data,
self_instance=self,
object.__setattr__(
self,
"__dict__",
{**old_dict, **self.__dict__},
)
else:
sqlmodel_table_construct(
self_instance=self,
values=data,
)
object.__setattr__(
self,
"__dict__",
{**old_dict, **self.__dict__},
)
fields_set = self.__pydantic_fields_set__.copy()
for key, value in {**old_dict, **self.__dict__}.items():
setattr(self, key, value)
object.__setattr__(self, "__pydantic_fields_set__", fields_set)
for key in self.__sqlmodel_relationships__:
value = data.get(key, Undefined)
if value is not Undefined:
setattr(self, key, value)
25 changes: 12 additions & 13 deletions tests/test_instance_no_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@
from sqlmodel import Field, Session, SQLModel, create_engine, select


def test_allow_instantiation_without_arguments(clear_sqlmodel):
def test_not_allow_instantiation_without_arguments(clear_sqlmodel):
class Item(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
name: str
description: str | None = None

with pytest.raises(ValidationError):
Item()


def test_allow_instantiation_with_required_arguments(clear_sqlmodel):
class Item(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
name: str
Expand All @@ -12,22 +22,11 @@ class Item(SQLModel, table=True):
engine = create_engine("sqlite:///:memory:")
SQLModel.metadata.create_all(engine)
with Session(engine) as db:
item = Item()
item.name = "Rick"
item = Item(name="Rick")
db.add(item)
db.commit()
statement = select(Item)
result = db.exec(statement).all()
assert len(result) == 1
assert isinstance(item.id, int)
SQLModel.metadata.clear()


def test_not_allow_instantiation_without_arguments_if_not_table():
class Item(SQLModel):
id: int | None = Field(default=None, primary_key=True)
name: str
description: str | None = None

with pytest.raises(ValidationError):
Item()
Loading
Loading