Create metrics:
1) '/jax/compilation_cache/cache_retrieval_time_sec' to record the time duration for getting cache entries.
2) '/jax/compilation_cache/original_compile_time_saved_sec' to record the time saved on cache hits.
PiperOrigin-RevId: 556243588
This enables shape assertion checking, the support for which
landed in XlaCallModule on July 12th, 2023.
See the CHANGELOG for details.
PiperOrigin-RevId: 556222908
It appears that some users pass lists as axis arguments, and these are allowed by the types on the regular `jax.numpy` functions. Relax the type annotations on the methods to match the free functions.
PiperOrigin-RevId: 556085084
These type annotations are of course mostly ignored because the pytype: skip-file comment, but they help readers if nothing else.
PiperOrigin-RevId: 555955257
We recently increased the test coverage of testing for dot_general with different dtype for lhs and rhs. Some of the new combinations of dtypes are not supported by XLA:GPU, and we disable those tests now.
PiperOrigin-RevId: 555465495
Set the AutoFDO profile version specified in --jax_xla_profile_version
if non-zero. Otherwise, expect that there is a function set in
get_latest_profile_version that will return a non-zero profile version
that should be used. If this function is not set or it returns 0,
set -1 instead to indicate that no attempt should be made to retrieve
an AutoFDO profile later on.
Testing: updated unit tests.
PiperOrigin-RevId: 555333728
The new cache-key generation algorithm will coexist with the original
version until the new one is fully deployed. While they coexist,
--jax_use_original_compilation_cache_key_generation will determine which
one is used. Once the new algorithm is deployed, the original algorithm
and this flag will be removed.
This change sets up the plumbing. Later changes will implement the new
algorithm.
Testing: test workload.
PiperOrigin-RevId: 555333628
This change adds `xla_client.DeviceList` that is implemented in C++
`jax::PyDeviceList`. `jax::PyDeviceList` implements the features of
`pxla._DeviceAssignment` as a functional drop-in replacement.
`jax::PyDeviceList` internally has `xla::ifrt::DeviceList`, which will be used
when using IFRT APIs without having to construct a new copy of a potentially
large device list.
`pxla._DeviceAssignment`'s interface is changed slightly to encourage avoiding
conversion to tuple.
Note that for the backward compatibility (and fast `xla_client.Device`
conversion), `jax::PyDeviceList` still uses a Python tuple whose element can be
any Python object matches `xla_client.Device` interface with duck typing. This
duck typing support will be removed when such use case is deprecated.
Eventually, we can try to avoid any type conversion to remove a shadow copy of
device list in JAX.
PiperOrigin-RevId: 555317152
* https://reviews.llvm.org/D155209 added support to the MLIR Python bindings for passing types like bfloat16 directly if an explicit IR type is provided.
* the crash for non-splat size 1 constants appears fixed at head, although I don't know which change fixed it.
PiperOrigin-RevId: 555225604