140 Commits

Author SHA1 Message Date
Peter Hawkins
8ab33669e2 Add a variant of safe_map() that has no return value, named foreach().
This avoids a bunch of list bookkeeping in cases where we are iterating only for a side effect and do not care about the results.

I would have named this iter() to match OCaml's list.iter(), but unfortunately iter() is a Python builtin.

PiperOrigin-RevId: 736859418
2025-03-14 07:42:48 -07:00
Matthew Johnson
251b93ebd7 fixups that we meant to include in #26427
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-08 00:03:26 +00:00
jax authors
6095af050f Merge pull request #26427 from mattjj:direct-linearize-fixes
PiperOrigin-RevId: 734687601
2025-03-07 14:22:16 -08:00
Matthew Johnson
7c2f842353 shard_map and other fixes to direct-linearize
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-07 21:02:40 +00:00
Matthew Johnson
0e30a3ace9 [mutable-arrays] read values should have the same explicit sharding as ref
fixes #26936
2025-03-07 20:53:29 +00:00
Sharad Vikram
b5fcffadd4 Add swap as method to TransformedRef
PiperOrigin-RevId: 731541165
2025-02-26 19:19:10 -08:00
Yash Katariya
8305803b76 [sharding_in_types] Initial support for partial-auto/explicit shard_map + sharding-in-types. If the axes in shmap(..., auto=...) is an explicit axes in the outer mesh context, then that axis is treated as Explicit instead of Auto.
PiperOrigin-RevId: 728920514
2025-02-19 20:04:54 -08:00
Sergei Lebedev
a73456d54d Removed unused `# type: ignore` comments
For future reference, this can be done via

    python -m mypy jax --warn-unused-ignores > /tmp/unused.txt
    while IFS=: read file line rest; do
      echo "$file:$line";
      gsed -i "${line}s/ *\# type: ignore\(\[[^]]*\]\)*//" "$file"
    done < /tmp/unused.txt
2025-02-13 21:12:27 +00:00
Sergei Lebedev
194884d311 Migrated to mypy 1.14.1 with --allow_redefinition
I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,

   def f(x: int) -> str: ...
   def g(x: int) -> str: ...

   callback = f if ... else g  # has type object!
2025-02-13 15:38:28 +00:00
Yash Katariya
2d01df760b [sharding_in_types] Make the typing checks and sharding rule checks a little bit less strict when the current or aval mesh is empty/unset. Also some more changes as listed below:
* get_aval is not context dependent

* canonicalization does not happen for avals on an empty mesh

* jax.jit does not set abstract mesh context anymore before tracing

* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode

* Even if use_mesh is not used in explicit sharding mode, computation follows data works!

* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)

* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.

As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.

PiperOrigin-RevId: 726097292
2025-02-12 10:03:01 -08:00
George Necula
550d1aa187 [better_errors] Continue adding debug info to Jaxprs (step 6)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
2025-02-11 11:28:58 +01:00
George Necula
817b3e5757 [better_errors] Continue adding debug info to Jaxprs (step 7)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
2025-02-09 18:14:33 +02:00
jax authors
c46b0215b0 Merge pull request #26313 from gnecula:debug_info_vjp
PiperOrigin-RevId: 723575296
2025-02-05 10:58:10 -08:00
George Necula
abcaec7081 [better_errors] Add debug info to the Jaxprs formed for AD
Following #26078 , we add debug info to more calls of lu.wrap_init.
2025-02-05 19:21:02 +02:00
Sharad Vikram
782215b099 Add get/set methods to TransformedRef
PiperOrigin-RevId: 723188696
2025-02-04 12:34:34 -08:00
jax authors
414449e142 Merge pull request #26078 from gnecula:debug_info_jaxpr
PiperOrigin-RevId: 723151082
2025-02-04 10:54:26 -08:00
George Necula
d12aead696 [better_errors] Add debug info to more Jaxprs and WrappedFun (step 1)
The plan is for all `core.Jaxpr` and `lu.WrappedFun` to carry
non-None debug info.

We change `lu.wrap_init` to construct the result paths thunk
whenever it is passed a `debug_info`. The goal is to make sure that
all `WrappedFun` have a debug info with result paths support.

