We add `None`'s when ndim > len(sharding.spec) and only remove `None`s when `ndim < len(sharding.spec)`. If sharded axes exist, then we error out when removing specs.
PiperOrigin-RevId: 748735303
The main changes here are:
* Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead.
* Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`.
* Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`.
* Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 745276474
Treat `None` metadata values as a special instruction not to set (or
to unset, if nested) the corresponding entry.
In particular, this makes it possible to unset metadata within the
sub-computations of higher-order operations (e.g. branches in
conditionals, loop bodies, etc.). This can be used, for example, to
annotate a conditional but not all the operations in its
branches. That is, the HLO for the following function `f` on a scalar
float argument:
```
def cos(x):
with set_xla_metadata(a=None):
return jnp.cos(x)
@jax.jit
def f(x):
with set_xla_metadata(a="b"):
return jax.lax.cond(x < 0., jnp.sin, cos, x)
```
produces an attribute `a` on the conditional and on the sine, but not
on the cosine.
Since this functionality was added for a dynamic shapes experiment, only enable it when dynamic_shapes config is True.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 740942785
1. At the lax level, before we bind the primitive, we need to insert pbroadcasts if some inputs are varying. This is equivalent to the rewrite rules that shard_map has.
2. In abstract_eval rules of primitives, we need to check if all inputs are varying across the same mesh axes and then add the `varying_manual_axes` to the output ShapedArray.
This in turn requires us to support `pbroadcast2` and `psum2` primitives in shard_map.py. These primitives don't need to insert any pbroadcasts (equivalent to `no_rewrite` in shard_map) but need to do checks and update the output aval in their abstract_eval rules.
* pbroadcast_p: Union the existing aval.varying_manual_axes + axes (passed to pbroadcast) to calculate the output vma. For checks we need to make sure that the intersection of `aval.varying_manual_axes` and `axes` is empty.
* psum2_p: Remove the named axes from aval.varying_manual_axes to calculate the output vma. For checks we need to make sure that the intersection of `aval.varying_manual_axes` and `axes` is NOT empty.
Majority of the primitives should use the standard_insert_pbroadcast and standard_vma_rule and I'll add those in the follow up CLs to other primitives
PiperOrigin-RevId: 739225392
This is so that we can catch this exception in backward_pass/vmap and add extra message to inform users that this is a potential JAX bug. They should file an issue on the repo.
Currently, we only raise `ShardingTypeError` in one place, but we can expand to all other places in follow up changes. This change sets the machinery up.
Previous error:
```
jax._src.core.ShardingTypeError: dynamic_update_slice update sharding must be equal to operand sharding, got update sharding float32[2@x]({Explicit: ('x',)}) for operand sharding float32[16]({}).
```
New error:
```
jax._src.core.ShardingTypeError: dynamic_update_slice update sharding must be equal to operand sharding, got update sharding float32[2@x]({Explicit: ('x',)}) for operand sharding float32[16]({}).
This is a potential JAX bug. Please file an issue at https://github.com/jax-ml/jax/issues
```
The new added message of `This is a potential JAX bug...` is important because this error is raised in the backward pass which is 100% a JAX bug given that forward pass did not error.
PiperOrigin-RevId: 739053305
This avoids a bunch of list bookkeeping in cases where we are iterating only for a side effect and do not care about the results.
I would have named this iter() to match OCaml's list.iter(), but unfortunately iter() is a Python builtin.
PiperOrigin-RevId: 736859418
Previously it was:
`ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x`
Now it is:
`TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]`
PiperOrigin-RevId: 736657644
Test the value of Trace._invalidated directly rather than using a hasattr test. I'm assuming the reason we did this is because we wanted to avoid updating all the subclasses to call super().__init__().
hasattr() tests are unnecessarily slow (did you know the one in jax.core.Trace builds an error message every time it fails?)
PiperOrigin-RevId: 736555016
1. axis_types now takes a `AxisTypes | tuple[AxisTypes, ...] | None`. It doesn't take a dictionary anymore
2. `jax.make_mesh` also takes the same `axis_types` tuple as in point 1.
PiperOrigin-RevId: 736360041
Also improve dynamic_update_slice sharding error by printing `aval.str_short()` instead of full sharding because it's concise and gives more info than the current error (i.e. it adds shape too to the error message)
Also make some formatting changes in scan lowering to make it easier to debug.
PiperOrigin-RevId: 734542862
Add a mechanism for using the same Var names for Vars that
are aliased. In this PR, we use this for `pjit`, such that the
following `print(jax.make_jaxpr(lambda a: jax.jit(lambda a: a + 1)(a))(0.))`
prints:
```
{ lambda ; a:f32[]. let
b:f32[] = pjit[
name=<lambda>
jaxpr={ lambda ; a:f32[]. let b:f32[] = add a 1.0 in (b,) }
] a
in (b,) }
```
instead of the previous:
```
{ lambda ; a:f32[]. let
b:f32[] = pjit[
name=<lambda>
jaxpr={ lambda ; c:f32[]. let d:f32[] = add c 1.0 in (d,) }
] a
in (b,) }
```
The same mechanism could be used for other higher-order primitives,
e.g., cond, and others.
Also add some typing declarations and rename APIs to use "shared jaxpr"
in lieu of "top-level jaxpr" for those Jaxprs that are used multiple
times and are printed first. I presume that the term "top-level jaxpr"
was picked because these are printed first at top-level. But this is
confusing, because they are really subjaxprs. In fact, there was already
a function `core.pp_toplevel_jaxpr` for printing the top-level Jaxpr,
and there was also `core.pp_top_level_jaxpr` (which now is named
`core.pp_shared_jaxpr`.
* `_partitions` is now canonicalized and only contains `tuples`, `singular strings`, `None` or `UNCONSTRAINED`. No more empty tuples (`P((), 'x')`) and singleton tuples.
* Cache the creating of sharding on ShapedArray since it's expensive to do it a lot of times
* Change the `__hash__` and `__eq__` of `NamedSharding` to depend on `self.spec` instead of `self._parsed_pspec`.
PiperOrigin-RevId: 731745062
contextdecorator turns out to be slower than just writing a decorator class explicitly. Since we use many decorators per-equation, this causes a measurable speed difference in certain benchmarks.
PiperOrigin-RevId: 730939406
Previously, we represented a missing arg name with `None`,
and a missing result path with the empty string. We now
adopt the same convention for arg names and use empty strings.
This simplifies the typing, and prevents the string "None" from
appearing in error messages.
I changed how we encode the result paths. Previously for a
function that returns a single array the path was the empty
string (the same as for an unknown path). And for a function
that returns a pair of arrays it was `([0], [1])`. Now we
add the "result" prefix: `("result",)` for a function returning a
single array and `(result[0], result[1])` for a function returning
a pair of arrays.
Finally, in debug_info_test, I removed the `check_tracer_arg_name`
so that all spied tracers are printed with the argument name they
depend on.
If a mesh axis is Explicit, we don't canonicalize closed over values yet since that make require shape changes. The workaround is for users to pass those arrays as arguments instead of closing over them in a shard_map.
PiperOrigin-RevId: 728956512
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.
For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.
See https://github.com/jax-ml/jax/issues/26480 for more details.
PiperOrigin-RevId: 726770483
For future reference, this can be done via
python -m mypy jax --warn-unused-ignores > /tmp/unused.txt
while IFS=: read file line rest; do
echo "$file:$line";
gsed -i "${line}s/ *\# type: ignore\(\[[^]]*\]\)*//" "$file"
done < /tmp/unused.txt