rocm_jax/jax/experimental
Peter Hawkins 808003b4e2 Update users of jax.tree.map() to be more careful about how they handle Nones.
Due to a bug in JAX, JAX previously permitted `jax.tree.map(f, None, x)` where `x` is not `None`, effectively treating `None` as if it were pytree-prefix of any value. But `None` is a pytree container, and it is only a prefix of `None` itself.

Fix code that was relying on this bug. Most commonly, the fix is to write
`jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)`.

PiperOrigin-RevId: 673258116
2024-09-10 23:54:11 -07:00
..
2024-06-26 16:10:18 -04:00
2024-06-26 16:10:18 -04:00
2024-05-29 17:53:56 +00:00
2024-06-26 16:10:18 -04:00
2023-12-18 10:08:47 -08:00
2024-05-19 21:01:29 +01:00