diff --git a/scripts/generate_select.py b/scripts/generate_select.py index cbb842b367..66729c69dc 100644 --- a/scripts/generate_select.py +++ b/scripts/generate_select.py @@ -37,7 +37,7 @@ class Arg(BaseModel): else: t_type = f"_T{i}" t_var = f"_TCCA[{t_type}]" - arg = Arg(name=f"__ent{i}", annotation=t_var) + arg = Arg(name=f"ent{i}", annotation=t_var) ret_type = t_type args.append(arg) return_types.append(ret_type) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 300031de8b..c42e17bd62 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -41,7 +41,6 @@ from sqlalchemy.orm import ( Mapped, RelationshipProperty, - declared_attr, registry, relationship, ) @@ -50,6 +49,7 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid +from sqlalchemy.types import TypeEngine from typing_extensions import deprecated from ._compat import ( # type: ignore[attr-defined] @@ -209,7 +209,7 @@ class FieldInfoMetadata: ondelete: OnDeleteType | UndefinedType = Undefined unique: bool | UndefinedType = Undefined index: bool | UndefinedType = Undefined - sa_type: type[Any] | UndefinedType = Undefined + sa_type: type[Any] | TypeEngine[Any] | UndefinedType = Undefined sa_column: Column[Any] | UndefinedType = Undefined sa_column_args: Sequence[Any] | UndefinedType = Undefined sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined @@ -268,7 +268,7 @@ def Field( unique: bool | UndefinedType = Undefined, nullable: bool | UndefinedType = Undefined, index: bool | UndefinedType = Undefined, - sa_type: type[Any] | UndefinedType = Undefined, + sa_type: type[Any] | TypeEngine[Any] | UndefinedType = Undefined, sa_column_args: Sequence[Any] | UndefinedType = Undefined, sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined, schema_extra: dict[str, Any] | None = None, @@ -312,7 +312,7 @@ def Field( unique: bool | UndefinedType = Undefined, nullable: bool | UndefinedType = Undefined, index: bool | UndefinedType = Undefined, - sa_type: type[Any] | UndefinedType = Undefined, + sa_type: type[Any] | TypeEngine[Any] | UndefinedType = Undefined, sa_column_args: Sequence[Any] | UndefinedType = Undefined, sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined, schema_extra: dict[str, Any] | None = None, @@ -397,7 +397,7 @@ def Field( unique: bool | UndefinedType = Undefined, nullable: bool | UndefinedType = Undefined, index: bool | UndefinedType = Undefined, - sa_type: type[Any] | UndefinedType = Undefined, + sa_type: type[Any] | TypeEngine[Any] | UndefinedType = Undefined, sa_column: Column | UndefinedType = Undefined, # type: ignore sa_column_args: Sequence[Any] | UndefinedType = Undefined, sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined, @@ -566,6 +566,14 @@ def __new__( "__sqlmodel_relationships__": relationships, "__annotations__": pydantic_annotations, } + # Set default __tablename__ before class creation so it's part of the + # class dict rather than assigned through the class object afterwards. + is_table = ( + kwargs.get("table") is True + or (class_dict.get("model_config") or {}).get("table") is True + ) + if is_table and "__tablename__" not in class_dict: + dict_used["__tablename__"] = name.lower() # Duplicate logic from Pydantic to filter config kwargs because if they are # passed directly including the registry Pydantic will pass them over to the # superclass causing an error @@ -865,10 +873,6 @@ def __repr_args__(self) -> Sequence[tuple[str | None, Any]]: if not (isinstance(k, str) and k.startswith("_sa_")) ] - @declared_attr # type: ignore - def __tablename__(cls) -> str: - return cls.__name__.lower() - @classmethod def model_validate( # type: ignore[override] cls: type[_TSQLModel], diff --git a/sqlmodel/sql/_expression_select_gen.py b/sqlmodel/sql/_expression_select_gen.py index 83d934c68b..4227bfdd0a 100644 --- a/sqlmodel/sql/_expression_select_gen.py +++ b/sqlmodel/sql/_expression_select_gen.py @@ -103,11 +103,11 @@ @overload -def select(__ent0: _TCCA[_T0]) -> SelectOfScalar[_T0]: ... +def select(ent0: _TCCA[_T0], /) -> SelectOfScalar[_T0]: ... @overload -def select(__ent0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore +def select(ent0: _TScalar_0, /) -> SelectOfScalar[_TScalar_0]: # type: ignore ... @@ -116,22 +116,25 @@ def select(__ent0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore @overload def select( # type: ignore - __ent0: _TCCA[_T0], - __ent1: _TCCA[_T1], + ent0: _TCCA[_T0], + ent1: _TCCA[_T1], + /, ) -> Select[tuple[_T0, _T1]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], + ent0: _TCCA[_T0], entity_1: _TScalar_1, + /, ) -> Select[tuple[_T0, _TScalar_1]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - __ent1: _TCCA[_T1], + ent1: _TCCA[_T1], + /, ) -> Select[tuple[_TScalar_0, _T1]]: ... @@ -139,54 +142,61 @@ def select( # type: ignore def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, + /, ) -> Select[tuple[_TScalar_0, _TScalar_1]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], - __ent1: _TCCA[_T1], - __ent2: _TCCA[_T2], + ent0: _TCCA[_T0], + ent1: _TCCA[_T1], + ent2: _TCCA[_T2], + /, ) -> Select[tuple[_T0, _T1, _T2]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], - __ent1: _TCCA[_T1], + ent0: _TCCA[_T0], + ent1: _TCCA[_T1], entity_2: _TScalar_2, + /, ) -> Select[tuple[_T0, _T1, _TScalar_2]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], + ent0: _TCCA[_T0], entity_1: _TScalar_1, - __ent2: _TCCA[_T2], + ent2: _TCCA[_T2], + /, ) -> Select[tuple[_T0, _TScalar_1, _T2]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], + ent0: _TCCA[_T0], entity_1: _TScalar_1, entity_2: _TScalar_2, + /, ) -> Select[tuple[_T0, _TScalar_1, _TScalar_2]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - __ent1: _TCCA[_T1], - __ent2: _TCCA[_T2], + ent1: _TCCA[_T1], + ent2: _TCCA[_T2], + /, ) -> Select[tuple[_TScalar_0, _T1, _T2]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - __ent1: _TCCA[_T1], + ent1: _TCCA[_T1], entity_2: _TScalar_2, + /, ) -> Select[tuple[_TScalar_0, _T1, _TScalar_2]]: ... @@ -194,7 +204,8 @@ def select( # type: ignore def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, - __ent2: _TCCA[_T2], + ent2: _TCCA[_T2], + /, ) -> Select[tuple[_TScalar_0, _TScalar_1, _T2]]: ... @@ -203,114 +214,127 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, entity_2: _TScalar_2, + /, ) -> Select[tuple[_TScalar_0, _TScalar_1, _TScalar_2]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], - __ent1: _TCCA[_T1], - __ent2: _TCCA[_T2], - __ent3: _TCCA[_T3], + ent0: _TCCA[_T0], + ent1: _TCCA[_T1], + ent2: _TCCA[_T2], + ent3: _TCCA[_T3], + /, ) -> Select[tuple[_T0, _T1, _T2, _T3]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], - __ent1: _TCCA[_T1], - __ent2: _TCCA[_T2], + ent0: _TCCA[_T0], + ent1: _TCCA[_T1], + ent2: _TCCA[_T2], entity_3: _TScalar_3, + /, ) -> Select[tuple[_T0, _T1, _T2, _TScalar_3]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], - __ent1: _TCCA[_T1], + ent0: _TCCA[_T0], + ent1: _TCCA[_T1], entity_2: _TScalar_2, - __ent3: _TCCA[_T3], + ent3: _TCCA[_T3], + /, ) -> Select[tuple[_T0, _T1, _TScalar_2, _T3]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], - __ent1: _TCCA[_T1], + ent0: _TCCA[_T0], + ent1: _TCCA[_T1], entity_2: _TScalar_2, entity_3: _TScalar_3, + /, ) -> Select[tuple[_T0, _T1, _TScalar_2, _TScalar_3]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], + ent0: _TCCA[_T0], entity_1: _TScalar_1, - __ent2: _TCCA[_T2], - __ent3: _TCCA[_T3], + ent2: _TCCA[_T2], + ent3: _TCCA[_T3], + /, ) -> Select[tuple[_T0, _TScalar_1, _T2, _T3]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], + ent0: _TCCA[_T0], entity_1: _TScalar_1, - __ent2: _TCCA[_T2], + ent2: _TCCA[_T2], entity_3: _TScalar_3, + /, ) -> Select[tuple[_T0, _TScalar_1, _T2, _TScalar_3]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], + ent0: _TCCA[_T0], entity_1: _TScalar_1, entity_2: _TScalar_2, - __ent3: _TCCA[_T3], + ent3: _TCCA[_T3], + /, ) -> Select[tuple[_T0, _TScalar_1, _TScalar_2, _T3]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], + ent0: _TCCA[_T0], entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: _TScalar_3, + /, ) -> Select[tuple[_T0, _TScalar_1, _TScalar_2, _TScalar_3]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - __ent1: _TCCA[_T1], - __ent2: _TCCA[_T2], - __ent3: _TCCA[_T3], + ent1: _TCCA[_T1], + ent2: _TCCA[_T2], + ent3: _TCCA[_T3], + /, ) -> Select[tuple[_TScalar_0, _T1, _T2, _T3]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - __ent1: _TCCA[_T1], - __ent2: _TCCA[_T2], + ent1: _TCCA[_T1], + ent2: _TCCA[_T2], entity_3: _TScalar_3, + /, ) -> Select[tuple[_TScalar_0, _T1, _T2, _TScalar_3]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - __ent1: _TCCA[_T1], + ent1: _TCCA[_T1], entity_2: _TScalar_2, - __ent3: _TCCA[_T3], + ent3: _TCCA[_T3], + /, ) -> Select[tuple[_TScalar_0, _T1, _TScalar_2, _T3]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - __ent1: _TCCA[_T1], + ent1: _TCCA[_T1], entity_2: _TScalar_2, entity_3: _TScalar_3, + /, ) -> Select[tuple[_TScalar_0, _T1, _TScalar_2, _TScalar_3]]: ... @@ -318,8 +342,9 @@ def select( # type: ignore def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, - __ent2: _TCCA[_T2], - __ent3: _TCCA[_T3], + ent2: _TCCA[_T2], + ent3: _TCCA[_T3], + /, ) -> Select[tuple[_TScalar_0, _TScalar_1, _T2, _T3]]: ... @@ -327,8 +352,9 @@ def select( # type: ignore def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, - __ent2: _TCCA[_T2], + ent2: _TCCA[_T2], entity_3: _TScalar_3, + /, ) -> Select[tuple[_TScalar_0, _TScalar_1, _T2, _TScalar_3]]: ... @@ -337,7 +363,8 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, entity_2: _TScalar_2, - __ent3: _TCCA[_T3], + ent3: _TCCA[_T3], + /, ) -> Select[tuple[_TScalar_0, _TScalar_1, _TScalar_2, _T3]]: ... @@ -347,6 +374,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: _TScalar_3, + /, ) -> Select[tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TScalar_3]]: ... diff --git a/sqlmodel/sql/_expression_select_gen.py.jinja2 b/sqlmodel/sql/_expression_select_gen.py.jinja2 index 2ce54cb2fa..875b00aa7d 100644 --- a/sqlmodel/sql/_expression_select_gen.py.jinja2 +++ b/sqlmodel/sql/_expression_select_gen.py.jinja2 @@ -48,11 +48,11 @@ _T{{ i }} = TypeVar("_T{{ i }}") # Generated TypeVars end @overload -def select(__ent0: _TCCA[_T0]) -> SelectOfScalar[_T0]: ... +def select(ent0: _TCCA[_T0], /) -> SelectOfScalar[_T0]: ... @overload -def select(__ent0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore +def select(ent0: _TScalar_0, /) -> SelectOfScalar[_TScalar_0]: # type: ignore ... @@ -62,7 +62,7 @@ def select(__ent0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore @overload def select( # type: ignore - {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %} + {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}/, ) -> Select[tuple[{%for ret in signature[1] %}{{ ret }} {% if not loop.last %}, {% endif %}{% endfor %}]]: ... {% endfor %} diff --git a/tests/test_field_sa_column.py b/tests/test_field_sa_column.py index 1bfca79503..47b1719e65 100644 --- a/tests/test_field_sa_column.py +++ b/tests/test_field_sa_column.py @@ -1,8 +1,10 @@ +from decimal import Decimal from typing import Annotated import pytest -from sqlalchemy import Column, Integer, String -from sqlmodel import Field, SQLModel +from sqlalchemy import BigInteger, Column, Integer, Numeric, String +from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.pool import StaticPool def test_sa_column_takes_precedence() -> None: @@ -130,3 +132,49 @@ class Item(SQLModel, table=True): sa_column=Column(Integer, primary_key=True), ondelete="CASCADE", ) + + +def test_sa_type_instance_biginteger() -> None: + """sa_type accepts TypeEngine instances, not just classes.""" + + class Record(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + big_number: int | None = Field(default=None, sa_type=BigInteger()) + + engine = create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ) + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(Record(big_number=2**40)) + session.commit() + + with Session(engine) as session: + row = session.exec(select(Record)).first() + assert row is not None + assert row.big_number == 2**40 + + +def test_sa_type_instance_numeric() -> None: + """sa_type accepts parameterised TypeEngine instances like Numeric(...).""" + + class Price(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + amount: Decimal | None = Field( + default=None, sa_type=Numeric(precision=10, scale=2) + ) + + engine = create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ) + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(Price(amount=Decimal("19.99"))) + session.commit() + + with Session(engine) as session: + row = session.exec(select(Price)).first() + assert row is not None + assert row.amount == Decimal("19.99") diff --git a/tests/test_tablename.py b/tests/test_tablename.py new file mode 100644 index 0000000000..da82b9e5ca --- /dev/null +++ b/tests/test_tablename.py @@ -0,0 +1,87 @@ +from sqlalchemy import inspect +from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.pool import StaticPool + + +def _engine(): + return create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ) + + +def test_default_tablename() -> None: + """table=True models get __tablename__ = classname.lower() by default.""" + + class Gadget(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + + assert Gadget.__tablename__ == "gadget" + + engine = _engine() + SQLModel.metadata.create_all(engine) + assert inspect(engine).has_table("gadget") + + +def test_explicit_tablename() -> None: + """An explicit __tablename__ overrides the default.""" + + class Widget(SQLModel, table=True): + __tablename__ = "custom_widgets" + id: int | None = Field(default=None, primary_key=True) + name: str + + assert Widget.__tablename__ == "custom_widgets" + + engine = _engine() + SQLModel.metadata.create_all(engine) + assert inspect(engine).has_table("custom_widgets") + assert not inspect(engine).has_table("widget") + + with Session(engine) as session: + session.add(Widget(name="sprocket")) + session.commit() + + with Session(engine) as session: + row = session.exec(select(Widget)).first() + assert row is not None + assert row.name == "sprocket" + + +def test_tablename_inheritance_default() -> None: + """A subclass that is also a table gets its own default __tablename__.""" + + class BaseThing(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + kind: str = "base" + + class SubThing(BaseThing, table=True): + extra: str | None = None + + assert BaseThing.__tablename__ == "basething" + assert SubThing.__tablename__ == "subthing" + + +def test_tablename_inheritance_explicit_child() -> None: + """A subclass can set its own __tablename__, visible on the class.""" + + class Vehicle(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + kind: str = "" + + class Truck(Vehicle, table=True): + __tablename__ = "trucks" + payload: int | None = None + + assert Vehicle.__tablename__ == "vehicle" + assert Truck.__tablename__ == "trucks" + + +def test_tablename_not_set_on_plain_model() -> None: + """Non-table models don't get a __tablename__ injected.""" + + class Schema(SQLModel): + name: str + + assert not hasattr(Schema, "__tablename__") or not isinstance( + Schema.__tablename__, str + )