diff --git a/distconv/distconv.py b/distconv/distconv.py index 7e9ed6b..8126d47 100644 --- a/distconv/distconv.py +++ b/distconv/distconv.py @@ -842,7 +842,9 @@ class _ToTensor(Function): @staticmethod def forward(ctx, dc_tensor: DCTensor): ctx.parallel_strategy = dc_tensor._parallel_strategy - return dc_tensor._tensor + # Need to alias the tensor to prevent the returned inner tensor from + # creating circular references when its grad_fn gets modified + return torch.Tensor(dc_tensor._tensor) @staticmethod def backward(ctx, grad: torch.Tensor):