This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.
Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
The values are guaranteed to be 0 or 1 since we create this array ourselves when processing the masks into a MaskInfo object.
PiperOrigin-RevId: 705252534
Layouts are added as annotations on MLIR ops, using the `in_layouts` and
`out_layouts` attributes.
At this point, layout inference is done in two passes: one "backwards" pass
(root-to-parameters), and one "forward" pass (parameters-to-root).
Each pass goes through all the ops in the specified order, and infers a
possible layout from the layout information that is available. We expect to
need two passes because partial layout annotations may be provided on
intermediate nodes (e.g. `wgmma`), and a single pass from the root to the
parameters is therefore insufficient to properly annotate all the operations.
We do not perform any check as to whether the inferred layouts can be further
lowered correctly---meaning that the produced IR can possibly fail to lower
later.
Layouts are only inferred for ops involving at least one operand or result of
type `VectorType`/`RankedTensorType`.
When layouts can't be inferred for an op that should have them, we default to
annotating it with strided fragmented layouts.
PiperOrigin-RevId: 705092403
Without this type hint, some tools (including PyCharm) infer the more
generic return type from typing.NamedTuple.
To improve user experience, I've added a narrower type hint.
However, the typing of this method is still 'flawed' as the only properly supported
input is another Rotation. This is a narrower input type and therefore
violates the Liskov substitution principle. Therefore I left the input
parameter untyped.
For more info:
https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides
Previously, an equality constraint was used only as a normalization
rule. This created a problem for constraints of the form "4*b=c",
because it would not allow proving that "b <= c" (since the
normalization of "4*b" kicks in only if "b" is multiplied by a
multiple of 4.
Now we add the equality constraints also in the inequality
reasoning state.
Also support `Auto` mode fully or mixed in with `User` mode. This works by overriding the sharding of `Auto` axes in the PartitionSpec with `Unconstrained` in `ShapedArray` constructor. The `ShapedArray` constructor is the central place where we can make such substitutions.
During lowering of shardings with auto axes, we mark the auto dims are `unspecifed_dims`. We don't mark all dims as unspecified because that would enable XLA to shard them even further which is not what we want if some of the dims are user sharded.
PiperOrigin-RevId: 704911253
This is doubly non-public: nothing under `jax.lib` is public, and also the object itself has a preceding underscore. Therefore it is safe to remove (chex had referenced this previously, but that's now addressed in adaf1b2b75).
PiperOrigin-RevId: 704825268
Move the parsing of a sharding rule string to a free function
str_to_sdy_sharding_rule. Move the building of the MLIR sharding rule to a free
function sdy_sharding_rule_to_mlir.
PiperOrigin-RevId: 704818640
This version emits a StableHLO custom call. The test outputs the following MLIR module:
```
module @jit_ragged_all_to_all {
func.func public @main(%arg0: tensor<6xf32>, %arg1: tensor<6xf32>, %arg2: tensor<3xi32>, %arg3: tensor<3xi32>, %arg4: tensor<3xi32>, %arg5: tensor<3xi32>) -> (tensor<6xf32>) {
%0 = stablehlo.custom_call @ragged_all_to_all(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {api_version = 4 : i32, backend_config = {replica_groups = dense<[[0, 1, 2]]> : tensor<1x3xi64>}} : (tensor<6xf32>, tensor<6xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<6xf32>
return %0 : tensor<6xf32>
}
}
```
For now, the API assumes `split_axis` and `concat_axis` of `all_to_all` to be the outermost (ragged) dim, and `axis_index_groups` is default to all replicas (e.g. there is only one group and covers all axis indices aka iota like the example above).
The current API is inspired from https://www.mpich.org/static/docs/v3.1/www3/MPI_Alltoallv.html which essentially also does a ragged all to all.
PiperOrigin-RevId: 704550890