Ever since the jit-pjit merge, the "Python" jit test has actually just called the same code as the "C++" jit test. We don't have a C++-free jit path any more. Remove the "Python" tests since they don't test anything.
PiperOrigin-RevId: 581965049
The compilation_cache_test had an exclusion since the C PjRt
topology description had not been implemented. Now that it is
available, remove the exclusion.
PiperOrigin-RevId: 581396824
The current implementation synchronously calls `ArrayImpl.block_until_ready()` one by one. This is suboptimal when it's not cheap to query the readiness of an array. Also, calling `x.block_until_ready()` causes GIL to be acquired/released repeatedly.
To address this issue, this CL introduces a C++ implementation of `jax.block_until_ready(x)` that uses IFRT's `Array::GetReadyFuture()` to asynchronously query the readiness of all arrays and wait for them once. To preserve the previous behavior, the C++ implementation also has a slow path for any non-PyArray objects that implement `block_until_ready`.
PiperOrigin-RevId: 581302290
Remove the code which checks if the min compile time is greater than zero. After this change, we can catch cache_misses when min compile time is zero.
Testing: revised unit test.
PiperOrigin-RevId: 579951415
This is a big step toward enabling xla_gpu_triton_gemm_any by default.
It shows about 1.05x geomean speedup on internal benchmarks (comparable to xla_gpu_triton_gemm_any=true).
PiperOrigin-RevId: 579524573
The original cache-key generation algorithm hashed devices and backend as
part of generating the key. The new algorithm relies on serialized
PjRtTopologyDescription instead. Not all backends support serialized
PjRtTopologyDescription. Fall back to the original device/backend hashing
if the needed backend does not support it.
Testing: unit testing + test workloads.
PiperOrigin-RevId: 579039803
In JAX the actual platform on which a computation is run is determined
very late, e.g., based on where the data is located. When using AOT
lowering or serialization, the computation may execute on a different
machine, or even on a platform that is not available at lowering time.
This means that it is not safe to write platform-dependent code using
Python conditionals, e.g., based on the current default JAX platform.
The proper way to do this is to introduce a primitive with
platform-specific lowering rules. This change introduces such a
primitive along with a user-facing API.
See more details in the docstring of lax.platform_dependent.
Transferring an array from host to device on CPU sometimes does a zero-copy implementation where no memory is actually moved. This is now never done with int4, since int4 arrays are stored in packed format on device and an unpacked format on host. Similarly, transferring an array from device to host on CPU used to always use a zero-copy implementation, but now it will unpack and copy for int4 arrays.
PiperOrigin-RevId: 578692796
Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.
PiperOrigin-RevId: 578349699
These methods are internal to JAX. Yet, prior to this commit they were
effectively part of the public API, since users could (and some did!) invoke
them on `jax.config`.
When the value in --jax_xla_profile_version changes, all tracing
and compilation caches should be invalidated since the XLA programs
need to be recompiled with the new XLA-AutoFDO profile.
Testing:
. New unit test.
. Test workload with instrumentation to repeatedly change
the profile version. Before/after comparison.
PiperOrigin-RevId: 577280639
In presence of ordered effects JAX lowering produces a main
function that takes token
inputs and returns token outputs. Previously, when exporting
such a module, we would wrap the main function with a function
that does not use tokens on inputs and outputs. With this
change we actually leave the token inputs and outputs and
rely on consumers of the exported function to know how to
invoke a function with tokens.
Due to the fact that PJRT does not support passing tokens
as input and output to the top-level function, JAX native
lowering uses dummy bool[0] arrays in lieu of tokens for
the top-level function, and uses stablehlo tokens for the
inner functions. When we export a function for serialization
we want to use stablehlo tokens even at top-level, to enable
calling that function from a larger JAX computation later.
See more details about the calling convention in the
docstring for `export.export`.
We also fix and test multi-platform lowering in presence
of effects.
This introduces serialization version 9, but does not change the
default serialization version. This means that version 9 will not
be used except in tests that specifically override the
serialization version.