365 Commits

Author SHA1 Message Date
jax authors
e9ce8fb92d Merge pull request #27227 from jburnim:jburnim_pallas_interpret_mode4
PiperOrigin-RevId: 738235363
2025-03-18 20:22:27 -07:00
jax authors
01a110c4c9 Better mosaic lowering for dynamic shapes, extend an interpreter into shape_poly dimexpr and lower them alongside the graph if we are in a dynamic export regime.
PiperOrigin-RevId: 738171437
2025-03-18 15:51:15 -07:00
Jacob Burnim
47e8effdce Adds option to initialize buffers to NaNs or zeros in TPU interpret mode. 2025-03-18 12:24:45 -07:00
jax authors
7c5871f464 [Pallas TPU] Hoist prologue and epilogue outside of pipeline loop
PiperOrigin-RevId: 738038138
2025-03-18 09:40:43 -07:00
Sergei Lebedev
051687dc4c [pallas] pallas_call_p is now parameterized by a mesh
The mesh is necessary to add support for clusters to the Mosaic GPU backend.

PiperOrigin-RevId: 737792129
2025-03-17 16:30:40 -07:00
Justin Fu
dbd8d92075 [Pallas] Add legacy PRNG key support to Pallas PRNG
PiperOrigin-RevId: 736949584
2025-03-14 12:30:04 -07:00
Peter Hawkins
8ab33669e2 Add a variant of safe_map() that has no return value, named foreach().
This avoids a bunch of list bookkeeping in cases where we are iterating only for a side effect and do not care about the results.

I would have named this iter() to match OCaml's list.iter(), but unfortunately iter() is a Python builtin.

PiperOrigin-RevId: 736859418
2025-03-14 07:42:48 -07:00
jax authors
47bf22e37d [pallas][Mosaic][Easy] Add batch dot dim test, remove check
PiperOrigin-RevId: 736623531
2025-03-13 13:38:44 -07:00
Jevin Jiang
29bfd00f9c [Pallas TPU] Fix preferred_element_type propagation in dot_general with const
PiperOrigin-RevId: 735903687
2025-03-11 15:06:07 -07:00
jax authors
02505fa757 [Pallas TPU] Remove next_slot SMEM tensor from pipeline emitter
PiperOrigin-RevId: 735564365
2025-03-10 17:19:39 -07:00
jax authors
aceae84fab [Pallas] Enable skipping of floating-point operations when interpreting Pallas TPU kernels on CPU.
PiperOrigin-RevId: 735527650
2025-03-10 15:14:00 -07:00
Jacob Burnim
73d20cd62a [Pallas] Small fix to TPU interpret mode (input_output_aliases + scalar args).
PiperOrigin-RevId: 735455671
2025-03-10 11:40:10 -07:00
Jevin Jiang
0f0636afab [Mosaic TPU][Pallas] Add pl.reciprocal
PiperOrigin-RevId: 734749577
2025-03-07 18:29:30 -08:00
Jacob Burnim
016b351f00 [Pallas] Adds a simple dynamic race detector for TPU interpret mode.
PiperOrigin-RevId: 733885890
2025-03-05 15:15:21 -08:00
jax authors
2a1eeb0ce8 Chnages for kernel export
PiperOrigin-RevId: 732383028
2025-03-01 00:32:39 -08:00
Sharad Vikram
6f57410e12 [Pallas TPU] Use grid_env for pipeline body so we can query num_programs/program_id inside the block spec
PiperOrigin-RevId: 731831543
2025-02-27 12:53:02 -08:00
jax authors
da39b6f3d4 Comment change
PiperOrigin-RevId: 731792151
2025-02-27 11:07:59 -08:00
Sharad Vikram
2646b8d4ad [Pallas TPU] Add support for GridDimensionSemantics to pallas_call
PiperOrigin-RevId: 731543938
2025-02-26 19:34:36 -08:00
Sharad Vikram
1ecbac9702 [Pallas] Add name parameter to core_map
PiperOrigin-RevId: 731536152
2025-02-26 18:59:01 -08:00
Jacob Burnim
4c7140fa03 [Pallas] Add option for async DMAs in the new TPU interpret mode
When dma_execution_mode='on_wait', we wait to execute DMAs until we are interpreting a `dma_wait` instruction.  In particular, while a device is waiting on a DMA semaphore, we will (partially) execute DMAs that signal that semaphore until the wait operation can succeed.

