With the advent of heterogenuous compute, XLA compilation now encompasses sub-compilation for multiple devices. These all can use LLVM, but with different settings. Today this means it is possible for one XLA client to reinitialize LLVM's global state while another client is in the middle of compilation.
Add a global lock around our LLVM usage. Concurrent compilation is still allowed, as long as both invocations have the same set of options. This means from within the same client multiple compilation invocations should still be non-blocking.
PiperOrigin-RevId: 695981613
This commit introduces new CI scripts and environment files for running Bazel CPU presubmits.
* Adds a ci directory at the root of the repository to store these files.
* Environment files are located in ci/envs and define new JAXCI_ environment variables to control CI build behavior.
* The build script sources these environment files and set up the build environment before running the build commands.
PiperOrigin-RevId: 695957540
Fixes a surprising interaction between the generator system in linear_util.py
and the try/finally python context managers we use for managing tracing context.
The `finally` block wasn't always being called until garbage collection, so the
context stack pushes/pops weren't always correctly nested. Dedenting the yield
fixes this particular bug but long-term we should get rid of linear_util
altogether.
PiperOrigin-RevId: 695898528
This CLs adds parameter names to the optional parameters of `tpu.sem_signal` -- `device_id`, `core_id` -- to remove the ambiguity upon deserialization.
Adds LIT tests of signalling on TC with parameter names.
PiperOrigin-RevId: 695875037
Unaligned concat used to be f32 only, but implicitly protected via unimplemented support for multi-row-shift in sub32 types. When this was added, we started invoking unaligned concat flow w/ sub32 types, but the masking code that assumed full rows (unpacked types) was no longer sufficient - we need better granularity for these cases. This only affects sublanes, as that is where we pack, we don't have partial lanes.
This CL, as a small benefit, also adds better error messages to the ops involved in lower_to_llo.cc.
PiperOrigin-RevId: 695796095
As reported in https://github.com/jax-ml/jax/issues/24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS.
PiperOrigin-RevId: 695694648
`xla::HostCallbackArgInfo` uses `uint16_t` for channel ids, so we should warn users explicitly when the channel ids exceed the UINT16_MAX instead of silently wrapping around.
PiperOrigin-RevId: 695682871
Previously, `jvp(lax.sort)` used a shape-dependent dtype, for
the types of indices (either `int32` or `int64`, depending on
the size of the dimension). For shape polymorphism, input shapes
can affect other intermediate shapes, but not `dtype`s.
In this case it is easy to just use `int46` independent of
the actual shape.
In some cases, `compilation_cache.is_cache_used` can reach the end of the function body without returning anything. This amounts to an implicit `return None`, which is not in line with the functions return type of `bool`. We fix this by adding a final `return False` to the function.
* Generalize any untiled memref to have tiling (packing, 128)
* Support dynamic index on 2nd minor.
* Support dynamic shape on 2nd minor.
PiperOrigin-RevId: 695516124