So far all of our layouts have been tailored to a limited set of use
cases we've tried so far, but they're still not general enough to
handle all of the register layouts needed for WGMMA or mixed precision
matmuls (incl. intermediate steps during conversions). Instead of adding
more special cases, I decided to adopt XLA tiled layouts and they do seem
to work quite well!
This change only lays the groundwork for the new layout system. Future
changes will build upon them to add new features and eventually replace
`WGMMA_LAYOUT` altogether.
PiperOrigin-RevId: 694105514
Previously we didn't really fully discharge squeezing the indexed
dims before applying other GMEM transforms, leading to potential
failures because they were not anticipating the increased rank.
PiperOrigin-RevId: 694098739
This requires that the file providing the bindings has the same name as the
dialect it defines, since dialect search looks for a module path of the form
`<prefix>.<dialect namespace>`.
PiperOrigin-RevId: 693241875
A noticeable amount of time during JAX tracing is spent getting and setting the value of config.State objects, in particular the thread-local values within that state. If we move that logic into C++, we can speed up that code.
There are two main ways we can get a speedup:
* Python thread-local state is based around a dictionary and isn't terribly fast.
* we can have the C++ jit dispatch path directly access the configuration items it needs to include in its cache key. We spend a considerable amount of time in effect eagerly computing cache keys via update_thread_local_jit_state, although most of that is pointless work. Instead, we can have `jit` simply pull the config items it needs on demand.
PiperOrigin-RevId: 693114411
If I revert the change in `shard_map.py`, then the unit test added `test_partial_auto_propagate_through` fails with:
```
self.assertEqual(actual.sharding, sharding)
AssertionError: Named[18 chars]('i': 2, 'j': 2), spec=PartitionSpec(), memory_kind=device) != Named[18 chars]('i': 2, 'j': 2), spec=PartitionSpec('i',), memory_kind=device)
```
PiperOrigin-RevId: 692971413
`cuda_nvcc`, when installed e.g. via `pip` in a `venv` comes out as a
namespace package. The previous logic found the `cuda_nvcc` import but
failed because `cuda_nvcc.__file__ is None`.
With jax.experimental.export gone we can now do some cleanup in the export module.
In particular we remove the `export.args_spec` API, and the `lowering_platforms` arg for `export.export`. These were deprecated in June 2024.
PiperOrigin-RevId: 692398132
This CL deprecates the device_type parameter of `tpu_custom_call.as_tpu_kernel()` in favour of the `tpu.core_type` annotation.
The latter is more fine-grained: it is applied on `func.FuncOp` instead of the entire module, supports `tc`, `sc_scalar_subcore` and `sc_vector_subcore`.
`device_type` of the TPU CustomCall HLO is set to `sparsecore` if `sc_scalar_subcore` or `sc_vector_subcore` annotation is provided. Otherwise, `device_type` is not set and the CustomCall targets TC.
PiperOrigin-RevId: 692212644
Also improved the corresponding test cases to ensure better coverage and accuracy.
This PR is similar to https://github.com/jax-ml/jax/pull/22283, which adds lowering for `lax.sign`.
PiperOrigin-RevId: 691988164
mlir.ir.Context has the unfortunate behavior that it loads all dialects linked into the binary, even those we have no intention of using. This is fairly benign in JAX's usual configuration, but if JAX is linked together with other MLIR-using software it can be problematic.
PiperOrigin-RevId: 691984229
These APIs were all removed 3 or more months ago, and the registrations
here cause them to raise informative AttributeErrors. Enough time has
passed now that we can remove these.