Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ repos:
language: unsupported
types: [python]

- id: local-mypy
name: mypy check
entry: uv run mypy sqlmodel tests/test_select_typing.py
- id: local-ty
name: ty check
entry: uv run ty check sqlmodel
require_serial: true
language: unsupported
pass_filenames: false
Expand Down
12 changes: 1 addition & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ tests = [
"fastapi >=0.128.0",
"httpx >=0.28.1",
"jinja2 >=3.1.6",
"mypy >=1.19.1",
"pytest >=7.0.1",
"ruff >=0.15.6",
"ty>=0.0.9",
"typing-extensions >=4.15.0",
]

Expand Down Expand Up @@ -124,16 +124,6 @@ exclude_lines = [
[tool.coverage.html]
show_contexts = true

[tool.mypy]
strict = true
exclude = "sqlmodel.sql._expression_select_gen"

[[tool.mypy.overrides]]
module = "docs_src.*"
disallow_incomplete_defs = false
disallow_untyped_defs = false
disallow_untyped_calls = false

[tool.ruff.lint]
select = [
"E", # pycodestyle errors
Expand Down
2 changes: 1 addition & 1 deletion scripts/generate_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Arg(BaseModel):
else:
t_type = f"_T{i}"
t_var = f"_TCCA[{t_type}]"
arg = Arg(name=f"__ent{i}", annotation=t_var)
arg = Arg(name=f"ent{i}", annotation=t_var)
ret_type = t_type
args.append(arg)
return_types.append(ret_type)
Expand Down
4 changes: 2 additions & 2 deletions scripts/lint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
set -e
set -x

mypy sqlmodel
mypy tests/test_select_typing.py
ty check sqlmodel
ty check tests/test_select_typing.py
ruff check sqlmodel tests docs_src scripts
ruff format sqlmodel tests docs_src scripts --check
21 changes: 9 additions & 12 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import builtins
import ipaddress
import uuid
import weakref
from collections.abc import Callable, Mapping, Sequence, Set
from dataclasses import dataclass
from datetime import date, datetime, time, timedelta
Expand Down Expand Up @@ -52,7 +51,7 @@
from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid
from typing_extensions import deprecated

from ._compat import ( # type: ignore[attr-defined]
from ._compat import (
PYDANTIC_MINOR_VERSION,
BaseConfig,
ModelMetaclass,
Expand Down Expand Up @@ -177,7 +176,7 @@ def __init__(
cascade_delete: bool | None = False,
passive_deletes: bool | Literal["all"] | None = False,
link_model: Any | None = None,
sa_relationship: RelationshipProperty | None = None, # type: ignore
sa_relationship: RelationshipProperty | None = None,
sa_relationship_args: Sequence[Any] | None = None,
sa_relationship_kwargs: Mapping[str, Any] | None = None,
) -> None:
Expand Down Expand Up @@ -398,7 +397,7 @@ def Field(
nullable: bool | UndefinedType = Undefined,
index: bool | UndefinedType = Undefined,
sa_type: type[Any] | UndefinedType = Undefined,
sa_column: Column | UndefinedType = Undefined, # type: ignore
sa_column: Column | UndefinedType = Undefined,
sa_column_args: Sequence[Any] | UndefinedType = Undefined,
sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined,
schema_extra: dict[str, Any] | None = None,
Expand Down Expand Up @@ -525,13 +524,13 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
model_fields: ClassVar[dict[str, FieldInfo]]

# Replicate SQLAlchemy
def __setattr__(cls, name: str, value: Any) -> None:
def __setattr__(cls, name: str, value: Any) -> None: # ty: ignore[invalid-method-override]
if is_table_model_class(cls):
DeclarativeMeta.__setattr__(cls, name, value)
else:
super().__setattr__(name, value)

def __delattr__(cls, name: str) -> None:
def __delattr__(cls, name: str) -> None: # ty: ignore[invalid-method-override]
if is_table_model_class(cls):
DeclarativeMeta.__delattr__(cls, name)
else:
Expand Down Expand Up @@ -649,7 +648,7 @@ def __init__(
# Plain forward references, for models not yet defined, are not
# handled well by SQLAlchemy without Mapped, so, wrap the
# annotations in Mapped here
cls.__annotations__[rel_name] = Mapped[ann] # type: ignore[valid-type]
cls.__annotations__[rel_name] = Mapped[ann]
relationship_to = get_relationship_to(
name=rel_name, rel_info=rel_info, annotation=ann
)
Expand Down Expand Up @@ -738,7 +737,7 @@ def get_sqlalchemy_type(field: Any) -> Any:
raise ValueError(f"{type_} has no matching SQLAlchemy type")


def get_column_from_field(field: Any) -> Column: # type: ignore
def get_column_from_field(field: Any) -> Column:
field_info = field
sa_column = _get_sqlmodel_field_value(field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
Expand Down Expand Up @@ -773,7 +772,7 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
assert isinstance(foreign_key, str)
assert isinstance(ondelete_value, (str, type(None))) # for typing
args.append(ForeignKey(foreign_key, ondelete=ondelete_value))
kwargs = {
kwargs: dict[str, Any] = {
"primary_key": primary_key,
"nullable": nullable,
"index": index,
Expand All @@ -797,8 +796,6 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
return Column(sa_type, *args, **kwargs)


class_registry = weakref.WeakValueDictionary() # type: ignore

default_registry = registry()

_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
Expand Down Expand Up @@ -850,7 +847,7 @@ def __setattr__(self, name: str, value: Any) -> None:
return
else:
# Set in SQLAlchemy, before Pydantic to trigger events and updates
if is_table_model_class(self.__class__) and is_instrumented(self, name): # type: ignore[no-untyped-call]
if is_table_model_class(self.__class__) and is_instrumented(self, name):
set_attribute(self, name, value)
# Set in Pydantic model to trigger possible validation changes, only for
# non relationship values
Expand Down
4 changes: 2 additions & 2 deletions sqlmodel/sql/_expression_select_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ def where(self, *whereclause: _ColumnExpressionArgument[bool] | bool) -> Self:
"""Return a new `Select` construct with the given expression added to
its `WHERE` clause, joined to the existing clause via `AND`, if any.
"""
return super().where(*whereclause) # type: ignore[arg-type]
return super().where(*whereclause)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's ty's mistake that it doesn't notice type mismatch here. But I checked - passing bool argument works in runtime


def having(self, *having: _ColumnExpressionArgument[bool] | bool) -> Self:
"""Return a new `Select` construct with the given expression added to
its `HAVING` clause, joined to the existing clause via `AND`, if any.
"""
return super().having(*having) # type: ignore[arg-type]
return super().having(*having)


class Select(SelectBase[_T]):
Expand Down
Loading
Loading