support np.array(x) where x is a custom pytree with __jax_array__

This commit is contained in:
Jake VanderPlas 2023-07-17 13:33:17 -07:00
parent 68ea651ae4
commit 74159132b6
2 changed files with 4 additions and 0 deletions

View File

@ -2021,6 +2021,8 @@ def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True,
if isinstance(object, (bool, int, float, complex)):
_ = dtypes.coerce_to_array(object, dtype)
if hasattr(object, '__jax_array__'):
object = object.__jax_array__()
object = tree_map(lambda leaf: leaf.__jax_array__() if hasattr(leaf, "__jax_array__") else leaf,
object)
leaves = tree_leaves(object)

View File

@ -3898,6 +3898,8 @@ class APITest(jtu.JaxTestCase):
shape = property(operator.attrgetter('x.shape'))
a = A(jnp.ones((3, 3)))
jnp.asarray(a) # don't crash
f = jax.jit(jnp.matmul)
f(a, a) # don't crash