We change some calling conventions for internal functions to not
pass along a separate debug_info if we have a `WrappedFun` or
a `Jaxpr`.

We obtain several improvements in presence of debug infos
in debug_info_test.py
2025-02-04 10:02:35 +02:00
Yash Katariya
bc1a706688 [sharding_in_types] Add a canonicalize_value step before dispatching bind so that we can insert mesh_casts under the following conditions:
* When current_mesh is Manual and aval mesh is Auto

* When current mesh is set and aval mesh is unset

* Final style primitives skip this canonicalization and they are free to add it in their own `bind` method.

* `mesh_cast` is skipped from this canonicalization to avoid recursion errors.

This is required to make sure that after we hit abstract_eval rule and check_jaxpr, everything is properly typed in JAX's type system.

`Auto` right now is a bit more permissive because we need to keep the current code at HEAD working but `Explicit` and `Manual` are very strict.

PiperOrigin-RevId: 722868091
2025-02-03 18:00:19 -08:00
George Necula
c70de6deed [better_errors] Merge the JaxprDebugInfo and TracingDebugInfo into core.DebugInfo
Previously, we had two almost identical classes: `TracingDebugInfo` and
`JaxprDebugInfo`. The only difference was that `TracingDebugInfo` had
a thunk to return the result paths, while `JaxprDebugInfo` had the
result paths resolved to a tuple. The separation of these types
provided some clarity, but also led to code duplication and
required conversions as the debugging info goes from `WrappedFun`
to a `Jaxpr` and then to `WrappedFun` again.
2025-02-02 06:23:03 +02:00
Yash Katariya
dcb28f1218 [sharding_in_types] Add vmap + explicit sharding support. The main changes are:
* Track `explicit_mesh_axis` on `AxisData`.
* Modify `unmapped_aval` to the the above explicit mesh axis and insert it into the right place in the sharding so out_shardings are correct.
* Make `matchaxis` also handle shardings correctly
* All mapped dimensions should be sharded the same way
* spmd_axis_name and explicit sharded arrays cannot be used together
* `out_shardings` parameter on `dot_general`, `broadcast_in_dim`, `reshape`, `reshard` and `mesh_cast` is handled correctly in presence of vmap.

This should eventually help us get rid of `spmd_axis_name` from `vmap`.

PiperOrigin-RevId: 721007659
2025-01-29 09:34:27 -08:00
Yash Katariya
8f248fe626 [sharding_in_types] Upstream changes from defaulting sharding_in_types config to True experiment. There aren't a lot of failures in TGP but we can atleast upstream these changes until we work on the failures.
PiperOrigin-RevId: 720639755
2025-01-28 11:04:42 -08:00
Adam Paszke
c10b9b88f2 [Pallas:MGPU] Add helpers to make writing core_map kernels less verbose
Also add small "getting started" examples that use the helpers in tests.

PiperOrigin-RevId: 719303512
2025-01-24 07:59:26 -08:00
Yash Katariya
23d360bded Remove axis_name from unmapped_aval
PiperOrigin-RevId: 718558713
2025-01-22 15:49:04 -08:00
George Necula
dcf72b01f4 [better_errors] Improvements in propagation of debugging info
Added some documentation for `TracingDebugInfo` (docstring, comments
about `arg_names`, since it was not obvious to me that this would
flatten the non-static arguments).

Laying the ground for the unification of the old `api_util.debug_info`
and `partial_eval.tracing_debug_info`: we rename the former to
`api_util.tracing_debug_info`, we push inside the calls to
`fun_sourceinfo` and `fun_signature` (which were done by the callers
until now), and we rewrite the latter in terms
of the former. We leave for a future PR the actual replacing of the
latter with the former throughout.

In the process of above, cleaned up the one case when `partial_eval.tracing_debug_info`
received None for the `in_tree` and `out_tracer_thunk`. The function contained
catch-all exception clauses to handle those, but doing so it masked other places
where we fail to collect debug info due to programming mistakes. E.g., in
one place we passed a `WrappedFun` instead of a `Callable`, resulting in missing debugging info.

Added more type declarations.

