This feature is necessary to fix the SMEM->GMEM waiting behavior in
`emit_pipeline`, which used a pessimistic condition prior to this change,
since every copy was its own commit group.
PiperOrigin-RevId: 734553668
Triton seems to support both 2D and 3D operands now, the latter case being a
batched matmul. We need more changes in the lowering to support 3D, so I will
leave it out of scope here.
Fixes#26013.
PiperOrigin-RevId: 733293299
Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas
module context---which should be enforced by Mosaic GPU.
This is in preparation of changes implementing transform inference.
PiperOrigin-RevId: 732091266
LLVM uses little-endian format for int4 packing. To avoid converting between
these formats, we should also use little-endian in XLA.
PiperOrigin-RevId: 731731530
The existing `int4` loading code is very generic. When reading contiguous data, it will read with offsets like `0, 0, 1, 1, ...`. Triton doesn't consider these to be contiguous in memory and emits much less efficient code than when reading contiguous blocks.
PiperOrigin-RevId: 731635736
While the predicate helps us avoid branching, it can be created once per
block. Its creation uses `*.sync` instructions, which are not DCEd by
LLVM and end up polluting the final code.
PiperOrigin-RevId: 731253109
When dma_execution_mode='on_wait', we wait to execute DMAs until we are interpreting a `dma_wait` instruction. In particular, while a device is waiting on a DMA semaphore, we will (partially) execute DMAs that signal that semaphore until the wait operation can succeed.
PiperOrigin-RevId: 731103569
jax_prng.PRNGKeyArray is not exposed to the public jax API, resulting in type check errors when sampling outside of tests.
PiperOrigin-RevId: 731008883
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)
Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).
You can still build the `jax` wheel with `python3 -m build` command.
Bazel `jax` wheel target: `//:jax_wheel`
Environment variables combinations for creating wheels with different versions:
* self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
* release: `--repo_env=ML_WHEEL_TYPE=release`
* release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
* nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`
PiperOrigin-RevId: 730916743
This shaves off a lot of complexity from our lowering code, while retaining
all of the functionality, except the arrive_tx optimization: `emit_pipeline`
arrives once per buffer, whereas the pipelining in the lowering used to
arrive once for all buffers.
PiperOrigin-RevId: 730824239
- This refactor just moves code around and should have no impact on tests or public-facing APIs.
- `mlir.emit_python_callback` would eventually depend on `ffi.ffi_lowering`, which in turn depends on definitions in `mlir.py`. We break this circular dependency.
PiperOrigin-RevId: 729561359
- Checks bounds for reads and writes to shared memory.
- Pads kernel arguments when necessary.
- Fix support for input-output aliasing.
- Fix handling of vmap'ed dimensions.
- Supports un-masked `pl.load` and masked or un-masked `pl.swap`.
- Switch to using single integer device IDs instead of tuples.
- Better error messages for unsupported primitives: `for_p`, `atomic_rmw_p`, and `atomic_cas_p` .
PiperOrigin-RevId: 727301519