Fix this by calculating the donation vector by looking at the in_tree.
A bonus is that we can now cache the calculation of donation vector leading to faster tracing times in JAX.
PiperOrigin-RevId: 627512710
The default `fn.__name__` was added in `_one_to_one_unop` but not other functions so that it leads to many downstream function wrappers giving unmeaningful names while debugging. For instance,
When a JAX numpy primitive `lax.add` is wrapped by `lu.WrappedFun`, `print(wrapped)` will give,
```
Wrapped function:
0 : _argnums_partial ((0, 1), ())
1 : flatten_fun (PyTreeDef(((*, *), {})),)
2 : result_paths ()
Core: fn
```
instead of
```
Wrapped function:
0 : _argnums_partial ((0, 1), ())
1 : flatten_fun (PyTreeDef(((*, *), {})),)
2 : result_paths ()
Core: add
```
PiperOrigin-RevId: 627417452
- Pull mesh from NamedSharding when rewriting manual axes.
- Properly set manual axes in SPMDAxisContext in shard_map.
- Properly set dims as unspecified inside shard_map.
PiperOrigin-RevId: 627156892
* Cache the sharding index comparison in addition to sharding index calculation. This helps when the list of indices is expensive to compare.
* Remove caching from `pxla.get_addressable_devices_for_shard_arg()` since `sharding._addressable_device_assignment` is already a cached property.
* Use `a is b` instead of `id(a) == id(b)` since the former is more concise.
PiperOrigin-RevId: 627080325
Prior to this change the behavior in eager and under jax.jit was inconsistent
>>> (lambda *args: jax.debug.callback(print, *args))([42])
[42]
>>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42])
[array(42, dtype=int32)]
It was also inconsistent with other callback APIs, which cast the arguments
to jax.Arrays.
Closes#20627.
PiperOrigin-RevId: 626461904
This test has started to fail in compiled mode, for complex128, but with small errors (1e-14).
Adjust the tolerance for both the native and non-native serialization mode.
PiperOrigin-RevId: 626373781
The existing lowering path supports only while_loops which can be converted to fori_loop.
That path makes it significantly easier to optimize and unroll, but cannot support a large class of interesting loop formulations.
This patch draws from the Pallas -> Triton while_loop lowering rule to support such loops in Pallas.
Matching is still performed against fori_loop, to lower via that mechanism if possible -- as it is likely more straightforwardly optimizable compared to general "while".
PiperOrigin-RevId: 626089180