mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Split PyTorch interoperability tests into their own test.
PiperOrigin-RevId: 508722180
This commit is contained in:
parent
5da5967d08
commit
6ee67639e2
@ -623,6 +623,13 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "pytorch_interoperability_test",
|
||||
srcs = ["pytorch_interoperability_test.py"],
|
||||
disable_backends = ["tpu"],
|
||||
deps = py_deps("torch"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "qdwh_test",
|
||||
srcs = ["qdwh_test.py"],
|
||||
|
@ -19,7 +19,6 @@ from absl.testing import absltest
|
||||
import jax
|
||||
from jax.config import config
|
||||
import jax.dlpack
|
||||
from jax._src.lib import xla_bridge, xla_client
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
@ -29,12 +28,6 @@ numpy_version = jtu.numpy_version()
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.utils.dlpack
|
||||
except ImportError:
|
||||
torch = None
|
||||
|
||||
try:
|
||||
import cupy
|
||||
except ImportError:
|
||||
@ -50,8 +43,10 @@ except:
|
||||
|
||||
dlpack_dtypes = sorted(list(jax.dlpack.SUPPORTED_DTYPES),
|
||||
key=lambda x: x.__name__)
|
||||
torch_dtypes = [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
||||
jnp.uint8, jnp.float16, jnp.float32, jnp.float64]
|
||||
|
||||
numpy_dtypes = sorted(
|
||||
[dt for dt in jax.dlpack.SUPPORTED_DTYPES if dt != jnp.bfloat16],
|
||||
key=lambda x: x.__name__)
|
||||
|
||||
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)]
|
||||
empty_array_shapes = []
|
||||
@ -145,60 +140,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=torch_dtypes,
|
||||
)
|
||||
@unittest.skipIf(not torch, "Test requires PyTorch")
|
||||
def testTorchToJax(self, shape, dtype):
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]:
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
x = torch.from_numpy(np)
|
||||
x = x.cuda() if jtu.device_under_test() == "gpu" else x
|
||||
dlpack = torch.utils.dlpack.to_dlpack(x)
|
||||
y = jax.dlpack.from_dlpack(dlpack)
|
||||
self.assertAllClose(np, y)
|
||||
|
||||
@unittest.skipIf(not torch, "Test requires PyTorch")
|
||||
def testTorchToJaxFailure(self):
|
||||
x = torch.arange(6).reshape((2, 3))
|
||||
y = torch.utils.dlpack.to_dlpack(x[:, :2])
|
||||
|
||||
backend = xla_bridge.get_backend()
|
||||
client = getattr(backend, "client", backend)
|
||||
|
||||
regex_str = (r'UNIMPLEMENTED: Only DLPack tensors with trivial \(compact\) '
|
||||
r'striding are supported')
|
||||
with self.assertRaisesRegex(RuntimeError, regex_str):
|
||||
xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
y, client)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=torch_dtypes,
|
||||
)
|
||||
@unittest.skipIf(not torch, "Test requires PyTorch")
|
||||
def testJaxToTorch(self, shape, dtype):
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]:
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
x = jnp.array(np)
|
||||
dlpack = jax.dlpack.to_dlpack(x)
|
||||
y = torch.utils.dlpack.from_dlpack(dlpack)
|
||||
self.assertAllClose(np, y.cpu().numpy())
|
||||
|
||||
@unittest.skipIf(not torch, "Test requires PyTorch")
|
||||
def testTorchToJaxInt64(self):
|
||||
# See https://github.com/google/jax/issues/11895
|
||||
x = jax.dlpack.from_dlpack(
|
||||
torch.utils.dlpack.to_dlpack(torch.ones((2, 3), dtype=torch.int64)))
|
||||
dtype_expected = jnp.int64 if config.x64_enabled else jnp.int32
|
||||
self.assertEqual(x.dtype, dtype_expected)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=torch_dtypes,
|
||||
dtype=numpy_dtypes,
|
||||
)
|
||||
@unittest.skipIf(numpy_version < (1, 22, 0), "Requires numpy 1.22 or newer")
|
||||
def testNumpyToJax(self, shape, dtype):
|
||||
@ -209,7 +151,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=torch_dtypes,
|
||||
dtype=numpy_dtypes,
|
||||
)
|
||||
@unittest.skipIf(numpy_version < (1, 23, 0), "Requires numpy 1.23 or newer")
|
||||
@jtu.skip_on_devices("gpu") #NumPy only accepts cpu DLPacks
|
||||
|
97
tests/pytorch_interoperability_test.py
Normal file
97
tests/pytorch_interoperability_test.py
Normal file
@ -0,0 +1,97 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax.config import config
|
||||
import jax.dlpack
|
||||
from jax._src.lib import xla_bridge, xla_client
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.utils.dlpack
|
||||
except ImportError:
|
||||
torch = None
|
||||
|
||||
|
||||
torch_dtypes = [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
||||
jnp.uint8, jnp.float16, jnp.float32, jnp.float64,
|
||||
jnp.bfloat16, jnp.complex64, jnp.complex128]
|
||||
|
||||
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)]
|
||||
empty_array_shapes = []
|
||||
empty_array_shapes += [(0,), (0, 4), (3, 0),]
|
||||
nonempty_nonscalar_array_shapes += [(3, 1), (1, 4), (2, 1, 4)]
|
||||
|
||||
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
|
||||
all_shapes = nonempty_array_shapes + empty_array_shapes
|
||||
|
||||
class DLPackTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if jtu.device_under_test() == "tpu":
|
||||
self.skipTest("DLPack not supported on TPU")
|
||||
|
||||
@unittest.skipIf(not torch, "Test requires PyTorch")
|
||||
def testTorchToJaxFailure(self):
|
||||
x = torch.arange(6).reshape((2, 3))
|
||||
x = x.cuda() if jtu.device_under_test() == "gpu" else x
|
||||
y = torch.utils.dlpack.to_dlpack(x[:, :2])
|
||||
|
||||
backend = xla_bridge.get_backend()
|
||||
client = getattr(backend, "client", backend)
|
||||
|
||||
regex_str = (r'UNIMPLEMENTED: Only DLPack tensors with trivial \(compact\) '
|
||||
r'striding are supported')
|
||||
with self.assertRaisesRegex(RuntimeError, regex_str):
|
||||
xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
y, client)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=torch_dtypes,
|
||||
)
|
||||
@unittest.skipIf(not torch, "Test requires PyTorch")
|
||||
def testJaxToTorch(self, shape, dtype):
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]:
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
x = jnp.array(np)
|
||||
dlpack = jax.dlpack.to_dlpack(x)
|
||||
y = torch.utils.dlpack.from_dlpack(dlpack)
|
||||
if dtype == jnp.bfloat16:
|
||||
# .numpy() doesn't work on Torch bfloat16 tensors.
|
||||
self.assertAllClose(np,
|
||||
y.cpu().view(torch.int16).numpy().view(jnp.bfloat16))
|
||||
else:
|
||||
self.assertAllClose(np, y.cpu().numpy())
|
||||
|
||||
@unittest.skipIf(not torch, "Test requires PyTorch")
|
||||
def testTorchToJaxInt64(self):
|
||||
# See https://github.com/google/jax/issues/11895
|
||||
x = jax.dlpack.from_dlpack(
|
||||
torch.utils.dlpack.to_dlpack(torch.ones((2, 3), dtype=torch.int64)))
|
||||
dtype_expected = jnp.int64 if config.x64_enabled else jnp.int32
|
||||
self.assertEqual(x.dtype, dtype_expected)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user