319 Commits

Author SHA1 Message Date
Yash Katariya
06ad3528e9 Use _make_lengths_same for explicit mode too.
We add `None`'s when ndim > len(sharding.spec) and only remove `None`s when `ndim < len(sharding.spec)`. If sharded axes exist, then we error out when removing specs.

PiperOrigin-RevId: 748735303
2025-04-17 10:48:46 -07:00
Yash Katariya
82215f660e Remove jax_varying_axes_in_types config and rewrite from shard_map_p
PiperOrigin-RevId: 748545142
2025-04-16 22:27:50 -07:00
Yash Katariya
6e00b5e02d [NFC] Rename standard_insert_pbroadcast to standard_insert_pvary
PiperOrigin-RevId: 747943230
2025-04-15 11:02:45 -07:00
Yash Katariya
75e4279e32 Set jax_varying_axes_in_types to True by default.
PiperOrigin-RevId: 745739477
2025-04-09 14:40:31 -07:00
Yash Katariya
84016bc368 Rename pbroadcast to pvary and expose it as jax.lax.pvary.
PiperOrigin-RevId: 745342103
2025-04-08 16:51:27 -07:00
Yash Katariya
8301c304c1 Make changes to shard_map to prepare for setting varying_axes_in_types to True.
The main changes here are:

* Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead.

* Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`.

* Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`.

* Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 745276474
2025-04-08 13:47:13 -07:00
Sergei Lebedev
af072feb5a Removed redundant passes
If a function or class has a docstring, it does not need a `pass`.

PiperOrigin-RevId: 745052107
2025-04-08 02:38:21 -07:00
Sergei Lebedev
8ed59d8b5d Removed jax._src.raise_to_shaped
It is just an identity after the "stackless" rewrite.

PiperOrigin-RevId: 745042532
2025-04-08 02:06:40 -07:00
Sergei Lebedev
2944e3b2a6 Removed data_dependent_tracing_fallback config option
No internal code needs it any more.

PiperOrigin-RevId: 744870756
2025-04-07 15:27:57 -07:00
Sergei Lebedev
51c224c446 Removed deprecated jax.core.{full_lower,jaxpr_as_fun,lattice_join}
PiperOrigin-RevId: 744754730
2025-04-07 09:50:43 -07:00
Sergei Lebedev
c2aa811cd6 jex.core.Var is no longer ordered
This behavior was only needed for kfac_jax which has been updated *not* to
rely on variable ordering.

PiperOrigin-RevId: 744691114
2025-04-07 05:50:41 -07:00
Yash Katariya
fc5d9a4fce Check that memory_kind of an aval is always None
PiperOrigin-RevId: 744136969
2025-04-04 19:23:25 -07:00
jax authors
056c976ecb Merge pull request #27660 from froystig:xla-meta-ctx
PiperOrigin-RevId: 743178649
2025-04-02 09:59:33 -07:00
Roy Frostig
1875c76bd2 let XLA metadata be unset in nested dynamic scopes
Treat `None` metadata values as a special instruction not to set (or
to unset, if nested) the corresponding entry.

In particular, this makes it possible to unset metadata within the
sub-computations of higher-order operations (e.g. branches in
conditionals, loop bodies, etc.). This can be used, for example, to
annotate a conditional but not all the operations in its
branches. That is, the HLO for the following function `f` on a scalar
float argument:

```
def cos(x):
  with set_xla_metadata(a=None):
    return jnp.cos(x)

@jax.jit
def f(x):
  with set_xla_metadata(a="b"):
    return jax.lax.cond(x < 0., jnp.sin, cos, x)
```

