Merge pull request #16769 from jakevdp:fix-jax-array

PiperOrigin-RevId: 548835682
This commit is contained in:
jax authors 2023-07-17 16:51:13 -07:00
commit 8016fb3b66
2 changed files with 4 additions and 0 deletions

View File

@ -2021,6 +2021,8 @@ def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True,
if isinstance(object, (bool, int, float, complex)):
_ = dtypes.coerce_to_array(object, dtype)
if hasattr(object, '__jax_array__'):
object = object.__jax_array__()
object = tree_map(lambda leaf: leaf.__jax_array__() if hasattr(leaf, "__jax_array__") else leaf,
object)
leaves = tree_leaves(object)

View File

@ -3898,6 +3898,8 @@ class APITest(jtu.JaxTestCase):
shape = property(operator.attrgetter('x.shape'))
a = A(jnp.ones((3, 3)))
jnp.asarray(a) # don't crash
f = jax.jit(jnp.matmul)
f(a, a) # don't crash