mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #16769 from jakevdp:fix-jax-array
PiperOrigin-RevId: 548835682
This commit is contained in:
commit
8016fb3b66
@ -2021,6 +2021,8 @@ def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True,
|
||||
if isinstance(object, (bool, int, float, complex)):
|
||||
_ = dtypes.coerce_to_array(object, dtype)
|
||||
|
||||
if hasattr(object, '__jax_array__'):
|
||||
object = object.__jax_array__()
|
||||
object = tree_map(lambda leaf: leaf.__jax_array__() if hasattr(leaf, "__jax_array__") else leaf,
|
||||
object)
|
||||
leaves = tree_leaves(object)
|
||||
|
@ -3898,6 +3898,8 @@ class APITest(jtu.JaxTestCase):
|
||||
shape = property(operator.attrgetter('x.shape'))
|
||||
|
||||
a = A(jnp.ones((3, 3)))
|
||||
jnp.asarray(a) # don't crash
|
||||
|
||||
f = jax.jit(jnp.matmul)
|
||||
f(a, a) # don't crash
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user