... at least when the manual sharding applies to the whole mesh, because
that's all that XLA can support right now. This is especially important
when computing gradients of xmapped functions (when manual lowering is
enabled), since AD often introduces many `psum`s.
PiperOrigin-RevId: 467895089
[XLA:CPU] Implement complex all-reductions for sum and product.
Fixes https://github.com/google/jax/issues/11133 by making XLA implement the all-reduction whenever we build one, not just the one path on which we happened to have a workaround.
PiperOrigin-RevId: 455687275
This runs into the currently unsupported feature in Python bindings which prevents it from taking advantage of the type inference functionality provided by HLO_CompatibleOperandsAndResultType.
PiperOrigin-RevId: 447844880
In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist.
[PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device.
PiperOrigin-RevId: 446737518
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.
PiperOrigin-RevId: 442547482
Change in preparation for removing XLA translation rules for many primitives. However, even after the MHLO switch we still need to tag collective and initial_style primitives.
PiperOrigin-RevId: 441474701
This is redundant with the XLA lowering, but it's probably not the end of the world as a temporary state. An alternative would have been to port the _xla_shard/_xla_unshard primitives to the LAX level and to use xla.lower_fun, but it's not immediately obvious to me how to access ReplicaId() without defining a new primitive. lax.axis_index() is similar but not identical.
Add an axis_env argument to xla.primitive_subcomputation for use by the MLIR fallback path.
PiperOrigin-RevId: 413124116
To solve a circular dependency problem where some functions in jax._src.lax.lax depend on slicing, I moved a number of utility functions, e.g., standard_primitive, into a new module `jax._src.lax.utils`. Only utilities that need to be present at module import time were moved.
PiperOrigin-RevId: 411921794
This was a bad bug! Unfortunately our tests didn't catch it, in part
because permutations on size-two axes are either trivial or not. The
simplest test might have a size-three axis.
* Always propagate the axis environment, and remove the parallel argument to lower_fun() because it is no longer needed.
* Don't update the name_stack in the MLIR version. The XLA version no longer does this.
* Simplify the call signature of the MLIR version by always accepting avals_out but noting that it is ignored so it is legal to pass, say, None.
PiperOrigin-RevId: 410875100
This is a more descriptive name and a better location (next to other facilities for building XLA IR).
Quite a few users of the former xla_bridge.constant() didn't need anything other than uncanonicalized array constants. Change these users to use xla_client.ops.Constant instead; no need for the fancy utility in these cases.
PiperOrigin-RevId: 404270649
The new version does *not* canonicalize dtypes. We should be canonicalizing dtypes as part of tracing to a jaxpr, not in any way as part of XLA lowering. In all cases as best I can tell the dtypes from the callers are already canonical anyway.
jax.interpreters.xla is also a better location: I'm not even sure why we have a bunch of random things in xla_bridge any more, so it makes sense to consolidate them in xla.py along with the other registrations for things like avals.
Also delete the unused function xla_bridge.supported_numpy_dtypes.
PiperOrigin-RevId: 404246574
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
The all-gather and reduce-scatter HLOs were wired through for GPU but not TPU, but they should also work there (and be more performant than the all-reduce based fallback).
This is especially convenient when using JAX as an HLO generator, because the
HLO AllGather defaults to the tiling behavior.
PiperOrigin-RevId: 384897270