805 Commits

Author SHA1 Message Date
Sharad Vikram
c6b164dc09 [Pallas/Fuser] Add custom evaluate to allow/disallow transposes
PiperOrigin-RevId: 735931978
2025-03-11 16:35:49 -07:00
Jevin Jiang
29bfd00f9c [Pallas TPU] Fix preferred_element_type propagation in dot_general with const
PiperOrigin-RevId: 735903687
2025-03-11 15:06:07 -07:00
jax authors
02505fa757 [Pallas TPU] Remove next_slot SMEM tensor from pipeline emitter
PiperOrigin-RevId: 735564365
2025-03-10 17:19:39 -07:00
jax authors
aceae84fab [Pallas] Enable skipping of floating-point operations when interpreting Pallas TPU kernels on CPU.
PiperOrigin-RevId: 735527650
2025-03-10 15:14:00 -07:00
Sharad Vikram
81dde225b0 [Pallas/Fuser] Add select_n push rule
PiperOrigin-RevId: 735510713
2025-03-10 14:23:01 -07:00
Sharad Vikram
87272fbe93 [Pallas/Fuser] Add debug option to fuser.fuse that prints out jaxpr
PiperOrigin-RevId: 735505460
2025-03-10 14:07:26 -07:00
Jacob Burnim
73d20cd62a [Pallas] Small fix to TPU interpret mode (input_output_aliases + scalar args).
PiperOrigin-RevId: 735455671
2025-03-10 11:40:10 -07:00
Sergei Lebedev
91340ea0a7 [pallas:mosaic_gpu] Added support for math functions to the WG lowering
PiperOrigin-RevId: 735333893
2025-03-10 05:08:19 -07:00
Benjamin Chetioui
75d8702023 [Pallas/Mosaic GPU] Add lowerings/layout inference for all the necessary conversion ops when using Warpgroup semantics.
Enable some of the pre-existing Pallas `ops_test`s for testing.

PiperOrigin-RevId: 735293084
2025-03-10 02:14:39 -07:00
Jevin Jiang
0f0636afab [Mosaic TPU][Pallas] Add pl.reciprocal
PiperOrigin-RevId: 734749577
2025-03-07 18:29:30 -08:00
Christos Perivolaropoulos
eeccc67c0b [mgpu] Debug print arrays.
PiperOrigin-RevId: 734576543
2025-03-07 08:58:25 -08:00
Sergei Lebedev
928caf83ee [pallas:mosaic_gpu] copy_smem_to_gmem now allows skipping cp.async.commit_group
This feature is necessary to fix the SMEM->GMEM waiting behavior in
`emit_pipeline`, which used a pessimistic condition prior to this change,
since every copy was its own commit group.

PiperOrigin-RevId: 734553668
2025-03-07 07:43:54 -08:00
Sergei Lebedev
2a34019388 [pallas:mosaic_gpu] Added WG lowering rule for lax.bitcast_convert_type_p
PiperOrigin-RevId: 734081448
2025-03-06 04:09:55 -08:00
Chris Jones
d6b97c2026 [pallas] Add support for pl.dot with int8 inputs.
PiperOrigin-RevId: 734081057
2025-03-06 04:08:04 -08:00
Jacob Burnim
016b351f00 [Pallas] Adds a simple dynamic race detector for TPU interpret mode.
PiperOrigin-RevId: 733885890
2025-03-05 15:15:21 -08:00
Gleb Pobudzey
43b6be0e81 [Mosaic GPU] Add lowering for log, and a fast path using log2.
PiperOrigin-RevId: 733411276
2025-03-04 11:50:50 -08:00
Sergei Lebedev
155839bb4d [pallas:triton] Emit a better error message for matmul with non-2D operands
Triton seems to support both 2D and 3D operands now, the latter case being a
batched matmul. We need more changes in the lowering to support 3D, so I will
leave it out of scope here.

Fixes #26013.

PiperOrigin-RevId: 733293299
2025-03-04 05:46:29 -08:00
Sharad Vikram
00d9f4529d [Pallas/Fuser] Add support for custom_call_jvp/pjit to push_block_spec
PiperOrigin-RevId: 733122108
2025-03-03 17:43:13 -08:00
Sharad Vikram
d32e282ff9 Add fuser to jax.experimental.pallas
Note that fuser is considered experimental within Pallas and APIs are subject to change

