Previously it was:
`ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x`
Now it is:
`TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]`
PiperOrigin-RevId: 736657644
1. axis_types now takes a `AxisTypes | tuple[AxisTypes, ...] | None`. It doesn't take a dictionary anymore
2. `jax.make_mesh` also takes the same `axis_types` tuple as in point 1.
PiperOrigin-RevId: 736360041
This would also make it easier to deprecate the `with mesh: pjit` path in the future from user code since the new path would be completely tested.
This will also allow us to remove `resource_env` from JAX and the internal API access of `resource_env.physical_mesh` spread throughout codebases internally and externally.
PiperOrigin-RevId: 735602187
Also improve dynamic_update_slice sharding error by printing `aval.str_short()` instead of full sharding because it's concise and gives more info than the current error (i.e. it adds shape too to the error message)
Also make some formatting changes in scan lowering to make it easier to debug.
PiperOrigin-RevId: 734542862
In this case, the example boils down to:
```
inp1 = f32[16@x, 4]
inp2 = f32[4]
def f(x: f32[4], y: f32[4])
return jnp.concat([x, y], axis=-1)
vmap(f, in_axes=(0, None))(inp1)
```
This example was breaking in concat batching rule because we didn't broadcast with the right sharding.
PiperOrigin-RevId: 733536944
Add a mechanism for using the same Var names for Vars that
are aliased. In this PR, we use this for `pjit`, such that the
following `print(jax.make_jaxpr(lambda a: jax.jit(lambda a: a + 1)(a))(0.))`
prints:
```
{ lambda ; a:f32[]. let
b:f32[] = pjit[
name=<lambda>
jaxpr={ lambda ; a:f32[]. let b:f32[] = add a 1.0 in (b,) }
] a
in (b,) }
```
instead of the previous:
```
{ lambda ; a:f32[]. let
b:f32[] = pjit[
name=<lambda>
jaxpr={ lambda ; c:f32[]. let d:f32[] = add c 1.0 in (d,) }
] a
in (b,) }
```
The same mechanism could be used for other higher-order primitives,
e.g., cond, and others.
Also add some typing declarations and rename APIs to use "shared jaxpr"
in lieu of "top-level jaxpr" for those Jaxprs that are used multiple
times and are printed first. I presume that the term "top-level jaxpr"
was picked because these are printed first at top-level. But this is
confusing, because they are really subjaxprs. In fact, there was already
a function `core.pp_toplevel_jaxpr` for printing the top-level Jaxpr,
and there was also `core.pp_top_level_jaxpr` (which now is named
`core.pp_shared_jaxpr`.
For example: Consider this einsum: `jnp.einsum('bthD, bthi, bthj->ijD', dy, i, j, out_sharding=P('data', None, None))`
This will decompose into 2 einsums where the intermediate einsum output will be of rank `5`:
* `'bthj,bthD->bthjD'`
* `'bthjD,bthi->ijD'`
The out_sharding specified (`P('data', None, None)`) is not compatible with the intermediate einsum: `'bthj,bthD->bthjD'` since the `length of spec (3) != out_aval.ndim (5)`.
This change makes it so that out_sharding is only applied to the contraction that leads to the final output. **If there are conflicts in intermediate einsums, then the user has to reshard the input or split into multiple einsums (and maybe provide out_sharding) so that conflicts don't exist.**
Note: We won't drop into auto mode for intermediate einsums. The user will have to split the einsum if any conflict is detected.
PiperOrigin-RevId: 732205849
* `_partitions` is now canonicalized and only contains `tuples`, `singular strings`, `None` or `UNCONSTRAINED`. No more empty tuples (`P((), 'x')`) and singleton tuples.
* Cache the creating of sharding on ShapedArray since it's expensive to do it a lot of times
* Change the `__hash__` and `__eq__` of `NamedSharding` to depend on `self.spec` instead of `self._parsed_pspec`.
PiperOrigin-RevId: 731745062
Those APIs don't support that right now anyways and they raise an ugly KeyError. Instead we raise a better error here.
I have added a TODO to get the mesh from args so that computation follows data works but we can decide to do that in the future if a lot of users request that and don't want to use `use_mesh`.
PiperOrigin-RevId: 730687231
* Allow merging and splitting only if major most dim is sharded since that involves no data movement. This only happens if `dimensions` is None i.e. if the input array is in **row-major order**.
* Merging: If **only** the major most dim is sharded of the merge block then that sharding is propagated to the merge block output
* Splitting: If the dimension being split is sharded, then the sharding is propagated to the major most dimension post split only if the spec divides the new shape exactly.
PiperOrigin-RevId: 730291595
Previously, we represented a missing arg name with `None`,
and a missing result path with the empty string. We now
adopt the same convention for arg names and use empty strings.
This simplifies the typing, and prevents the string "None" from
appearing in error messages.
I changed how we encode the result paths. Previously for a
function that returns a single array the path was the empty
string (the same as for an unknown path). And for a function
that returns a pair of arrays it was `([0], [1])`. Now we
add the "result" prefix: `("result",)` for a function returning a
single array and `(result[0], result[1])` for a function returning
a pair of arrays.
Finally, in debug_info_test, I removed the `check_tracer_arg_name`
so that all spied tracers are printed with the argument name they
depend on.
* `bitcast_convert_element_type`
* `cumsum`
* `cumlogsumexp`
* `cumprod`
* `cummax`
* `cummin`
* `reduce_window`
* `reduce_window_sum`
* `reduce_window_max`
* `reduce_window_min`
* `select_and_gather_add`
For `reduce_window_...` primitives only trivial windowing is supported along non-replicated dimensions. We can relax the other NotImplemented case in the future.
PiperOrigin-RevId: 729910108
If a mesh axis is Explicit, we don't canonicalize closed over values yet since that make require shape changes. The workaround is for users to pass those arrays as arguments instead of closing over them in a shard_map.
PiperOrigin-RevId: 728956512
Some caveats of enabling sharding-in-types by default are that we'll see tracing cache misses which will lead to lowering cache miss and compilation cache misses in the **following cases**: (but persistent compilation cache is not affected so we'll see a cache hit there)
1. Call `jitted_f(arr_ns)` with an array on `NamedSharding` and again `jitted_f(arr_ps)` with an array of same shape and dtype but now with `PositionalSharding`
* This leads to a tracing cache miss because on the second call, the aval has no sharding since it's PositionalSharding. This applies to calling with any sharding other than NamedSharding
2. `jitted_f = jit(f, in_shardings=ns)`. Call `jitted_f(sharded_arr)` and then on the second call you pass a numpy array `jitted_f(numpy_arr)`
* This also leads to a cache miss because the avals currently don't look at in_shardings because the semantics of in_shardings is complicated and I don't think we should change the aval based on in_shardings.
**The solution in both cases is make sure to pass the array sharded on the same mesh during both calls to jit.**
PiperOrigin-RevId: 728361493
This change is raising a better error because doing `NamedSharding(empty_mesh, P('x'))` will raise an error on construction but it is uglier than the current error added in this change.
PiperOrigin-RevId: 726253654
* get_aval is not context dependent
* canonicalization does not happen for avals on an empty mesh
* jax.jit does not set abstract mesh context anymore before tracing
* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode
* Even if use_mesh is not used in explicit sharding mode, computation follows data works!
* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)
* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.
As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.
PiperOrigin-RevId: 726097292
* Track `explicit_mesh_axis` on `AxisData`.
* Modify `unmapped_aval` to the the above explicit mesh axis and insert it into the right place in the sharding so out_shardings are correct.
* Make `matchaxis` also handle shardings correctly
* All mapped dimensions should be sharded the same way
* spmd_axis_name and explicit sharded arrays cannot be used together
* `out_shardings` parameter on `dot_general`, `broadcast_in_dim`, `reshape`, `reshard` and `mesh_cast` is handled correctly in presence of vmap.
This should eventually help us get rid of `spmd_axis_name` from `vmap`.
PiperOrigin-RevId: 721007659