mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[jax] Make from_dlpack error on non-default layouts.
We're currently working on better layout support in jax, but for now, jax can't really handle Arrays with non-standard layouts. Raise an error instead of returning wrong results when the Array is passed to a computation (!). This can be reverted once jax has better support for non-default layouts. Related issues: https://github.com/google/jax/issues/7657 https://github.com/google/jax/issues/17784 PiperOrigin-RevId: 589258555
This commit is contained in:
parent
5fb8ceca73
commit
013a01ccae
@ -21,6 +21,7 @@ import jax.dlpack
|
||||
import jax.numpy as jnp
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -104,7 +105,6 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
self.assertEqual(z.devices(), {device})
|
||||
self.assertAllClose(np.astype(x.dtype), z)
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=dlpack_dtypes,
|
||||
@ -181,6 +181,17 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
x_np = np.from_dlpack(x_jax)
|
||||
self.assertAllClose(x_np, x_jax)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 221, "Requires newer jaxlib")
|
||||
def testNondefaultLayout(self):
|
||||
# Generate numpy array with nonstandard layout
|
||||
a = np.arange(4).reshape(2, 2)
|
||||
b = a.T
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"from_dlpack got array with non-default layout with minor-to-major "
|
||||
r"dimensions \(0,1\), expected \(1,0\)"):
|
||||
b_jax = jax.dlpack.from_dlpack(b.__dlpack__())
|
||||
|
||||
|
||||
class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user