745 Commits

Author SHA1 Message Date
jax authors
2c165bffc9 [pallas:triton] Lift dot_general restriction on minimal tile size for a.
PiperOrigin-RevId: 725605869
2025-02-11 06:27:16 -08:00
George Necula
550d1aa187 [better_errors] Continue adding debug info to Jaxprs (step 6)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
2025-02-11 11:28:58 +01:00
jax authors
ffd3faad72 [TPU[Mosaic] Fix missing sfences in smem DMAs
PiperOrigin-RevId: 725376627
2025-02-10 15:51:35 -08:00
jax authors
b7d012281e Merge pull request #26423 from gnecula:debug_info_jaxpr_7
PiperOrigin-RevId: 725317552
2025-02-10 12:58:26 -08:00
jax authors
6bedabd386 [TPU][Pallas][XLA] Add BUILD time codegen tool that turns a pallas kernel into a parameterized kernel loader header that can be utilized anywhere in C++
Next step here is to write a specialization pass that takes the kernel loaded above and binds values to it (already done in prototype/scratch)

PiperOrigin-RevId: 725271468
2025-02-10 10:45:32 -08:00
jax authors
6740165e4f [Pallas] Add pipeline mode to pltpu
PiperOrigin-RevId: 725133131
2025-02-10 02:36:44 -08:00
George Necula
817b3e5757 [better_errors] Continue adding debug info to Jaxprs (step 7)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
2025-02-09 18:14:33 +02:00
Sergei Lebedev
e5058079c9 [pallas:mosaic_gpu] Fixed a bug in how delay_release is handled in emit_pipeline
PiperOrigin-RevId: 724395676
2025-02-07 10:37:21 -08:00
Sergei Lebedev
35351f95e4 [pallas:triton] Really revert to the lowering using Triton IR
PiperOrigin-RevId: 724329911
2025-02-07 06:55:14 -08:00
jax authors
6ad38af473 Merge pull request #26368 from ROCm:fix-rocm-pallas-lowerings
PiperOrigin-RevId: 724328946
2025-02-07 06:52:01 -08:00
Jacob Burnim
1c82484c9b Start a new TPU interpret mode for Pallas.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.

The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.

The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:

 - Executing DMAs asynchronously.

 - Padding in pallas_call.

 - Propagating source info.
2025-02-06 13:04:14 -08:00
Mathew Odden
10c2374f61 Fix invalid lowerings for ROCm in Pallas
popcount and clz were effectively broken on ROCm,
since math_dialect fallbacks were resulting in
incorrect lowerings during compilation in XLA.

Use the device intrinsics for these functions, as
well as for exp and absf, which fixes some accuracy issues in
the pallas tests.

Docs for OCML/OCKL

- https://github.com/ROCm/llvm-project/blob/amd-staging/amd/device-libs/doc/OCML.md
- https://github.com/ROCm/llvm-project/blob/amd-staging/amd/device-libs/doc/OCKL.md

Co-Authored-By: <jason.furmanek@amd.com>
2025-02-06 13:59:58 -06:00
Sergei Lebedev
efbb0afd7a [pallas:triton] Temporarily reverted to the lowering using Triton IR
The new lowering caused a performance regression internally.

PiperOrigin-RevId: 723934141
2025-02-06 07:53:04 -08:00
George Necula
904b74860c [better_errors] Continue adding debug info to Jaxprs (step 3)
This follows after #26078, and #26313, adding `debug_info` to
more calls to `lu.wrap_init`.

As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
2025-02-06 16:26:49 +02:00
Christos Perivolaropoulos
eeace3ceba [pallas:mgpu] Cast all indices to i32 during lowering.
PiperOrigin-RevId: 723505268
2025-02-05 07:37:04 -08:00
Adam Paszke
1fbc4a15dd [Mosaic GPU] Infer whether A/B are row- or column-major from strides
There's no need to require extra arguments. This makes our calling convention
saner since the logical dimension order stays the same (e.g. for B it's always
k before n in the shape), only the in-memory representation changes.

Other than the API change, this is a NFC.

PiperOrigin-RevId: 723449720
2025-02-05 04:01:04 -08:00
Sharad Vikram
02f4531310 [Pallas TPU] Add helpers for writing collectives
PiperOrigin-RevId: 723250661
2025-02-04 15:39:10 -08:00
jax authors
414449e142 Merge pull request #26078 from gnecula:debug_info_jaxpr
PiperOrigin-RevId: 723151082
2025-02-04 10:54:26 -08:00
George Necula
d12aead696 [better_errors] Add debug info to more Jaxprs and WrappedFun (step 1)
The plan is for all `core.Jaxpr` and `lu.WrappedFun` to carry
non-None debug info.

We change `lu.wrap_init` to construct the result paths thunk
whenever it is passed a `debug_info`. The goal is to make sure that
all `WrappedFun` have a debug info with result paths support.

We change some calling conventions for internal functions to not
pass along a separate debug_info if we have a `WrappedFun` or
a `Jaxpr`.

We obtain several improvements in presence of debug infos
in debug_info_test.py
2025-02-04 10:02:35 +02:00
Jevin Jiang
124e123946 [Pallas] Support promise_in_bounds mode in jnp.take_along_axis.
Change is also applied to jax because we don't need to normalize index if the mode is already "promise_in_bounds".

PiperOrigin-RevId: 722930215
2025-02-03 22:06:19 -08:00
Yash Katariya
bc1a706688 [sharding_in_types] Add a canonicalize_value step before dispatching bind so that we can insert mesh_casts under the following conditions:
* When current_mesh is Manual and aval mesh is Auto

* When current mesh is set and aval mesh is unset

* Final style primitives skip this canonicalization and they are free to add it in their own `bind` method.

* `mesh_cast` is skipped from this canonicalization to avoid recursion errors.

This is required to make sure that after we hit abstract_eval rule and check_jaxpr, everything is properly typed in JAX's type system.

`Auto` right now is a bit more permissive because we need to keep the current code at HEAD working but `Explicit` and `Manual` are very strict.

PiperOrigin-RevId: 722868091
2025-02-03 18:00:19 -08:00
Jacques Pienaar
60d3836fdf Propagate source ranges in location.
Previously only the line info was propagated. Given the new source range location support, propagate source range.

PiperOrigin-RevId: 722860932
2025-02-03 17:32:59 -08:00
Sergei Lebedev
f58207a28d [pallas:triton] Fixed dispatch tablee for lax.pow_p
PiperOrigin-RevId: 722817510
2025-02-03 15:17:58 -08:00
Sergei Lebedev
7929cd8410 [pallas:triton] The lowering now uses PTX instead of Triton IR
This change improves the stability and backward compatibility of Pallas Triton
calls, because unlike PTX, the Triton dialect has no stability guarantees
and does change in practice.

See #25196.

A few notes

* Pallas Triton no longer delegates compilation to PTX to XLA:GPU. Instead,
  compilation is done via a new PjRt extension, which uses its own compilation
  pipeline mirrored after the one in the Triton Python bindings.
* The implementation of the old custom call used by Pallas Triton is
  deprecated and will be removed after 6 months as per
  [compatibility guarantees] [*]

[*]: https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees

PiperOrigin-RevId: 722773884
2025-02-03 13:21:40 -08:00
Sergei Lebedev
bf6489ff5b [pallas:triton] Fallback lowering rules for math functions now use general dtypes
Previously, it was necessary to list all dtypes explicitly, which is why
we had separate fallback rules for float16 and bfloat16 for some functions.

PiperOrigin-RevId: 722729554
2025-02-03 11:21:11 -08:00
jax authors
7e353913f2 Merge pull request #26262 from gnecula:debug_info_one
PiperOrigin-RevId: 722684417
2025-02-03 09:17:13 -08:00
Christos Perivolaropoulos
b48d15d788 [pallas_mgpu] For loops can have **non-ref** accumulators for carries.
The user has access only to accumulator references and they can't pass them as caries to loops. However when they are discharged these accumulators become values and become part of the carry. Before this CL this would surprise the loop lowering code.

This was never a problem for pallas mgpu until we added pipelining loops instead of sequential bloc axes.

PiperOrigin-RevId: 722495749
2025-02-02 21:03:26 -08:00
George Necula
c70de6deed [better_errors] Merge the JaxprDebugInfo and TracingDebugInfo into core.DebugInfo
Previously, we had two almost identical classes: `TracingDebugInfo` and
`JaxprDebugInfo`. The only difference was that `TracingDebugInfo` had
a thunk to return the result paths, while `JaxprDebugInfo` had the
result paths resolved to a tuple. The separation of these types
provided some clarity, but also led to code duplication and
required conversions as the debugging info goes from `WrappedFun`
to a `Jaxpr` and then to `WrappedFun` again.
2025-02-02 06:23:03 +02:00
Christos Perivolaropoulos
b23f8f414b [pallas/pallas_mgpu] Discharging run_scoped should not be discharging the intermediates
When we do run_scoped[jaxpr, R1,R2], it can't be assumed that references
corresponding to R1 and R2 can be safely discharged. Sometimes they can (eg
Accumulator) but sometimes they can't (eg SMEM scratch). It should be up to the
lowering rule to do such discharging.

This further means that during lowering there is no guarantee that the
references will not be used/returned by nested scoped blocks so we also remove
that check.

PiperOrigin-RevId: 722137352
2025-02-01 09:37:03 -08:00
Christos Perivolaropoulos
8649132d86 [pallas] Support DMA start partial discharge and run_scoped() does its own partial discharge.
This CL lays the ground for a future CL that makes run_scoped discharge to not request the discharge of the temporary buffers it creates. This causes issues becausa

a) dma_start can't discharge some but not all its references
b) run_scoped() lowering depends on run_scoped discharge to remove the run_scoped operation (or it goes in an infinite loop).

