diff --git a/_unittests/ut_investigate/test_input_observer.py b/_unittests/ut_investigate/test_input_observer.py index 056c4765..2f415d2e 100644 --- a/_unittests/ut_investigate/test_input_observer.py +++ b/_unittests/ut_investigate/test_input_observer.py @@ -1197,5 +1197,112 @@ def forward(self, a, *args, **kwargs): torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=ds) + def test_remove_inputs_kwargs(self): + """Test that remove_inputs removes a kwarg from the observer info.""" + + class Model(torch.nn.Module): + def forward(self, x, y, z=None): + r = x + y + if z is not None: + r += z + return r + + inputs = [ + dict(x=torch.randn((5, 6)), y=torch.randn((1, 6)), z=torch.randn((5, 6))), + dict(x=torch.randn((7, 7)), y=torch.randn((1, 7)), z=torch.randn((7, 7))), + dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)), z=torch.randn((7, 8))), + ] + + model = Model() + observer = InputObserver() + with observer(model): + for kwargs in inputs: + model(**kwargs) + self.assertEqual(len(observer.info), 3) + + cst = torch.export.Dim.DYNAMIC + ds = observer.infer_dynamic_shapes() + self.assertIn("z", ds) + self.assertIn("x", ds) + self.assertIn("y", ds) + + # Remove z input + observer.remove_inputs(["z"]) + + ds_after = observer.infer_dynamic_shapes() + self.assertNotIn("z", ds_after) + self.assertIn("x", ds_after) + self.assertIn("y", ds_after) + self.assertEqual(dict(x={0: cst, 1: cst}, y={1: cst}), ds_after) + + args_after = observer.infer_arguments() + self.assertIsInstance(args_after, dict) + self.assertNotIn("z", args_after) + self.assertIn("x", args_after) + self.assertIn("y", args_after) + + def test_remove_inputs_multiple_kwargs(self): + """Test that remove_inputs removes multiple kwargs at once.""" + + class Model(torch.nn.Module): + def forward(self, x, y, z=None, w=None): + r = x + y + if z is not None: + r += z + if w is not None: + r += w + return r + + inputs = [ + dict( + x=torch.randn((5, 6)), + y=torch.randn((1, 6)), + z=torch.randn((5, 6)), + w=torch.randn((1, 6)), + ), + dict( + x=torch.randn((6, 7)), + y=torch.randn((1, 7)), + z=torch.randn((6, 7)), + w=torch.randn((1, 7)), + ), + dict( + x=torch.randn((7, 8)), + y=torch.randn((1, 8)), + z=torch.randn((7, 8)), + w=torch.randn((1, 8)), + ), + ] + + model = Model() + observer = InputObserver() + with observer(model): + for kwargs in inputs: + model(**kwargs) + self.assertEqual(len(observer.info), 3) + + cst = torch.export.Dim.DYNAMIC + ds = observer.infer_dynamic_shapes() + self.assertIn("z", ds) + self.assertIn("w", ds) + + # Remove z and w inputs + observer.remove_inputs(["z", "w"]) + + ds_after = observer.infer_dynamic_shapes() + self.assertNotIn("z", ds_after) + self.assertNotIn("w", ds_after) + self.assertIn("x", ds_after) + self.assertIn("y", ds_after) + self.assertEqual(dict(x={0: cst, 1: cst}, y={1: cst}), ds_after) + + args_after = observer.infer_arguments() + self.assertIsInstance(args_after, dict) + self.assertNotIn("z", args_after) + self.assertNotIn("w", args_after) + self.assertIn("x", args_after) + self.assertIn("y", args_after) + + if __name__ == "__main__": unittest.main(verbosity=2)