There's no reason why not two custom vmappable types cannot share the same spec_type. However, spec_types was a set, which can cause bugs / exceptions.
Suppose that I register two vmappable data_types sharing the same spec_type, and then unregister one of the two. Then, the spec_type is no longer in the set to support the second data_type. Also, an exception will be raised if I try to unregister the two vmappable types (the second call to spec_types.remove).
When unregistering a data type, instead of removing its spec_type from the set, we regenerate the set from the remaining vmappable types.
PiperOrigin-RevId: 737280270
This is an exact port of the current Python implementation to C++ for speed.
I am being careful not to change the topological order we return in any way in this change, although we may do so in a future change.
PiperOrigin-RevId: 737014989
This avoids a bunch of list bookkeeping in cases where we are iterating only for a side effect and do not care about the results.
I would have named this iter() to match OCaml's list.iter(), but unfortunately iter() is a Python builtin.
PiperOrigin-RevId: 736859418
Two PRs that support this feature have been submitted to stablehlo and openxla.
Now we could do the last step - enable it in JAX.
PiperOrigin-RevId: 736799241
For now, propagate transforms for `wgmma`. We do not handle `transpose` for
either operand yet.
The pass isn't called anywhere yet.
PiperOrigin-RevId: 736758754
Previously it was:
`ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x`
Now it is:
`TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]`
PiperOrigin-RevId: 736657644
Test the value of Trace._invalidated directly rather than using a hasattr test. I'm assuming the reason we did this is because we wanted to avoid updating all the subclasses to call super().__init__().
hasattr() tests are unnecessarily slow (did you know the one in jax.core.Trace builds an error message every time it fails?)
PiperOrigin-RevId: 736555016
1. axis_types now takes a `AxisTypes | tuple[AxisTypes, ...] | None`. It doesn't take a dictionary anymore
2. `jax.make_mesh` also takes the same `axis_types` tuple as in point 1.
PiperOrigin-RevId: 736360041
Note that
* we have no easy way of testing multi-reductions at the moment;
* `reduce_max` assumes WGMMA_ROW layout which is not currently supported by
the dialect lowering AFAICT.
PiperOrigin-RevId: 736138554
also update the ragged_all_to_all docstring. pseudocode in the style of the shard_map tutorial would be better and cleaner, but it needs the context of the tutorial to explain; i'll add ra2a to the shmap tutorial in the future.
PiperOrigin-RevId: 735957604