From d5d309008719fa9ec240070a2d18622404ce57bf Mon Sep 17 00:00:00 2001 From: Ewgenij Starostin Date: Wed, 18 Mar 2026 13:03:28 +0100 Subject: [PATCH 1/3] =?UTF-8?q?=F0=9F=90=9B=20Fix=20`select()`=20overloads?= =?UTF-8?q?=20mixing=20positional-only=20parameter=20styles?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the `__`-prefix convention (PEP 484) with explicit `/` syntax (PEP 570) for positional-only parameters in generated `select()` overloads. Pyright enforces that `__`-prefixed (positional-only) parameters cannot appear before non-prefixed parameters in the same signature, which caused 30 errors in `_expression_select_gen.py`. Mypy does not enforce this, so it was never caught. Both the Jinja2 template and the generator script are updated so that `test_select_gen` continues to pass. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/generate_select.py | 2 +- sqlmodel/sql/_expression_select_gen.py | 128 +++++++++++------- sqlmodel/sql/_expression_select_gen.py.jinja2 | 6 +- 3 files changed, 82 insertions(+), 54 deletions(-) diff --git a/scripts/generate_select.py b/scripts/generate_select.py index cbb842b367..66729c69dc 100644 --- a/scripts/generate_select.py +++ b/scripts/generate_select.py @@ -37,7 +37,7 @@ class Arg(BaseModel): else: t_type = f"_T{i}" t_var = f"_TCCA[{t_type}]" - arg = Arg(name=f"__ent{i}", annotation=t_var) + arg = Arg(name=f"ent{i}", annotation=t_var) ret_type = t_type args.append(arg) return_types.append(ret_type) diff --git a/sqlmodel/sql/_expression_select_gen.py b/sqlmodel/sql/_expression_select_gen.py index 83d934c68b..4227bfdd0a 100644 --- a/sqlmodel/sql/_expression_select_gen.py +++ b/sqlmodel/sql/_expression_select_gen.py @@ -103,11 +103,11 @@ @overload -def select(__ent0: _TCCA[_T0]) -> SelectOfScalar[_T0]: ... +def select(ent0: _TCCA[_T0], /) -> SelectOfScalar[_T0]: ... @overload -def select(__ent0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore +def select(ent0: _TScalar_0, /) -> SelectOfScalar[_TScalar_0]: # type: ignore ... @@ -116,22 +116,25 @@ def select(__ent0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore @overload def select( # type: ignore - __ent0: _TCCA[_T0], - __ent1: _TCCA[_T1], + ent0: _TCCA[_T0], + ent1: _TCCA[_T1], + /, ) -> Select[tuple[_T0, _T1]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], + ent0: _TCCA[_T0], entity_1: _TScalar_1, + /, ) -> Select[tuple[_T0, _TScalar_1]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - __ent1: _TCCA[_T1], + ent1: _TCCA[_T1], + /, ) -> Select[tuple[_TScalar_0, _T1]]: ... @@ -139,54 +142,61 @@ def select( # type: ignore def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, + /, ) -> Select[tuple[_TScalar_0, _TScalar_1]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], - __ent1: _TCCA[_T1], - __ent2: _TCCA[_T2], + ent0: _TCCA[_T0], + ent1: _TCCA[_T1], + ent2: _TCCA[_T2], + /, ) -> Select[tuple[_T0, _T1, _T2]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], - __ent1: _TCCA[_T1], + ent0: _TCCA[_T0], + ent1: _TCCA[_T1], entity_2: _TScalar_2, + /, ) -> Select[tuple[_T0, _T1, _TScalar_2]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], + ent0: _TCCA[_T0], entity_1: _TScalar_1, - __ent2: _TCCA[_T2], + ent2: _TCCA[_T2], + /, ) -> Select[tuple[_T0, _TScalar_1, _T2]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], + ent0: _TCCA[_T0], entity_1: _TScalar_1, entity_2: _TScalar_2, + /, ) -> Select[tuple[_T0, _TScalar_1, _TScalar_2]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - __ent1: _TCCA[_T1], - __ent2: _TCCA[_T2], + ent1: _TCCA[_T1], + ent2: _TCCA[_T2], + /, ) -> Select[tuple[_TScalar_0, _T1, _T2]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - __ent1: _TCCA[_T1], + ent1: _TCCA[_T1], entity_2: _TScalar_2, + /, ) -> Select[tuple[_TScalar_0, _T1, _TScalar_2]]: ... @@ -194,7 +204,8 @@ def select( # type: ignore def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, - __ent2: _TCCA[_T2], + ent2: _TCCA[_T2], + /, ) -> Select[tuple[_TScalar_0, _TScalar_1, _T2]]: ... @@ -203,114 +214,127 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, entity_2: _TScalar_2, + /, ) -> Select[tuple[_TScalar_0, _TScalar_1, _TScalar_2]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], - __ent1: _TCCA[_T1], - __ent2: _TCCA[_T2], - __ent3: _TCCA[_T3], + ent0: _TCCA[_T0], + ent1: _TCCA[_T1], + ent2: _TCCA[_T2], + ent3: _TCCA[_T3], + /, ) -> Select[tuple[_T0, _T1, _T2, _T3]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], - __ent1: _TCCA[_T1], - __ent2: _TCCA[_T2], + ent0: _TCCA[_T0], + ent1: _TCCA[_T1], + ent2: _TCCA[_T2], entity_3: _TScalar_3, + /, ) -> Select[tuple[_T0, _T1, _T2, _TScalar_3]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], - __ent1: _TCCA[_T1], + ent0: _TCCA[_T0], + ent1: _TCCA[_T1], entity_2: _TScalar_2, - __ent3: _TCCA[_T3], + ent3: _TCCA[_T3], + /, ) -> Select[tuple[_T0, _T1, _TScalar_2, _T3]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], - __ent1: _TCCA[_T1], + ent0: _TCCA[_T0], + ent1: _TCCA[_T1], entity_2: _TScalar_2, entity_3: _TScalar_3, + /, ) -> Select[tuple[_T0, _T1, _TScalar_2, _TScalar_3]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], + ent0: _TCCA[_T0], entity_1: _TScalar_1, - __ent2: _TCCA[_T2], - __ent3: _TCCA[_T3], + ent2: _TCCA[_T2], + ent3: _TCCA[_T3], + /, ) -> Select[tuple[_T0, _TScalar_1, _T2, _T3]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], + ent0: _TCCA[_T0], entity_1: _TScalar_1, - __ent2: _TCCA[_T2], + ent2: _TCCA[_T2], entity_3: _TScalar_3, + /, ) -> Select[tuple[_T0, _TScalar_1, _T2, _TScalar_3]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], + ent0: _TCCA[_T0], entity_1: _TScalar_1, entity_2: _TScalar_2, - __ent3: _TCCA[_T3], + ent3: _TCCA[_T3], + /, ) -> Select[tuple[_T0, _TScalar_1, _TScalar_2, _T3]]: ... @overload def select( # type: ignore - __ent0: _TCCA[_T0], + ent0: _TCCA[_T0], entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: _TScalar_3, + /, ) -> Select[tuple[_T0, _TScalar_1, _TScalar_2, _TScalar_3]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - __ent1: _TCCA[_T1], - __ent2: _TCCA[_T2], - __ent3: _TCCA[_T3], + ent1: _TCCA[_T1], + ent2: _TCCA[_T2], + ent3: _TCCA[_T3], + /, ) -> Select[tuple[_TScalar_0, _T1, _T2, _T3]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - __ent1: _TCCA[_T1], - __ent2: _TCCA[_T2], + ent1: _TCCA[_T1], + ent2: _TCCA[_T2], entity_3: _TScalar_3, + /, ) -> Select[tuple[_TScalar_0, _T1, _T2, _TScalar_3]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - __ent1: _TCCA[_T1], + ent1: _TCCA[_T1], entity_2: _TScalar_2, - __ent3: _TCCA[_T3], + ent3: _TCCA[_T3], + /, ) -> Select[tuple[_TScalar_0, _T1, _TScalar_2, _T3]]: ... @overload def select( # type: ignore entity_0: _TScalar_0, - __ent1: _TCCA[_T1], + ent1: _TCCA[_T1], entity_2: _TScalar_2, entity_3: _TScalar_3, + /, ) -> Select[tuple[_TScalar_0, _T1, _TScalar_2, _TScalar_3]]: ... @@ -318,8 +342,9 @@ def select( # type: ignore def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, - __ent2: _TCCA[_T2], - __ent3: _TCCA[_T3], + ent2: _TCCA[_T2], + ent3: _TCCA[_T3], + /, ) -> Select[tuple[_TScalar_0, _TScalar_1, _T2, _T3]]: ... @@ -327,8 +352,9 @@ def select( # type: ignore def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, - __ent2: _TCCA[_T2], + ent2: _TCCA[_T2], entity_3: _TScalar_3, + /, ) -> Select[tuple[_TScalar_0, _TScalar_1, _T2, _TScalar_3]]: ... @@ -337,7 +363,8 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, entity_2: _TScalar_2, - __ent3: _TCCA[_T3], + ent3: _TCCA[_T3], + /, ) -> Select[tuple[_TScalar_0, _TScalar_1, _TScalar_2, _T3]]: ... @@ -347,6 +374,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: _TScalar_3, + /, ) -> Select[tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TScalar_3]]: ... diff --git a/sqlmodel/sql/_expression_select_gen.py.jinja2 b/sqlmodel/sql/_expression_select_gen.py.jinja2 index 2ce54cb2fa..875b00aa7d 100644 --- a/sqlmodel/sql/_expression_select_gen.py.jinja2 +++ b/sqlmodel/sql/_expression_select_gen.py.jinja2 @@ -48,11 +48,11 @@ _T{{ i }} = TypeVar("_T{{ i }}") # Generated TypeVars end @overload -def select(__ent0: _TCCA[_T0]) -> SelectOfScalar[_T0]: ... +def select(ent0: _TCCA[_T0], /) -> SelectOfScalar[_T0]: ... @overload -def select(__ent0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore +def select(ent0: _TScalar_0, /) -> SelectOfScalar[_TScalar_0]: # type: ignore ... @@ -62,7 +62,7 @@ def select(__ent0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore @overload def select( # type: ignore - {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %} + {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}/, ) -> Select[tuple[{%for ret in signature[1] %}{{ ret }} {% if not loop.last %}, {% endif %}{% endfor %}]]: ... {% endfor %} From c619044fa3b673e3126983f870370548bc221971 Mon Sep 17 00:00:00 2001 From: Ewgenij Starostin Date: Wed, 18 Mar 2026 13:07:21 +0100 Subject: [PATCH 2/3] =?UTF-8?q?=F0=9F=90=9B=20Widen=20`sa=5Ftype`=20annota?= =?UTF-8?q?tion=20to=20accept=20`TypeEngine`=20instances?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `sa_type` parameter on `Field()` was typed as `type[Any]` (class only), but the runtime in `get_sqlalchemy_type()` already accepts both classes and instances. Users passing instances like `BigInteger()` or `Numeric(precision=10, scale=2)` got false type errors from pyright. Widen the annotation to `type[Any] | TypeEngine[Any]` in the `FieldInfoMetadata` dataclass, all three `Field()` overloads, and the implementation. Add tests for `BigInteger()` and `Numeric(...)` instances in `test_field_sa_column.py`. Co-Authored-By: Claude Opus 4.6 (1M context) --- sqlmodel/main.py | 9 +++--- tests/test_field_sa_column.py | 52 +++++++++++++++++++++++++++++++++-- 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 300031de8b..7f803e3b65 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -50,6 +50,7 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid +from sqlalchemy.types import TypeEngine from typing_extensions import deprecated from ._compat import ( # type: ignore[attr-defined] @@ -209,7 +210,7 @@ class FieldInfoMetadata: ondelete: OnDeleteType | UndefinedType = Undefined unique: bool | UndefinedType = Undefined index: bool | UndefinedType = Undefined - sa_type: type[Any] | UndefinedType = Undefined + sa_type: type[Any] | TypeEngine[Any] | UndefinedType = Undefined sa_column: Column[Any] | UndefinedType = Undefined sa_column_args: Sequence[Any] | UndefinedType = Undefined sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined @@ -268,7 +269,7 @@ def Field( unique: bool | UndefinedType = Undefined, nullable: bool | UndefinedType = Undefined, index: bool | UndefinedType = Undefined, - sa_type: type[Any] | UndefinedType = Undefined, + sa_type: type[Any] | TypeEngine[Any] | UndefinedType = Undefined, sa_column_args: Sequence[Any] | UndefinedType = Undefined, sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined, schema_extra: dict[str, Any] | None = None, @@ -312,7 +313,7 @@ def Field( unique: bool | UndefinedType = Undefined, nullable: bool | UndefinedType = Undefined, index: bool | UndefinedType = Undefined, - sa_type: type[Any] | UndefinedType = Undefined, + sa_type: type[Any] | TypeEngine[Any] | UndefinedType = Undefined, sa_column_args: Sequence[Any] | UndefinedType = Undefined, sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined, schema_extra: dict[str, Any] | None = None, @@ -397,7 +398,7 @@ def Field( unique: bool | UndefinedType = Undefined, nullable: bool | UndefinedType = Undefined, index: bool | UndefinedType = Undefined, - sa_type: type[Any] | UndefinedType = Undefined, + sa_type: type[Any] | TypeEngine[Any] | UndefinedType = Undefined, sa_column: Column | UndefinedType = Undefined, # type: ignore sa_column_args: Sequence[Any] | UndefinedType = Undefined, sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined, diff --git a/tests/test_field_sa_column.py b/tests/test_field_sa_column.py index 1bfca79503..47b1719e65 100644 --- a/tests/test_field_sa_column.py +++ b/tests/test_field_sa_column.py @@ -1,8 +1,10 @@ +from decimal import Decimal from typing import Annotated import pytest -from sqlalchemy import Column, Integer, String -from sqlmodel import Field, SQLModel +from sqlalchemy import BigInteger, Column, Integer, Numeric, String +from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.pool import StaticPool def test_sa_column_takes_precedence() -> None: @@ -130,3 +132,49 @@ class Item(SQLModel, table=True): sa_column=Column(Integer, primary_key=True), ondelete="CASCADE", ) + + +def test_sa_type_instance_biginteger() -> None: + """sa_type accepts TypeEngine instances, not just classes.""" + + class Record(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + big_number: int | None = Field(default=None, sa_type=BigInteger()) + + engine = create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ) + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(Record(big_number=2**40)) + session.commit() + + with Session(engine) as session: + row = session.exec(select(Record)).first() + assert row is not None + assert row.big_number == 2**40 + + +def test_sa_type_instance_numeric() -> None: + """sa_type accepts parameterised TypeEngine instances like Numeric(...).""" + + class Price(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + amount: Decimal | None = Field( + default=None, sa_type=Numeric(precision=10, scale=2) + ) + + engine = create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ) + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(Price(amount=Decimal("19.99"))) + session.commit() + + with Session(engine) as session: + row = session.exec(select(Price)).first() + assert row is not None + assert row.amount == Decimal("19.99") From f738f673f0389b5f9d2d8c05d3e03c8eb9f594c8 Mon Sep 17 00:00:00 2001 From: Ewgenij Starostin Date: Wed, 18 Mar 2026 13:10:33 +0100 Subject: [PATCH 3/3] =?UTF-8?q?=F0=9F=90=9B=20Move=20`=5F=5Ftablename=5F?= =?UTF-8?q?=5F`=20default=20from=20`@declared=5Fattr`=20to=20metaclass?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `SQLModel` base class declared `__tablename__` both as a `ClassVar` and as a `@declared_attr` method. Pyright saw the descriptor type from `@declared_attr`, causing `reportAssignmentType` errors when subclasses set `__tablename__ = "my_table"` — a pattern that works at runtime. Replace the `@declared_attr` method with a default set in `SQLModelMetaclass.__new__` via `dict_used`, before class creation. This preserves the `ClassVar[str | Callable]` declaration and makes plain string assignments work without type errors. Fixes #98. Co-Authored-By: Claude Opus 4.6 (1M context) --- sqlmodel/main.py | 13 +++--- tests/test_tablename.py | 87 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 5 deletions(-) create mode 100644 tests/test_tablename.py diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 7f803e3b65..c42e17bd62 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -41,7 +41,6 @@ from sqlalchemy.orm import ( Mapped, RelationshipProperty, - declared_attr, registry, relationship, ) @@ -567,6 +566,14 @@ def __new__( "__sqlmodel_relationships__": relationships, "__annotations__": pydantic_annotations, } + # Set default __tablename__ before class creation so it's part of the + # class dict rather than assigned through the class object afterwards. + is_table = ( + kwargs.get("table") is True + or (class_dict.get("model_config") or {}).get("table") is True + ) + if is_table and "__tablename__" not in class_dict: + dict_used["__tablename__"] = name.lower() # Duplicate logic from Pydantic to filter config kwargs because if they are # passed directly including the registry Pydantic will pass them over to the # superclass causing an error @@ -866,10 +873,6 @@ def __repr_args__(self) -> Sequence[tuple[str | None, Any]]: if not (isinstance(k, str) and k.startswith("_sa_")) ] - @declared_attr # type: ignore - def __tablename__(cls) -> str: - return cls.__name__.lower() - @classmethod def model_validate( # type: ignore[override] cls: type[_TSQLModel], diff --git a/tests/test_tablename.py b/tests/test_tablename.py new file mode 100644 index 0000000000..da82b9e5ca --- /dev/null +++ b/tests/test_tablename.py @@ -0,0 +1,87 @@ +from sqlalchemy import inspect +from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.pool import StaticPool + + +def _engine(): + return create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ) + + +def test_default_tablename() -> None: + """table=True models get __tablename__ = classname.lower() by default.""" + + class Gadget(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + + assert Gadget.__tablename__ == "gadget" + + engine = _engine() + SQLModel.metadata.create_all(engine) + assert inspect(engine).has_table("gadget") + + +def test_explicit_tablename() -> None: + """An explicit __tablename__ overrides the default.""" + + class Widget(SQLModel, table=True): + __tablename__ = "custom_widgets" + id: int | None = Field(default=None, primary_key=True) + name: str + + assert Widget.__tablename__ == "custom_widgets" + + engine = _engine() + SQLModel.metadata.create_all(engine) + assert inspect(engine).has_table("custom_widgets") + assert not inspect(engine).has_table("widget") + + with Session(engine) as session: + session.add(Widget(name="sprocket")) + session.commit() + + with Session(engine) as session: + row = session.exec(select(Widget)).first() + assert row is not None + assert row.name == "sprocket" + + +def test_tablename_inheritance_default() -> None: + """A subclass that is also a table gets its own default __tablename__.""" + + class BaseThing(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + kind: str = "base" + + class SubThing(BaseThing, table=True): + extra: str | None = None + + assert BaseThing.__tablename__ == "basething" + assert SubThing.__tablename__ == "subthing" + + +def test_tablename_inheritance_explicit_child() -> None: + """A subclass can set its own __tablename__, visible on the class.""" + + class Vehicle(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + kind: str = "" + + class Truck(Vehicle, table=True): + __tablename__ = "trucks" + payload: int | None = None + + assert Vehicle.__tablename__ == "vehicle" + assert Truck.__tablename__ == "trucks" + + +def test_tablename_not_set_on_plain_model() -> None: + """Non-table models don't get a __tablename__ injected.""" + + class Schema(SQLModel): + name: str + + assert not hasattr(Schema, "__tablename__") or not isinstance( + Schema.__tablename__, str + )