Turns out that waiting for the kernel to finish it not enough, since the
prints also need to be processed by the CUDA runtime. Using a test-only
function that synchronizes all the devices seems to suffice.
PiperOrigin-RevId: 690624999
Repeated string addition is apparently a bit of an anti-pattern. Not that it matters
much in this place, but why not do it properly.
PiperOrigin-RevId: 689416587
Originally proposed in #24021. Slightly rewritter to make testing with internal LLVM toolchains better.
Use CUDA driver API to query major and minor compute capabilities, thus arriving at a "base" SM string (e.g. `sm_90`).
Then use LLVM to see if we can "upgrade" the base SM string to one that enables architecture-specific capabilities (e.g. `sm_90a`).
Then use LLVM to map the SM string to a PTX ISA version that supports the SM.
Co-authored-by: Andrey Portnoy <aportnoy@nvidia.com>
PiperOrigin-RevId: 689286774
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.
This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:
1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.
2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.
Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.
PiperOrigin-RevId: 687106965
Introduces `jax.config.array_garbage_collection_guard`, which is a tristate config for setting up a `jax.Array` garbage collection guard. The possible configs are:
* allow: `jax.Array`s are allowed to be garbage collected. This is the default value.
* log: whenever a `jax.Array` is GCed a log entry is generated with the array's traceback.
* fatal: fatal crash when a `jax.Array` is GCed. This is meant to be used for mature code bases that do tight memory management, and are reference cycle free.
PiperOrigin-RevId: 687003464
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 685689593
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 685679646
* Separate MXU size to MXU contracting size and MXU non-contracting size.
* Rename tile to group for MXU shaped tiling since tile is overused in Mosaic.
PiperOrigin-RevId: 684116306
These kernels support shape polymorphism in all dimensions and no GPU is required during lowering. The kernels have been included in jaxlib for more than 3 weeks so we don't need to include any forward compatibility checks.
PiperOrigin-RevId: 682415506
As part of this change, I've added support and tests for shape polymorphism and export on CPU and GPU.
The FFI kernels have been available in jaxlib for over 3 weeks already and they are included with the latest release of jaxlib on PyPI so we don't need to worry about the forward compatibility checks. With this in mind, I also removed the old lowering rules, but kept the backwards compatibility tests for now.
PiperOrigin-RevId: 682312752
We still have this temporary check in apply vector layout, but in infer vector layout, instead of throwing error, we should just reset offset to zero. Because some ops which has relaxed this restriction might be passed as input for un-relaxed ops and cause failure.
PiperOrigin-RevId: 680706301
This fixes the failure in elementwise rule of apply vector layout pass.
If the condition scalar is static, it will be simplified to corresponding vector from true value and false value by MLIR.
If the condition scalar is dynamic, we want to use vselect over scf.if anyway. Because latter creates a inner region.
PiperOrigin-RevId: 680674560
We're seeing test failures from tests assuming that this dialect exists. But given we plan to enable it at some point, we may as well just include it in the build.
The size impact is small (around 400K uncompressed).
PiperOrigin-RevId: 679608092
Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled".
We change the relationship between enable_backends, disable_configs, enable_configs to be the following:
* `enable_backends` selects a set of initial test configurations to enable, based off backend only.
* `disable_configs` then prunes that set of test configurations, removing elements from the set.
* `enable_configs` then adds additional configurations to the set.
Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target.
PiperOrigin-RevId: 679563155
The goal of this change is to catch PRs that introduce new warnings sooner.
To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.
Add code to suppress some new warnings uncovered in CI.
PiperOrigin-RevId: 678352286