diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 300031de8b..a91c52294a 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -589,12 +589,12 @@ def __new__( } def get_config(name: str) -> Any: - config_class_value = new_cls.model_config.get(name, Undefined) - if config_class_value is not Undefined: - return config_class_value kwarg_value = kwargs.get(name, Undefined) if kwarg_value is not Undefined: return kwarg_value + config_class_value = new_cls.model_config.get(name, Undefined) + if config_class_value is not Undefined: + return config_class_value return Undefined config_table = get_config("table") @@ -618,10 +618,15 @@ def get_config(name: str) -> Any: if config_registry is not Undefined: config_registry = cast(registry, config_registry) # If it was passed by kwargs, ensure it's also set in config - new_cls.model_config["registry"] = config_table - setattr(new_cls, "_sa_registry", config_registry) # noqa: B010 - setattr(new_cls, "metadata", config_registry.metadata) # noqa: B010 - setattr(new_cls, "__abstract__", True) # noqa: B010 + new_cls.model_config["registry"] = config_registry + # Only set up the registry attributes when explicitly passed + # as a kwarg on this class, not when inherited from a parent. + # Setting __abstract__ on subclasses that merely inherit the + # registry would prevent SQLAlchemy from instrumenting them. + if "registry" in kwargs: + setattr(new_cls, "_sa_registry", config_registry) # noqa: B010 + setattr(new_cls, "metadata", config_registry.metadata) # noqa: B010 + setattr(new_cls, "__abstract__", True) # noqa: B010 return new_cls # Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 0000000000..499eaca8fb --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,43 @@ +from sqlalchemy.orm import registry +from sqlmodel import Field, Session, SQLModel, create_engine, select + + +def test_custom_registry_stored_in_model_config(clear_sqlmodel): + """Test that passing a custom registry via kwargs stores the registry + (not the table config value) in model_config['registry']. + + This is a regression test for a copy-paste bug where model_config['registry'] + was incorrectly set to the value of config_table instead of config_registry. + """ + custom_registry = registry() + + class Base(SQLModel, registry=custom_registry): + pass + + class Hero(Base, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + + # The registry stored in model_config should be the custom registry, + # not a bool (which config_table would be) + assert Base.model_config.get("registry") is custom_registry + assert isinstance(Base.model_config.get("registry"), registry) + + # Verify the custom registry is actually functional + engine = create_engine("sqlite://") + custom_registry.metadata.create_all(engine) + + with Session(engine) as session: + hero = Hero(name="Spider-Boy") + session.add(hero) + session.commit() + session.refresh(hero) + assert hero.id is not None + assert hero.name == "Spider-Boy" + + with Session(engine) as session: + heroes = session.exec(select(Hero)).all() + assert len(heroes) == 1 + assert heroes[0].name == "Spider-Boy" + + custom_registry.dispose()