diff --git a/tests/test_torch.py b/tests/test_torch.py index 35ef5dda..b064a46d 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -161,3 +161,18 @@ def test_round(): r = xp.round(x, decimals=1, out=o) assert xp.all(r == o) assert r is o + + +def test_dynamo_array_namespace(): + """Check that torch.compiling array_namespace does not incur graph breaks.""" + from array_api_compat import array_namespace + + def foo(x): + xp = array_namespace(x) + return xp.multiply(x, x) + + bar = torch.compile(fullgraph=True)(foo) + + x = torch.arange(3) + y = bar(x) + assert xp.all(y == x**2)