mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[XLA:Python] Fix __cuda_array_interface__.
Adds a test for __cuda_array_interface__ that does not depend on cupy. Fixes https://github.com/google/jax/issues/16440 PiperOrigin-RevId: 541965361
This commit is contained in:
parent
afcd1a7c8e
commit
0ec03dbdce
@ -37,6 +37,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jaxlib 0.4.13
|
||||
|
||||
* Bug fixes
|
||||
* `__cuda_array_interface__` was broken in previous jaxlib versions and is now
|
||||
fixed ({jax-issue}`16440`).
|
||||
|
||||
## jax 0.4.12 (June 8, 2023)
|
||||
|
||||
* Changes
|
||||
|
@ -21,6 +21,7 @@ from jax import config
|
||||
import jax.dlpack
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -48,6 +49,8 @@ numpy_dtypes = sorted(
|
||||
[dt for dt in jax.dlpack.SUPPORTED_DTYPES if dt != jnp.bfloat16],
|
||||
key=lambda x: x.__name__)
|
||||
|
||||
cuda_array_interface_dtypes = [dt for dt in dlpack_dtypes if dt != jnp.bfloat16]
|
||||
|
||||
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)]
|
||||
empty_array_shapes = []
|
||||
empty_array_shapes += [(0,), (0, 4), (3, 0),]
|
||||
@ -162,6 +165,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(x_np, x_jax)
|
||||
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 163, "Test requires jaxlib 0.4.13")
|
||||
class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -171,12 +175,29 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=dlpack_dtypes,
|
||||
dtype=cuda_array_interface_dtypes,
|
||||
)
|
||||
def testCudaArrayInterfaceWorks(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng(shape, dtype)
|
||||
y = jnp.array(x)
|
||||
a = y.__cuda_array_interface__
|
||||
self.assertEqual(shape, a["shape"])
|
||||
self.assertEqual(x.__array_interface__["typestr"], a["typestr"])
|
||||
|
||||
def testCudaArrayInterfaceBfloat16Fails(self):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng((2, 2), jnp.bfloat16)
|
||||
y = jnp.array(x)
|
||||
with self.assertRaisesRegex(RuntimeError, ".*not supported for bfloat16.*"):
|
||||
_ = y.__cuda_array_interface__
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=cuda_array_interface_dtypes,
|
||||
)
|
||||
@unittest.skipIf(not cupy, "Test requires CuPy")
|
||||
def testJaxToCuPy(self, shape, dtype):
|
||||
if dtype == jnp.bfloat16:
|
||||
raise unittest.SkipTest("cupy does not support bfloat16")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng(shape, dtype)
|
||||
y = jnp.array(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user