produces an attribute `a` on the conditional and on the sine, but not
on the cosine.
2025-04-01 20:25:19 -07:00
Yash Katariya
76271d638a Add scan_p and cond_p vma rule.
PiperOrigin-RevId: 742737384
2025-04-01 09:50:38 -07:00
Yash Katariya
5950e722e2 Make sure vma on ShapedArray exists by default to make development easier. The field is populated inside shard_map guarded on the varying_axes_in_types config though.
PiperOrigin-RevId: 741554623
2025-03-28 09:44:03 -07:00
Yash Katariya
563c3e2244 Add standard pbroadcast rules to more primitives. This should cover all primitives from which shard_map registered standard_rewrite rules
PiperOrigin-RevId: 741516445
2025-03-28 07:20:12 -07:00
Yash Katariya
25c106d132 Add standard_insert_pbroadcasts and standard_vma_rule to all primitives in following files: (Don't add standard_insert_broadcast for unary ops though)
* slicing.py
* windowed_reductions.py
* special.py
* convolution.py
* fft.py
* linalg.py
* ann.py

PiperOrigin-RevId: 741327361
2025-03-27 16:56:39 -07:00
Yash Katariya
e8038501d0 Fix a bug where jit was forwarding inputs to outputs even when donation was True for that inputs. This caused the output to be marked as deleted since the input was being forwarded to the output.
Since this functionality was added for a dynamic shapes experiment, only enable it when dynamic_shapes config is True.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 740942785
2025-03-26 16:31:11 -07:00
Ayaka
c450b69dd7 Add missing __len__ to MutableArray
Fixes https://github.com/jax-ml/jax/issues/27476

PiperOrigin-RevId: 740903637
2025-03-26 14:27:50 -07:00
Yash Katariya
3163fbaac4 Add varying manual axes rules to mul_p and convert_element_type_p. There are 2 things that need to be added:
1. At the lax level, before we bind the primitive, we need to insert pbroadcasts if some inputs are varying. This is equivalent to the rewrite rules that shard_map has.

2. In abstract_eval rules of primitives, we need to check if all inputs are varying across the same mesh axes and then add the `varying_manual_axes` to the output ShapedArray.

This in turn requires us to support `pbroadcast2` and `psum2` primitives in shard_map.py. These primitives don't need to insert any pbroadcasts (equivalent to `no_rewrite` in shard_map) but need to do checks and update the output aval in their abstract_eval rules.

* pbroadcast_p: Union the existing aval.varying_manual_axes + axes (passed to pbroadcast) to calculate the output vma. For checks we need to make sure that the intersection of `aval.varying_manual_axes` and `axes` is empty.

* psum2_p: Remove the named axes from aval.varying_manual_axes to calculate the output vma. For checks we need to make sure that the intersection of `aval.varying_manual_axes` and `axes` is NOT empty.

Majority of the primitives should use the standard_insert_pbroadcast and standard_vma_rule and I'll add those in the follow up CLs to other primitives

PiperOrigin-RevId: 739225392
2025-03-21 10:26:18 -07:00
Yash Katariya
c7d6b653ce [sharding_in_types] Add core.ShardingTypeError as a new Exception that are sharding-in-types specific errors should raise.
This is so that we can catch this exception in backward_pass/vmap and add extra message to inform users that this is a potential JAX bug. They should file an issue on the repo.

Currently, we only raise `ShardingTypeError` in one place, but we can expand to all other places in follow up changes. This change sets the machinery up.

Previous error:

```
jax._src.core.ShardingTypeError: dynamic_update_slice update sharding must be equal to operand sharding, got update sharding float32[2@x]({Explicit: ('x',)}) for operand sharding float32[16]({}).
```

New error:

```
jax._src.core.ShardingTypeError: dynamic_update_slice update sharding must be equal to operand sharding, got update sharding float32[2@x]({Explicit: ('x',)}) for operand sharding float32[16]({}).
This is a potential JAX bug. Please file an issue at https://github.com/jax-ml/jax/issues
```

The new added message of `This is a potential JAX bug...` is important because this error is raised in the backward pass which is 100% a JAX bug given that forward pass did not error.

PiperOrigin-RevId: 739053305
2025-03-20 22:19:08 -07:00
Yash Katariya
88d4bc3d45 Rename AxisTypes enum to AxisType
PiperOrigin-RevId: 736935746
2025-03-14 11:48:21 -07:00
Peter Hawkins
8ab33669e2 Add a variant of safe_map() that has no return value, named foreach().
This avoids a bunch of list bookkeeping in cases where we are iterating only for a side effect and do not care about the results.

I would have named this iter() to match OCaml's list.iter(), but unfortunately iter() is a Python builtin.

PiperOrigin-RevId: 736859418
2025-03-14 07:42:48 -07:00
Peter Hawkins
074216e07a Precompute a weakref to a Trace≥
We use Trace weakrefs frequently, so we may as well construct one eagerly.

PiperOrigin-RevId: 736841778
2025-03-14 06:26:17 -07:00
Yash Katariya
d3a41d8448 get_sharding doesn't need to be conditioned on the context mesh
PiperOrigin-RevId: 736710468
2025-03-13 18:59:31 -07:00
Yash Katariya
e615e2acb3 Raise a better error with more info when we see duplicate axis in a PartitionSpec resulting from a sharding rule.
Previously it was:

`ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x`

Now it is:

`TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]`

PiperOrigin-RevId: 736657644
2025-03-13 15:24:10 -07:00
Peter Hawkins
8effa19734 [JAX] Change jax.core.Trace subclasses to call super().__init__().
Test the value of Trace._invalidated directly rather than using a hasattr test. I'm assuming the reason we did this is because we wanted to avoid updating all the subclasses to call super().__init__().

hasattr() tests are unnecessarily slow (did you know the one in jax.core.Trace builds an error message every time it fails?)

PiperOrigin-RevId: 736555016
2025-03-13 10:27:52 -07:00
Yash Katariya
c6dcbb6759 [sharding_in_types] Rework the axis_types argument in Mesh and AbstractMesh APIs. The changes are:
1. axis_types now takes a `AxisTypes | tuple[AxisTypes, ...] | None`. It doesn't take a dictionary anymore

2. `jax.make_mesh` also takes the same `axis_types` tuple as in point 1.

PiperOrigin-RevId: 736360041
2025-03-12 20:41:50 -07:00
Yash Katariya
abcc7fdf4c [sharding_in_types] Initial commit to add varying_manual_axes: frozenset[AxisName] to ShapedArray. Also add jax_varying_axes_in_types config to hide this option under while we develop it.
PiperOrigin-RevId: 736141670
2025-03-12 08:29:16 -07:00
Yash Katariya
3a26804c68 Rename get_ty to typeof which is an alias of get_aval
PiperOrigin-RevId: 735946640
2025-03-11 17:34:44 -07:00
Matthew Johnson
7c2f842353 shard_map and other fixes to direct-linearize
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-07 21:02:40 +00:00
Yash Katariya
f8b98993b8 Add a divisibility check so that we make sure that sharding evenly divides the shape (until this restriction is lifted) to make sure we don't create bad shardings.
Also improve dynamic_update_slice sharding error by printing `aval.str_short()` instead of full sharding because it's concise and gives more info than the current error (i.e. it adds shape too to the error message)

Also make some formatting changes in scan lowering to make it easier to debug.

PiperOrigin-RevId: 734542862
2025-03-07 07:01:34 -08:00
Yash Katariya
e9486920e8 Auto complete specs in a sharding if aval.ndim > len(sharding.spec) with None. So that for a 2D input, P('data') continues to work.
PiperOrigin-RevId: 734325209
2025-03-06 16:10:14 -08:00
George Necula
a6c47d6f36 Use the same name for aliased Vars when pretty-printing Jaxprs.
Add a mechanism for using the same Var names for Vars that
are aliased. In this PR, we use this for `pjit`, such that the
following `print(jax.make_jaxpr(lambda a: jax.jit(lambda a: a + 1)(a))(0.))`
prints:

```
{ lambda ; a:f32[]. let
    b:f32[] = pjit[
          name=<lambda>
          jaxpr={ lambda ; a:f32[]. let b:f32[] = add a 1.0 in (b,) }
          ] a
    in (b,) }
```

instead of the previous:

```
{ lambda ; a:f32[]. let
    b:f32[] = pjit[
          name=<lambda>
          jaxpr={ lambda ; c:f32[]. let d:f32[] = add c 1.0 in (d,) }
          ] a
    in (b,) }
```

The same mechanism could be used for other higher-order primitives,
e.g., cond, and others.

Also add some typing declarations and rename APIs to use "shared jaxpr"
in lieu of "top-level jaxpr" for those Jaxprs that are used multiple
times and are printed first. I presume that the term "top-level jaxpr"
was picked because these are printed first at top-level. But this is
confusing, because they are really subjaxprs. In fact, there was already
a function `core.pp_toplevel_jaxpr` for printing the top-level Jaxpr,
and there was also `core.pp_top_level_jaxpr` (which now is named
`core.pp_shared_jaxpr`.
2025-03-03 11:38:51 +01:00
Yash Katariya
53494ade2d PRNGKeyArray.aval should have the correct logical sharding. This required refactoring code so that we don't hit recursion errors.
PiperOrigin-RevId: 732536521
2025-03-01 18:18:19 -08:00
Yash Katariya
177e1f6ed9 Canonicalize PartitionSpec so that we can delete ParsedPartitionSpec. We need to do this after sharding-in-types to speed up NamedSharding construction and remove a lot of tech debt and unnecessary complexity.
* `_partitions` is now canonicalized and only contains `tuples`, `singular strings`, `None` or `UNCONSTRAINED`. No more empty tuples (`P((), 'x')`) and singleton tuples.

* Cache the creating of sharding on ShapedArray since it's expensive to do it a lot of times

* Change the `__hash__` and `__eq__` of `NamedSharding` to depend on `self.spec` instead of `self._parsed_pspec`.

PiperOrigin-RevId: 731745062
2025-02-27 08:59:25 -08:00
Peter Hawkins
256e37af5f Port many uses of contextlib.contextdecorator to explicit context manager classes.
contextdecorator turns out to be slower than just writing a decorator class explicitly. Since we use many decorators per-equation, this causes a measurable speed difference in certain benchmarks.

PiperOrigin-RevId: 730939406
2025-02-25 10:31:05 -08:00
Yash Katariya
9deb7e3d96 [sharding_in_types] physical_aval should set the correct sharding on ShapedArray so that lowering and compilation don't crash
PiperOrigin-RevId: 730885084
2025-02-25 07:53:14 -08:00
Yash Katariya
6f8bab3c92 Add sharding mismatch to explain_tracing_cache_miss
PiperOrigin-RevId: 730645598
2025-02-24 16:49:49 -08:00
George Necula
1be801bac8 [better_errors] Cleanup use of DebugInfo.arg_names and result_paths
Previously, we represented a missing arg name with `None`,
and a missing result path with the empty string. We now
adopt the same convention for arg names and use empty strings.
This simplifies the typing, and prevents the string "None" from
appearing in error messages.

I changed how we encode the result paths. Previously for a
function that returns a single array the path was the empty
string (the same as for an unknown path). And for a function
that returns a pair of arrays it was `([0], [1])`. Now we
add the "result" prefix: `("result",)` for a function returning a
single array and `(result[0], result[1])` for a function returning
a pair of arrays.

Finally, in debug_info_test, I removed the `check_tracer_arg_name`
so that all spied tracers are printed with the argument name they
depend on.
2025-02-23 08:27:56 +02:00
Yash Katariya
262aab74f0 canonicalize closed over values if **atleast** 1 mesh axis is Manual and **all other mesh axes** are Manual or Auto. This would make the canonicalization work properly with shmap partial-auto.
If a mesh axis is Explicit, we don't canonicalize closed over values yet since that make require shape changes. The workaround is for users to pass those arrays as arguments instead of closing over them in a shard_map.

PiperOrigin-RevId: 728956512
2025-02-19 22:18:56 -08:00
Yash Katariya
8305803b76 [sharding_in_types] Initial support for partial-auto/explicit shard_map + sharding-in-types. If the axes in shmap(..., auto=...) is an explicit axes in the outer mesh context, then that axis is treated as Explicit instead of Auto.
PiperOrigin-RevId: 728920514
2025-02-19 20:04:54 -08:00
Yash Katariya
a3edfb43ef Now that sharding_in_types config flag is True, remove the config and all the conditionals
PiperOrigin-RevId: 728653433
2025-02-19 06:53:35 -08:00
Yash Katariya
b35083331c Expose get_ty aka get_aval from jax namespace
PiperOrigin-RevId: 728490205
2025-02-18 21:22:19 -08:00
George Necula
a0812cd57e [better_errors] Make it explicit that debug_info is not None.
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.

For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.

See https://github.com/jax-ml/jax/issues/26480 for more details.

PiperOrigin-RevId: 726770483
2025-02-13 22:07:04 -08:00
jax authors
60dcded2af Merge pull request #26518 from superbobry:maint-2
PiperOrigin-RevId: 726663977
2025-02-13 15:44:19 -08:00
Sergei Lebedev
a73456d54d Removed unused `# type: ignore` comments
For future reference, this can be done via

    python -m mypy jax --warn-unused-ignores > /tmp/unused.txt
    while IFS=: read file line rest; do
      echo "$file:$line";
      gsed -i "${line}s/ *\# type: ignore\(\[[^]]*\]\)*//" "$file"
    done < /tmp/unused.txt
2025-02-13 21:12:27 +00:00
Yash Katariya
229aa65a3e Split NamedSharding into a separate file called named_sharding.py so that we can import it in core.py and break the cyclic dependency.
PiperOrigin-RevId: 726566863
2025-02-13 11:22:54 -08:00
Yash Katariya
3ec7a67e51 [sharding_in_types] Make sharding arg to ShapedArray kwarg only
PiperOrigin-RevId: 726272943
2025-02-12 18:22:50 -08:00