diff --git a/vendor_test/uses_torch.py b/vendor_test/uses_torch.py index 5804aaff..747ecd51 100644 --- a/vendor_test/uses_torch.py +++ b/vendor_test/uses_torch.py @@ -23,7 +23,7 @@ def _test_torch(): assert isinstance(b, torch.Tensor) assert isinstance(res, torch.Tensor) - torch.testing.assert_allclose(res, [[1., 2., 3.]]) + torch.testing.assert_close(res, torch.as_tensor([[1., 2., 3.]])) assert is_torch_array(res) assert is_torch_namespace(torch) and is_torch_namespace(torch_compat)