[jax] Make from_dlpack error on non-default layouts.

We're currently working on better layout support in jax, but for now,
jax can't really handle Arrays with non-standard layouts. Raise an
error instead of returning wrong results when the Array is passed to a
computation (!). This can be reverted once jax has better support for
non-default layouts.

Related issues:
https://github.com/google/jax/issues/7657
https://github.com/google/jax/issues/17784

PiperOrigin-RevId: 589258555
This commit is contained in:
Skye Wanderman-Milne 2023-12-08 15:32:31 -08:00 committed by jax authors
parent 5fb8ceca73
commit 013a01ccae

View File

@ -21,6 +21,7 @@ import jax.dlpack
import jax.numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
import numpy as np
@ -104,7 +105,6 @@ class DLPackTest(jtu.JaxTestCase):
self.assertEqual(z.devices(), {device})
self.assertAllClose(np.astype(x.dtype), z)
@jtu.sample_product(
shape=all_shapes,
dtype=dlpack_dtypes,
@ -181,6 +181,17 @@ class DLPackTest(jtu.JaxTestCase):
x_np = np.from_dlpack(x_jax)
self.assertAllClose(x_np, x_jax)
@unittest.skipIf(xla_extension_version < 221, "Requires newer jaxlib")
def testNondefaultLayout(self):
# Generate numpy array with nonstandard layout
a = np.arange(4).reshape(2, 2)
b = a.T
with self.assertRaisesRegex(
RuntimeError,
r"from_dlpack got array with non-default layout with minor-to-major "
r"dimensions \(0,1\), expected \(1,0\)"):
b_jax = jax.dlpack.from_dlpack(b.__dlpack__())
class CudaArrayInterfaceTest(jtu.JaxTestCase):