PiperOrigin-RevId: 733117882
2025-03-03 17:26:44 -08:00
Sharad Vikram
0b6c355083 [Pallas] Add experimental (private for now) API for manual fusion into Pallas kernels
PiperOrigin-RevId: 733112191
2025-03-03 17:05:51 -08:00
jax authors
2a1eeb0ce8 Chnages for kernel export
PiperOrigin-RevId: 732383028
2025-03-01 00:32:39 -08:00
Benjamin Chetioui
a9ab614123 [Pallas/Mosaic GPU] Add an abstraction to obtain a slice of dynamic shared memory when using waprgroup semantics.
Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas
module context---which should be enforced by Mosaic GPU.

This is in preparation of changes implementing transform inference.

PiperOrigin-RevId: 732091266
2025-02-28 04:38:25 -08:00
Sharad Vikram
6f57410e12 [Pallas TPU] Use grid_env for pipeline body so we can query num_programs/program_id inside the block spec
PiperOrigin-RevId: 731831543
2025-02-27 12:53:02 -08:00
jax authors
da39b6f3d4 Comment change
PiperOrigin-RevId: 731792151
2025-02-27 11:07:59 -08:00
Adrian Kuegel
de4d047852 Change int4 packing from big-endian to little-endian
LLVM uses little-endian format for int4 packing. To avoid converting between
these formats, we should also use little-endian in XLA.

PiperOrigin-RevId: 731731530
2025-02-27 08:13:43 -08:00
Chris Jones
d6752e9267 [pallas:triton] Generate more efficient code for loading contiguous slices of int4 values.
The existing `int4` loading code is very generic. When reading contiguous data, it will read with offsets like `0, 0, 1, 1, ...`. Triton doesn't consider these to be contiguous in memory and emits much less efficient code than when reading contiguous blocks.

PiperOrigin-RevId: 731635736
2025-02-27 01:57:47 -08:00
Sharad Vikram
2646b8d4ad [Pallas TPU] Add support for GridDimensionSemantics to pallas_call
PiperOrigin-RevId: 731543938
2025-02-26 19:34:36 -08:00
Sharad Vikram
1ecbac9702 [Pallas] Add name parameter to core_map
PiperOrigin-RevId: 731536152
2025-02-26 18:59:01 -08:00
Peter Hawkins
66293d8897 Remove code present to support jaxlib < 0.5.1.
The new minimum xla_extension_version is 317 and the new mlir_api_version is 58.
2025-02-26 07:40:40 -05:00
Adam Paszke
3251b55ef2 [Pallas:MGPU] Don't recreate single_thread_predicate at every rule
While the predicate helps us avoid branching, it can be created once per
block. Its creation uses `*.sync` instructions, which are not DCEd by
LLVM and end up polluting the final code.

PiperOrigin-RevId: 731253109
2025-02-26 04:02:21 -08:00
Benjamin Chetioui
7a34f1cedc [Pallas/Mosaic GPU][NFC] Move thread_semantics to ModuleContext.
This simplifies the propagation of the argument, and is the proper place to
put it.

PiperOrigin-RevId: 731239831
2025-02-26 03:08:42 -08:00
Jacob Burnim
4c7140fa03 [Pallas] Add option for async DMAs in the new TPU interpret mode
When dma_execution_mode='on_wait', we wait to execute DMAs until we are interpreting a `dma_wait` instruction.  In particular, while a device is waiting on a DMA semaphore, we will (partially) execute DMAs that signal that semaphore until the wait operation can succeed.

PiperOrigin-RevId: 731103569
2025-02-25 18:19:20 -08:00
jax authors
7c26ab53f6 Use jax.Array as type annotation for pallas random keys
jax_prng.PRNGKeyArray is not exposed to the public jax API, resulting in type check errors when sampling outside of tests.

PiperOrigin-RevId: 731008883
2025-02-25 13:30:58 -08:00
jax authors
eb912ad0d9 Create jax wheel build target.
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)

Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).

You can still build the `jax` wheel with `python3 -m build` command.

Bazel `jax` wheel target: `//:jax_wheel`

Environment variables combinations for creating wheels with different versions:
  * self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * release: `--repo_env=ML_WHEEL_TYPE=release`
  * release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
  * nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 730916743
2025-02-25 09:30:08 -08:00
jax authors
0f8e6b996d Typecheck pallas.CostEstimate
Passing a float can lead to miscompilations

PiperOrigin-RevId: 730909635
2025-02-25 09:08:29 -08:00
Adam Paszke
3d87a01bea [Pallas:MGPU] Adjust warpgroup lowering to the recent emit_pipeline changes
The Pallas-level pipelining generates a number of ops we haven't had to deal with before
like conditionals, scans, etc.

