This allowlist used to prevent users from using collectives that didn't work correctly in multihost pmap(). But currently every collective in JAX (except for pgather(), which isn't public), is on the list. So the allowlist serves no purpose any more.
PiperOrigin-RevId: 555124144
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.
Remove an unused top_k translation rule as well.
PiperOrigin-RevId: 554946059
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.
PiperOrigin-RevId: 523146076
As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically:
1) MLIR lowerings now produce StableHLO ops instead of MHLO ops.
2) Fallback lowerings now produce StableHLO ops as well.
3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs).
From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO):
a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing.
b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath.
c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo".
d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues.
PiperOrigin-RevId: 497978733
... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.
Unchanged occurrences:
1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
argument value in Lowering.as_text and Lowering.compiler_ir.
2) Documentation (changelog, JEPs, IR examples, etc).
3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
so both are necessary to disambiguate.
PiperOrigin-RevId: 495771153
... at least when the manual sharding applies to the whole mesh, because
that's all that XLA can support right now. This is especially important
when computing gradients of xmapped functions (when manual lowering is
enabled), since AD often introduces many `psum`s.
PiperOrigin-RevId: 467895089
[XLA:CPU] Implement complex all-reductions for sum and product.
Fixes https://github.com/google/jax/issues/11133 by making XLA implement the all-reduction whenever we build one, not just the one path on which we happened to have a workaround.
PiperOrigin-RevId: 455687275
This runs into the currently unsupported feature in Python bindings which prevents it from taking advantage of the type inference functionality provided by HLO_CompatibleOperandsAndResultType.
PiperOrigin-RevId: 447844880
In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist.
[PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device.
PiperOrigin-RevId: 446737518
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.
PiperOrigin-RevId: 442547482
Change in preparation for removing XLA translation rules for many primitives. However, even after the MHLO switch we still need to tag collective and initial_style primitives.
PiperOrigin-RevId: 441474701
This is redundant with the XLA lowering, but it's probably not the end of the world as a temporary state. An alternative would have been to port the _xla_shard/_xla_unshard primitives to the LAX level and to use xla.lower_fun, but it's not immediately obvious to me how to access ReplicaId() without defining a new primitive. lax.axis_index() is similar but not identical.
Add an axis_env argument to xla.primitive_subcomputation for use by the MLIR fallback path.
PiperOrigin-RevId: 413124116