Added a `state_test` with a failure to track debugging information, manifested
with a leaked tracer without function provenance. Fixing this in a subsequent PR.
2025-01-20 15:09:51 +01:00
George Necula
f9dfe7f646 [better_errors] More cleanup 2025-01-15 10:22:29 +00:00
Sharad Vikram
7be127f23c [Pallas] Improvements to core_map
PiperOrigin-RevId: 713075852
2025-01-07 16:18:30 -08:00
Jake VanderPlas
40367a9eaf Cleanup: remove uses of no-op raise_to_shaped 2024-12-12 09:49:06 -08:00
jax authors
263d4d1462 Merge pull request #25369 from jax-ml:mutable-arrays-ad
PiperOrigin-RevId: 704685653
2024-12-10 06:36:02 -08:00
Dougal
fc2edbfac8 Add a freeze primitive to delimit ref lifetimes for AD.
Also some basic AD through mutable_array/freeze.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2024-12-09 20:57:07 -05:00
Jacob Burnim
af5013568a Fix error when swapping a ref with a trivial indexing transform.
Without this fix, the added test case fails with:
```
...
jax/_src/state/discharge.py:416: in _swap_discharge_rule
    z, x_new = _swap_discharge(x, val, idx, tree)
jax/_src/state/discharge.py:421: in _swap_discharge
    return transform_swap_array(x, transforms, val)
jax/_src/state/discharge.py:396: in transform_swap_array
    result_val = lax_slicing.dynamic_update_slice(
jax/_src/lax/slicing.py:215: in dynamic_update_slice
    start_indices = _dynamic_slice_indices(operand, start_indices)
...
AttributeError: 'NoneType' object has no attribute 'ndim'
```
from encountering a None when computing the `result_val`.
2024-12-06 09:12:14 -08:00
jax authors
e707edeafa Merge pull request #25034 from gnecula:poly_state
PiperOrigin-RevId: 698820458
2024-11-21 09:57:55 -08:00
George Necula
0831e2e340 [shape_poly] Adding shape polymorphism support for the state primitives. 2024-11-21 06:17:01 -08:00
Yash Katariya
40fc6598f9 [sharding_in_types] Make flash_attention forward pass in TPU pallas work nicely with sharding in types. Backward pass is still busted which I will fix in follow up CLs.
Set the abstract mesh context manager at the jit tracing boundary by looking at the mesh on the avals. In the future, this context manager will be user settable too.

Abstract mesh context manager is a new context manager with a new context variable and new trace_context entry which governs the cache behavior. If the abstract mesh context manager is not set, the default is `None`.

PiperOrigin-RevId: 698493184
2024-11-20 13:07:30 -08:00
Dougal Maclaurin
478b750c29 Reverts f281c6f46475270a57a02416469226315377592c
PiperOrigin-RevId: 693339094
2024-11-05 07:17:14 -08:00
Dougal Maclaurin
f281c6f464 Reverts ec39b592f7c096b0b8183723feaab2ed0d001041
PiperOrigin-RevId: 692949053
2024-11-04 06:54:06 -08:00
Dougal Maclaurin
ec39b592f7 Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat)
PiperOrigin-RevId: 692557993
2024-11-02 17:03:50 -07:00
Dougal Maclaurin
48f24b6acb Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.
PiperOrigin-RevId: 691929385
2024-10-31 14:06:54 -07:00
Dougal Maclaurin
c36e1f7c1a Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
jax authors
6f371212d9 Implements an alternate version of ragged_attention, wherein, the actual attention kernel itself is dense. Meaning, this kernel does not have the compute saving (@when wrapped kernel) or prefetch/index skipping (via index rewriting) as part of the kernel. Rather, the kernel is invoked with a Jumble (A ragged type representation) and pallas takes care of applying the correct work skipping and index rewriting.
Performance wise, we should be at parity, although this has not yet been tested.

Authoring wise, the new kernel is significantly smaller and simpler to write.

A major known limitation of this approach, which we have a plan to fix, is the invariant that the `seq_len % grid_size == 0` - we plan to relax this limitation in following CLs.

PiperOrigin-RevId: 689868468
2024-10-25 12:07:34 -07:00
Sharad Vikram
ce8ecbd16d Add an extension mechanism to run_state that allows:
* Uninitialized values
* Custom ref aval construction