PiperOrigin-RevId: 730899808
2025-02-25 08:39:44 -08:00
George Necula
c4e0db6f8a [better_errors] Port the Pallas debug info mechanisms to the new JAX DebugInfo.
Now that we carry debug informatiion in Jaxpr we can remove the Pallas-specific
tracking of the `func_src_info`, e.g., `NameAndSrcInfo`.
2025-02-25 14:43:17 +01:00
Sergei Lebedev
c13a2f95d5 [pallas:mosaic_gpu] Use emit_pipeline for pipelining in the lowering
This shaves off a lot of complexity from our lowering code, while retaining
all of the functionality, except the arrive_tx optimization: `emit_pipeline`
arrives once per buffer, whereas the pipelining in the lowering used to
arrive once for all buffers.

PiperOrigin-RevId: 730824239
2025-02-25 04:14:10 -08:00
Adam Paszke
676acebafa [Pallas:MGPU] Enable lowering for .astype and scalar broadcasts
PiperOrigin-RevId: 730805326
2025-02-25 03:01:11 -08:00
Adam Paszke
71c7622037 [Pallas:MGPU] Change WG semantics convention to represent scalar arrays using scalars
Previously every ShapedArray got converted to an MLIR vector which was more annoying
than helpful.

PiperOrigin-RevId: 730795455
2025-02-25 02:24:06 -08:00
Adam Paszke
d0d5bba645 [Pallas:MGPU] Avoid SMEM->GMEM wait if no outputs are transferred in the pipeline loop
The TMA wait does not add much overhead, but it lets us save on an unnecessary warpgroup barrier.

PiperOrigin-RevId: 730795234
2025-02-25 02:22:25 -08:00
Sergei Lebedev
7eadc64b5a [pallas:mosaic_gpu] Added WG lowering rules for TMA primitives and run_scoped_p
PiperOrigin-RevId: 730780335
2025-02-25 01:32:43 -08:00
Adam Paszke
80848ad859 [Pallas:MGPU] Consistently use i32 as the grid index type in emit_pipeline
That's more consistent with Pallas semantics and avoids generating a slightly
different kernel depending on x32/x64 mode.

PiperOrigin-RevId: 730778314
2025-02-25 01:24:50 -08:00
Sergei Lebedev
74b2e0203f [pallas:mosaic_gpu] Use {min,max}imumf instead of {min,max}numf
PiperOrigin-RevId: 730154865
2025-02-23 09:52:48 -08:00
Daniel Suo
2d1bc5c2a0 Refactor Jax FFI lowering to prepare for implementing CPU/GPU callbacks using XLA's FFI.
- This refactor just moves code around and should have no impact on tests or public-facing APIs.
- `mlir.emit_python_callback` would eventually depend on `ffi.ffi_lowering`, which in turn depends on definitions in `mlir.py`. We break this circular dependency.

PiperOrigin-RevId: 729561359
2025-02-21 09:45:59 -08:00
Sergei Lebedev
7438976e79 [pallas:mosaic_gpu] Added support for binary/comparison ops with WG semantics
PiperOrigin-RevId: 729266484
2025-02-20 15:06:27 -08:00
jax authors
b7968474c2 [Pallas][Mosaic] Support float8_e4m3b11fnuz
PiperOrigin-RevId: 729169181
2025-02-20 10:44:33 -08:00
Jacob Burnim
ac74857d27 [Pallas] Support dynamic grids in the new TPU interpret mode
PiperOrigin-RevId: 728786896
2025-02-19 13:09:23 -08:00
Yash Katariya
a3edfb43ef Now that sharding_in_types config flag is True, remove the config and all the conditionals
PiperOrigin-RevId: 728653433
2025-02-19 06:53:35 -08:00
Jacob Burnim
962eb41933 [Mosaic] Several fixes/improvements for the new TPU interpret mode.
- Checks bounds for reads and writes to shared memory.
 - Pads kernel arguments when necessary.
 - Fix support for input-output aliasing.
 - Fix handling of vmap'ed dimensions.
 - Supports un-masked `pl.load` and masked or un-masked `pl.swap`.
 - Switch to using single integer device IDs instead of tuples.
 - Better error messages for unsupported primitives: `for_p`, `atomic_rmw_p`, and `atomic_cas_p` .

PiperOrigin-RevId: 727301519
2025-02-15 08:35:55 -08:00