PiperOrigin-RevId: 731103569
2025-02-25 18:19:20 -08:00
jax authors
7c26ab53f6 Use jax.Array as type annotation for pallas random keys
jax_prng.PRNGKeyArray is not exposed to the public jax API, resulting in type check errors when sampling outside of tests.

PiperOrigin-RevId: 731008883
2025-02-25 13:30:58 -08:00
jax authors
eb912ad0d9 Create jax wheel build target.
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)

Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).

You can still build the `jax` wheel with `python3 -m build` command.

Bazel `jax` wheel target: `//:jax_wheel`

Environment variables combinations for creating wheels with different versions:
  * self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * release: `--repo_env=ML_WHEEL_TYPE=release`
  * release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
  * nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 730916743
2025-02-25 09:30:08 -08:00
George Necula
c4e0db6f8a [better_errors] Port the Pallas debug info mechanisms to the new JAX DebugInfo.
Now that we carry debug informatiion in Jaxpr we can remove the Pallas-specific
tracking of the `func_src_info`, e.g., `NameAndSrcInfo`.
2025-02-25 14:43:17 +01:00
jax authors
b7968474c2 [Pallas][Mosaic] Support float8_e4m3b11fnuz
PiperOrigin-RevId: 729169181
2025-02-20 10:44:33 -08:00
Jacob Burnim
ac74857d27 [Pallas] Support dynamic grids in the new TPU interpret mode
PiperOrigin-RevId: 728786896
2025-02-19 13:09:23 -08:00
Jacob Burnim
962eb41933 [Mosaic] Several fixes/improvements for the new TPU interpret mode.
- Checks bounds for reads and writes to shared memory.
 - Pads kernel arguments when necessary.
 - Fix support for input-output aliasing.
 - Fix handling of vmap'ed dimensions.
 - Supports un-masked `pl.load` and masked or un-masked `pl.swap`.
 - Switch to using single integer device IDs instead of tuples.
 - Better error messages for unsupported primitives: `for_p`, `atomic_rmw_p`, and `atomic_cas_p` .

PiperOrigin-RevId: 727301519
2025-02-15 08:35:55 -08:00
Marcello Maggioni
9a8c9a56cf [JAX] Allow pallas to accept scalar shape semaphores.
PiperOrigin-RevId: 727198066
2025-02-14 23:20:47 -08:00
Sergei Lebedev
a73456d54d Removed unused `# type: ignore` comments
For future reference, this can be done via

    python -m mypy jax --warn-unused-ignores > /tmp/unused.txt
    while IFS=: read file line rest; do
      echo "$file:$line";
      gsed -i "${line}s/ *\# type: ignore\(\[[^]]*\]\)*//" "$file"
    done < /tmp/unused.txt
2025-02-13 21:12:27 +00:00
Adam Paszke
f1ab7514db Make sure we take libTPU version into account in the Pallas lowering
Also, strengthen the presubmit to make sure we catch more errors.

PiperOrigin-RevId: 726061633
2025-02-12 08:15:57 -08:00
Marcello Maggioni
6c6b5ec582 [JAX/Pallas] Add has_side_effect parameter to CompilerParams to stop CSE of operations.
Some pallas kernels shouldn't be CSEd even if they share the same inputs.
For example in async pallas scenarios like when you have a kernel starting some DMAs
that are waited in the user of the kernel (to perform async copies) we can't CSE or kernels
might wait multiple times on a DMA that happens only one.

