From 402949db5f17e474d2faee6252b384702a95ee04 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 3 Apr 2026 19:54:56 +0200 Subject: [PATCH] TST: add a test of torch.compile-ing array_namespace --- tests/test_torch.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_torch.py b/tests/test_torch.py index 35ef5dda..b064a46d 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -161,3 +161,18 @@ def test_round(): r = xp.round(x, decimals=1, out=o) assert xp.all(r == o) assert r is o + + +def test_dynamo_array_namespace(): + """Check that torch.compiling array_namespace does not incur graph breaks.""" + from array_api_compat import array_namespace + + def foo(x): + xp = array_namespace(x) + return xp.multiply(x, x) + + bar = torch.compile(fullgraph=True)(foo) + + x = torch.arange(3) + y = bar(x) + assert xp.all(y == x**2)