rocm_jax/jax/experimental
Yash Katariya e32373c3ea Make jnp.array return jax.Array. Add input and result handlers for jax.Array. Also added tests for add under jit.
TODO:
* Don't allow `x + y` if `jax.Array` is not fully addressable.
* Figure out how to use the already written tests with Array. Might be able to follow the path taken by SDA.
PiperOrigin-RevId: 457034779
2022-06-24 10:05:06 -07:00
..
2022-06-24 08:06:21 -07:00
2022-04-21 13:44:12 -07:00
2022-04-21 13:44:12 -07:00
2022-05-17 22:14:05 +01:00