PiperOrigin-RevId: 725752913
2025-02-11 13:33:01 -08:00
George Necula
550d1aa187 [better_errors] Continue adding debug info to Jaxprs (step 6)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
2025-02-11 11:28:58 +01:00
jax authors
ffd3faad72 [TPU[Mosaic] Fix missing sfences in smem DMAs
PiperOrigin-RevId: 725376627
2025-02-10 15:51:35 -08:00
jax authors
b7d012281e Merge pull request #26423 from gnecula:debug_info_jaxpr_7
PiperOrigin-RevId: 725317552
2025-02-10 12:58:26 -08:00
jax authors
6740165e4f [Pallas] Add pipeline mode to pltpu
PiperOrigin-RevId: 725133131
2025-02-10 02:36:44 -08:00
George Necula
817b3e5757 [better_errors] Continue adding debug info to Jaxprs (step 7)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
2025-02-09 18:14:33 +02:00
Jacob Burnim
1c82484c9b Start a new TPU interpret mode for Pallas.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.

The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.

The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:

 - Executing DMAs asynchronously.

 - Padding in pallas_call.

 - Propagating source info.
2025-02-06 13:04:14 -08:00
Sharad Vikram
02f4531310 [Pallas TPU] Add helpers for writing collectives
PiperOrigin-RevId: 723250661
2025-02-04 15:39:10 -08:00
Jevin Jiang
124e123946 [Pallas] Support promise_in_bounds mode in jnp.take_along_axis.
Change is also applied to jax because we don't need to normalize index if the mode is already "promise_in_bounds".

PiperOrigin-RevId: 722930215
2025-02-03 22:06:19 -08:00
Yash Katariya
bc1a706688 [sharding_in_types] Add a canonicalize_value step before dispatching bind so that we can insert mesh_casts under the following conditions:
* When current_mesh is Manual and aval mesh is Auto

* When current mesh is set and aval mesh is unset

* Final style primitives skip this canonicalization and they are free to add it in their own `bind` method.

* `mesh_cast` is skipped from this canonicalization to avoid recursion errors.

This is required to make sure that after we hit abstract_eval rule and check_jaxpr, everything is properly typed in JAX's type system.

`Auto` right now is a bit more permissive because we need to keep the current code at HEAD working but `Explicit` and `Manual` are very strict.

PiperOrigin-RevId: 722868091
2025-02-03 18:00:19 -08:00
Jacques Pienaar
60d3836fdf Propagate source ranges in location.
Previously only the line info was propagated. Given the new source range location support, propagate source range.

PiperOrigin-RevId: 722860932
2025-02-03 17:32:59 -08:00
Christos Perivolaropoulos
8649132d86 [pallas] Support DMA start partial discharge and run_scoped() does its own partial discharge.
This CL lays the ground for a future CL that makes run_scoped discharge to not request the discharge of the temporary buffers it creates. This causes issues becausa

a) dma_start can't discharge some but not all its references
b) run_scoped() lowering depends on run_scoped discharge to remove the run_scoped operation (or it goes in an infinite loop).

PiperOrigin-RevId: 722126566
2025-02-01 08:23:23 -08:00
Jevin Jiang
ed952c8e65 [Pallas TPU] Support jnp.take_along_axis for 32-bit vreg-sized vector.
PiperOrigin-RevId: 722015152
2025-01-31 21:27:08 -08:00
Justin Fu
54ac172b4c [Pallas] Refactor Pallas HLO interpret mode to a standalone file.
Also replaces the interpreter context (used only for handling extended dtypes) with a physicalize Jaxpr pass.

PiperOrigin-RevId: 720371033
2025-01-27 17:52:27 -08:00
Yash Katariya
704b2e5fba [sharding_in_types] Make vmap work with shard_map + pallas
PiperOrigin-RevId: 718578207
2025-01-22 16:48:32 -08:00
Aaron Russell Voelker
4173842736
add f-string to mosaic memory space error msg 2025-01-17 20:16:36 -05:00
Peter Hawkins
efab6945ca Remove code that supported jaxlib < 0.5.
The new xla_extension_version is 303 and the new mlir_api_version is 57.
2025-01-17 14:22:27 -05:00
jax authors
a527aba646 Reverts f1b894d14a28ac22a037fb79177b991275c75a18
PiperOrigin-RevId: 716653711
2025-01-17 07:00:31 -08:00
Sharad Vikram
0ac63157f5 [Pallas TPU] Add helpers file with copy_ref function
PiperOrigin-RevId: 716030813
2025-01-15 18:34:58 -08:00
jax authors
c4406d2759 [pallas] Fix bad rebase, deleted lowering for a print
PiperOrigin-RevId: 715694818
2025-01-15 01:18:30 -08:00
jax authors
c18492be65 [pallas][mosaic kernel export] Add initial support for exporting a dynamic shapes (placeholder bound) kernel out of mosaic, via pallas as both MLIR and jaxpr.
PiperOrigin-RevId: 715629439
2025-01-14 20:34:11 -08:00