- Fix non determinism in axis name mapping in Pallas lowering.
- Implement mesh info support in Pallas transform jaxprs so blockspecs can access mesh info.
- Fix non compatibility of mesh info and scalar prefetching in Pallas lowering.
- Fix Pallas lowering of multi-dimensional topology remote DMAs.
PiperOrigin-RevId: 580357651
Remove the code which checks if the min compile time is greater than zero. After this change, we can catch cache_misses when min compile time is zero.
Testing: revised unit test.
PiperOrigin-RevId: 579951415
The new cache-key generation algorithm is more robust and
results in fewer stale entries being returned.
Testing: test workloads.
PiperOrigin-RevId: 579928158
XLA now calls ducc itself as of da67903a4c, so we don't need a custom call in JAX any more. In addition, the DUCC call from XLA receives a thread pool and is parallelized.
Fixes https://github.com/google/jax/issues/14664
PiperOrigin-RevId: 579829580
The original cache-key generation algorithm hashed devices and backend as
part of generating the key. The new algorithm relies on serialized
PjRtTopologyDescription instead. Not all backends support serialized
PjRtTopologyDescription. Fall back to the original device/backend hashing
if the needed backend does not support it.
Testing: unit testing + test workloads.
PiperOrigin-RevId: 579039803
In JAX the actual platform on which a computation is run is determined
very late, e.g., based on where the data is located. When using AOT
lowering or serialization, the computation may execute on a different
machine, or even on a platform that is not available at lowering time.
This means that it is not safe to write platform-dependent code using
Python conditionals, e.g., based on the current default JAX platform.
The proper way to do this is to introduce a primitive with
platform-specific lowering rules. This change introduces such a
primitive along with a user-facing API.
See more details in the docstring of lax.platform_dependent.
Transferring an array from host to device on CPU sometimes does a zero-copy implementation where no memory is actually moved. This is now never done with int4, since int4 arrays are stored in packed format on device and an unpacked format on host. Similarly, transferring an array from device to host on CPU used to always use a zero-copy implementation, but now it will unpack and copy for int4 arrays.
PiperOrigin-RevId: 578692796
Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.
PiperOrigin-RevId: 578349699
These methods are internal to JAX. Yet, prior to this commit they were
effectively part of the public API, since users could (and some did!) invoke
them on `jax.config`.