diff --git a/src/datajoint/autopopulate.py b/src/datajoint/autopopulate.py index c02f18791..c1a6ce510 100644 --- a/src/datajoint/autopopulate.py +++ b/src/datajoint/autopopulate.py @@ -204,7 +204,7 @@ def _rename_attributes(table, props): self._key_source *= _rename_attributes(*q) return self._key_source - def make(self, key: dict[str, Any]) -> None | Generator[Any, Any, None]: + def make(self, key: dict[str, Any], **kwargs) -> None | Generator[Any, Any, None]: """ Compute and insert data for one key. @@ -219,6 +219,9 @@ def make(self, key: dict[str, Any]) -> None | Generator[Any, Any, None]: ---------- key : dict Primary key value identifying the entity to compute. + **kwargs + Keyword arguments passed from ``populate(make_kwargs=...)``. + These are forwarded to ``make_fetch`` for the tripartite pattern. Raises ------ @@ -232,7 +235,7 @@ def make(self, key: dict[str, Any]) -> None | Generator[Any, Any, None]: **Tripartite make**: For long-running computations, implement: - - ``make_fetch(key)``: Fetch data from parent tables + - ``make_fetch(key, **kwargs)``: Fetch data from parent tables - ``make_compute(key, *fetched_data)``: Compute results - ``make_insert(key, *computed_result)``: Insert results @@ -250,7 +253,7 @@ def make(self, key: dict[str, Any]) -> None | Generator[Any, Any, None]: # User has implemented `_fetch`, `_compute`, and `_insert` methods instead # Step 1: Fetch data from parent tables - fetched_data = self.make_fetch(key) # fetched_data is a tuple + fetched_data = self.make_fetch(key, **kwargs) # fetched_data is a tuple computed_result = yield fetched_data # passed as input into make_compute # Step 2: If computed result is not passed in, compute the result diff --git a/tests/integration/test_autopopulate.py b/tests/integration/test_autopopulate.py index c9df5f78f..4e6290b99 100644 --- a/tests/integration/test_autopopulate.py +++ b/tests/integration/test_autopopulate.py @@ -147,3 +147,84 @@ def make(self, key): self.insert1(dict(key, crop_image=dict())) Crop.populate() + + +def test_make_kwargs_regular(prefix, connection_test): + """Test that make_kwargs are passed to regular make method.""" + schema = dj.Schema(f"{prefix}_make_kwargs_regular", connection=connection_test) + + @schema + class Source(dj.Lookup): + definition = """ + source_id: int + """ + contents = [(1,), (2,)] + + @schema + class Computed(dj.Computed): + definition = """ + -> Source + --- + multiplier: int + result: int + """ + + def make(self, key, multiplier=1): + self.insert1(dict(key, multiplier=multiplier, result=key["source_id"] * multiplier)) + + # Test without make_kwargs + Computed.populate(Source & "source_id = 1") + assert (Computed & "source_id = 1").fetch1("result") == 1 + + # Test with make_kwargs + Computed.populate(Source & "source_id = 2", make_kwargs={"multiplier": 10}) + assert (Computed & "source_id = 2").fetch1("multiplier") == 10 + assert (Computed & "source_id = 2").fetch1("result") == 20 + + +def test_make_kwargs_tripartite(prefix, connection_test): + """Test that make_kwargs are passed to make_fetch in tripartite pattern (issue #1350).""" + schema = dj.Schema(f"{prefix}_make_kwargs_tripartite", connection=connection_test) + + @schema + class Source(dj.Lookup): + definition = """ + source_id: int + --- + value: int + """ + contents = [(1, 100), (2, 200)] + + @schema + class TripartiteComputed(dj.Computed): + definition = """ + -> Source + --- + scale: int + result: int + """ + + def make_fetch(self, key, scale=1): + """Fetch data with optional scale parameter.""" + value = (Source & key).fetch1("value") + return (value, scale) + + def make_compute(self, key, value, scale): + """Compute result using fetched value and scale.""" + return (value * scale, scale) + + def make_insert(self, key, result, scale): + """Insert computed result.""" + self.insert1(dict(key, scale=scale, result=result)) + + # Test without make_kwargs (scale defaults to 1) + TripartiteComputed.populate(Source & "source_id = 1") + row = (TripartiteComputed & "source_id = 1").fetch1() + assert row["scale"] == 1 + assert row["result"] == 100 # 100 * 1 + + # Test with make_kwargs (scale = 5) + TripartiteComputed.populate(Source & "source_id = 2", make_kwargs={"scale": 5}) + row = (TripartiteComputed & "source_id = 2").fetch1() + assert row["scale"] == 5 + assert row["result"] == 1000 # 200 * 5