The original cache-key generation algorithm hashed devices and backend as
part of generating the key. The new algorithm relies on serialized
PjRtTopologyDescription instead. Not all backends support serialized
PjRtTopologyDescription. Fall back to the original device/backend hashing
if the needed backend does not support it.
Testing: unit testing + test workloads.
PiperOrigin-RevId: 579039803
In JAX the actual platform on which a computation is run is determined
very late, e.g., based on where the data is located. When using AOT
lowering or serialization, the computation may execute on a different
machine, or even on a platform that is not available at lowering time.
This means that it is not safe to write platform-dependent code using
Python conditionals, e.g., based on the current default JAX platform.
The proper way to do this is to introduce a primitive with
platform-specific lowering rules. This change introduces such a
primitive along with a user-facing API.
See more details in the docstring of lax.platform_dependent.
Transferring an array from host to device on CPU sometimes does a zero-copy implementation where no memory is actually moved. This is now never done with int4, since int4 arrays are stored in packed format on device and an unpacked format on host. Similarly, transferring an array from device to host on CPU used to always use a zero-copy implementation, but now it will unpack and copy for int4 arrays.
PiperOrigin-RevId: 578692796
Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.
PiperOrigin-RevId: 578349699
These methods are internal to JAX. Yet, prior to this commit they were
effectively part of the public API, since users could (and some did!) invoke
them on `jax.config`.
When the value in --jax_xla_profile_version changes, all tracing
and compilation caches should be invalidated since the XLA programs
need to be recompiled with the new XLA-AutoFDO profile.
Testing:
. New unit test.
. Test workload with instrumentation to repeatedly change
the profile version. Before/after comparison.
PiperOrigin-RevId: 577280639
The `config` object it re-exports is available on the top-level `jax` package,
i.e.
from jax.config import config
can always safely be replaced with
from jax import config
or just
import jax
jax.config
Previously, we introduced support for multi-platform lowering, by
adding a new LoweringParameters object that can be used to specify
a cross-lowering platform or even multiple platforms. But we had
kept the ModuleContext.platform in place because some lowering rules
were still referencing it. Now we replace ModuleContext.platform with
ModuleContext.platforms, which removes the redundancy, simplifies
the code, and makes it clearer that the lowering rules should not
simply assume single-platform lowering.
PiperOrigin-RevId: 576575376
Those failures should be treated as internal and reported as bugs.
As a next step, we should add a verifier pass that explicitly checks
that all the ops in the module are supported.
PiperOrigin-RevId: 576502802
call_tf has per-platform lowering because the lowering
of the called TF function may depend on the platform. When
doing multi-platform lowering this means that we lower
call_tf several times and wrap the lowerings with a
conditional. This results in an assertion failure
in add_to_call_tf_concrete_function_list, because we
are attempting to add the same function multiple times.
Here we remove the assertion (afaik, it is Ok to add
multiple functions with the same name, because all
we care about is the index of the called function in
the list). We also reuse the existing function if
we are adding an identical one.
We add tests for call_tf with multi-platform lowering.