We recently increased the test coverage of testing for dot_general with different dtype for lhs and rhs. Some of the new combinations of dtypes are not supported by XLA:GPU, and we disable those tests now.
PiperOrigin-RevId: 555465495
Set the AutoFDO profile version specified in --jax_xla_profile_version
if non-zero. Otherwise, expect that there is a function set in
get_latest_profile_version that will return a non-zero profile version
that should be used. If this function is not set or it returns 0,
set -1 instead to indicate that no attempt should be made to retrieve
an AutoFDO profile later on.
Testing: updated unit tests.
PiperOrigin-RevId: 555333728
The new cache-key generation algorithm will coexist with the original
version until the new one is fully deployed. While they coexist,
--jax_use_original_compilation_cache_key_generation will determine which
one is used. Once the new algorithm is deployed, the original algorithm
and this flag will be removed.
This change sets up the plumbing. Later changes will implement the new
algorithm.
Testing: test workload.
PiperOrigin-RevId: 555333628
This change adds `xla_client.DeviceList` that is implemented in C++
`jax::PyDeviceList`. `jax::PyDeviceList` implements the features of
`pxla._DeviceAssignment` as a functional drop-in replacement.
`jax::PyDeviceList` internally has `xla::ifrt::DeviceList`, which will be used
when using IFRT APIs without having to construct a new copy of a potentially
large device list.
`pxla._DeviceAssignment`'s interface is changed slightly to encourage avoiding
conversion to tuple.
Note that for the backward compatibility (and fast `xla_client.Device`
conversion), `jax::PyDeviceList` still uses a Python tuple whose element can be
any Python object matches `xla_client.Device` interface with duck typing. This
duck typing support will be removed when such use case is deprecated.
Eventually, we can try to avoid any type conversion to remove a shadow copy of
device list in JAX.
PiperOrigin-RevId: 555317152
* https://reviews.llvm.org/D155209 added support to the MLIR Python bindings for passing types like bfloat16 directly if an explicit IR type is provided.
* the crash for non-splat size 1 constants appears fixed at head, although I don't know which change fixed it.
PiperOrigin-RevId: 555225604
Add private functions _unregister_event_duration_listener_by_callback and _unregister_event_duration_listener_by_index to remove registered event duration listeners. The functions are supposed to be called in test only.
PiperOrigin-RevId: 555208764
This allowlist used to prevent users from using collectives that didn't work correctly in multihost pmap(). But currently every collective in JAX (except for pgather(), which isn't public), is on the list. So the allowlist serves no purpose any more.
PiperOrigin-RevId: 555124144
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.
Remove an unused top_k translation rule as well.
PiperOrigin-RevId: 554946059
Changing the flag to a config permits more contained testing.
This is in preparation for an upcoming change to incorporate
AutoFDO profile versions in the cache key.
Testing: test workload.
PiperOrigin-RevId: 554942573
The coordination service has been the default for a long time, and has significant additional functionality. Remove the older code path to simplify the code.
PiperOrigin-RevId: 554608165
NOTE: this version of DUCC has a breaking change, where the fft.h header
no longer contains the definitions of many fft functions - instead they exist
within fft1d_impl.h and fftnd_impl.h.
PiperOrigin-RevId: 554567641