PiperOrigin-RevId: 722126566
2025-02-01 08:23:23 -08:00
Jevin Jiang
ed952c8e65 [Pallas TPU] Support jnp.take_along_axis for 32-bit vreg-sized vector.
PiperOrigin-RevId: 722015152
2025-01-31 21:27:08 -08:00
Justin Fu
54ac172b4c [Pallas] Refactor Pallas HLO interpret mode to a standalone file.
Also replaces the interpreter context (used only for handling extended dtypes) with a physicalize Jaxpr pass.

PiperOrigin-RevId: 720371033
2025-01-27 17:52:27 -08:00
George Necula
878272ee3c [better_errors] Refactor more uses of pe.tracing_debug_info (part 2)
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 2 of a series (following #26097) for Pallas.
2025-01-27 16:10:56 +02:00
Adam Paszke
c10b9b88f2 [Pallas:MGPU] Add helpers to make writing core_map kernels less verbose
Also add small "getting started" examples that use the helpers in tests.

PiperOrigin-RevId: 719303512
2025-01-24 07:59:26 -08:00
Yash Katariya
704b2e5fba [sharding_in_types] Make vmap work with shard_map + pallas
PiperOrigin-RevId: 718578207
2025-01-22 16:48:32 -08:00
Justin Fu
10bb38bb79 [Mosaic GPU] Add manual consumed barrier handling to WS pipeline.
PiperOrigin-RevId: 718451678
2025-01-22 10:59:58 -08:00
George Necula
3f73f7b0eb [better_errors] Ensure debug_info.arg_names is never None.
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.

When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.

We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
2025-01-21 13:38:10 +01:00
George Necula
4fd0bb05b1 [better_errors] Finally remove api_util.debug_info.
Following https://github.com/jax-ml/jax/pull/25916 there were a few TODOs
left in the code to remove api_util.debug_info and replace the
one remaining use with api_util.tracing_debug_info.

PiperOrigin-RevId: 717583667
2025-01-20 11:19:53 -08:00
George Necula
dcf72b01f4 [better_errors] Improvements in propagation of debugging info
Added some documentation for `TracingDebugInfo` (docstring, comments
about `arg_names`, since it was not obvious to me that this would
flatten the non-static arguments).

Laying the ground for the unification of the old `api_util.debug_info`
and `partial_eval.tracing_debug_info`: we rename the former to
`api_util.tracing_debug_info`, we push inside the calls to
`fun_sourceinfo` and `fun_signature` (which were done by the callers
until now), and we rewrite the latter in terms
of the former. We leave for a future PR the actual replacing of the
latter with the former throughout.

In the process of above, cleaned up the one case when `partial_eval.tracing_debug_info`
received None for the `in_tree` and `out_tracer_thunk`. The function contained
catch-all exception clauses to handle those, but doing so it masked other places
where we fail to collect debug info due to programming mistakes. E.g., in
one place we passed a `WrappedFun` instead of a `Callable`, resulting in missing debugging info.

Added more type declarations.

Added a `state_test` with a failure to track debugging information, manifested
with a leaked tracer without function provenance. Fixing this in a subsequent PR.
2025-01-20 15:09:51 +01:00
Aaron Russell Voelker
4173842736
add f-string to mosaic memory space error msg 2025-01-17 20:16:36 -05:00
Peter Hawkins
efab6945ca Remove code that supported jaxlib < 0.5.
The new xla_extension_version is 303 and the new mlir_api_version is 57.
2025-01-17 14:22:27 -05:00
jax authors
a527aba646 Reverts f1b894d14a28ac22a037fb79177b991275c75a18
PiperOrigin-RevId: 716653711
2025-01-17 07:00:31 -08:00
Sergei Lebedev
d34c40f6b6 [mosaic_gpu] Added a serialization pass
The pass adds versioning to the Mosaic GPU IR in the lowered custom calls
and can apply forward/backward migration rules. Currently, no rules are
necessary since we are at version 1.

PiperOrigin-RevId: 716596848
2025-01-17 03:12:51 -08:00
Yash Katariya
97cd748376 Rename out_type -> out_sharding parameter on einsum
PiperOrigin-RevId: 716454800
2025-01-16 18:16:52 -08:00
Sharad Vikram
0ac63157f5 [Pallas TPU] Add helpers file with copy_ref function
PiperOrigin-RevId: 716030813
2025-01-15 18:34:58 -08:00
jax authors
70c1ee5d9c Merge pull request #25876 from gnecula:debug_info_3
PiperOrigin-RevId: 715831527
2025-01-15 09:35:03 -08:00
Sergei Lebedev
afcb21ddf1 [pallas:mosaic_gpu] Fixed a crash in MLIR Python bindings
The error message produced by MLIR is not really clear, but AFAICT the crash
was caused by the "temporary module" hack we use in the lax.cond lowering
rule.

PiperOrigin-RevId: 715785632
2025-01-15 07:09:43 -08:00
George Necula
f9dfe7f646 [better_errors] More cleanup 2025-01-15 10:22:29 +00:00
jax authors
c4406d2759 [pallas] Fix bad rebase, deleted lowering for a print
PiperOrigin-RevId: 715694818
2025-01-15 01:18:30 -08:00
jax authors
c18492be65 [pallas][mosaic kernel export] Add initial support for exporting a dynamic shapes (placeholder bound) kernel out of mosaic, via pallas as both MLIR and jaxpr.
PiperOrigin-RevId: 715629439
2025-01-14 20:34:11 -08:00