This will allow us to replace `run_scoped` with `run_state`, and allow us to change the memory space of initialized values.

Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 688965089
2024-10-23 08:00:56 -07:00
Adam Paszke
f08801b8d6 [Pallas:MGPU] Allow indexing to appear anywhere in the list of transforms
We only need to exchange the transforms preceding the indexer, while
the rest can remain unmodified.

PiperOrigin-RevId: 688112088
2024-10-21 06:22:16 -07:00
Jake VanderPlas
de3191fab3 Cleanup: fix unused imports & mark exported names 2024-10-16 17:42:41 -07:00
Jevin Jiang
3a7d9137a4 [Pallas TPU] Support ref reshape.
Jaxpr example:
```
{ lambda ; a:MemRef<None>{int32[32,256]} b:MemRef<None>{int32[8,128]}. let
    c:i32[8,128] <- a[:16,:][bitcast(int16[32,256])][reshape(int16[2,16,256])][bitcast(float16[2,16,256])][1:,:,:][reshape(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
    b[:,:] <- c
  in () }
```

Tested:

- DMA with reshaped ref
- Load from reshaped ref
- Store to reshaped ref
- Multiple transforms
- Interpret Mode for ref transforms (updated discharge rules)

PiperOrigin-RevId: 686186426
2024-10-15 11:52:15 -07:00
jax authors
1f0b5728a4 Add a memory saving index rewrite step to vmap with ragged inputs over pallas_call.
The approach here is to add a new notion to jax, for ragged_prop. Ragged prop is useful for computing the dynamism/raggedness of an output, given a set of inputs. In the limit, if we decide that this is a useful property to have in jax as a first class citizen, we could fold the raggedness into the type system. At the moment, however, it is just a small set of rules implemented per op.

PiperOrigin-RevId: 685827096
2024-10-14 14:01:42 -07:00
Adam Paszke
e2d3bd866a [Pallas/MGPU] Add support for tiled and swizzled loads/stores + support slices
PiperOrigin-RevId: 681370464
2024-10-02 02:44:10 -07:00
Justin Fu
350afaa7b6 [Pallas] Clean up lowering exceptions.
PiperOrigin-RevId: 681073628
2024-10-01 10:26:40 -07:00
Christos Perivolaropoulos
84fc011e27 Introducing partial discharge rules and implementations for cond_p
As things stand you can partially discharge a jaxpr with
`discharge_state(should_discharge=[...])` but each equation is discharges *all*
its arguments. This means that primitives like `scan_p` and `cond_p` discharge
all references they refer to (no pun intended) regardless of whether the user
asked for it. We provide a special discharge rule that is preferred to the
normal one when present that allows the op to discharge only some of the
references.

This feature is especially useful for pallas kernels because contrary to all
other contexts where jaxprs are expected to eventually be fully discharged,
pallas kernels lower references all the way to the runtime as pointers or
MLIR memrefs.

Here we implement the partial discharge rule for `cond_p` and will implement it
for others in due course.

PiperOrigin-RevId: 681021324
2024-10-01 08:03:58 -07:00
Adam Paszke
cac2b8d5fc [Pallas/MGPU] Undo transforms before giving refs back to users
This is a second attempt at this change. The first one was rolled back because of reported failures.

Reverts 411928b9668570bbc3795522aba94cece6894881

PiperOrigin-RevId: 680943744
2024-10-01 03:32:40 -07:00
jax authors
411928b966 Rollback because of breakages
Reverts 21fea5b0db7a8d3fcd9d6918b430b0ebdd4da3e5

PiperOrigin-RevId: 680552566
2024-09-30 07:23:36 -07:00
Adam Paszke
21fea5b0db [Pallas/MGPU] Undo transforms on refs before giving them back to the users
This changes makes it so that the refs users receive inside their kernels have shapes
matching their block specs. However, the refs are not actually plain refs, but transformed
references that begin with the fully transformed abstract ref and then stack the inverse
of the transformation stack on top of it. This means that all primitives that take in refs
can also see the sequence of transforms the user applied in the block spec, which lets us
verify e.g. that the inputs to WGMMA are correctly tiled, even though their user-visible
shape remains 2D. We should be able to use the same trick in the future to propagate tiling
and better infer the layouts for loads and stores.

PiperOrigin-RevId: 680520185
2024-09-30 04:43:08 -07:00