diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4b9d4cd49..1c9733ae4 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2021,6 +2021,8 @@ def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True, if isinstance(object, (bool, int, float, complex)): _ = dtypes.coerce_to_array(object, dtype) + if hasattr(object, '__jax_array__'): + object = object.__jax_array__() object = tree_map(lambda leaf: leaf.__jax_array__() if hasattr(leaf, "__jax_array__") else leaf, object) leaves = tree_leaves(object) diff --git a/tests/api_test.py b/tests/api_test.py index f35aab5d5..47b84768d 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -3898,6 +3898,8 @@ class APITest(jtu.JaxTestCase): shape = property(operator.attrgetter('x.shape')) a = A(jnp.ones((3, 3))) + jnp.asarray(a) # don't crash + f = jax.jit(jnp.matmul) f(a, a) # don't crash