137 Commits

Author SHA1 Message Date
Peter Hawkins
c9cf6b4423 Remove allowlist for multihost collectives.
This allowlist used to prevent users from using collectives that didn't work correctly in multihost pmap(). But currently every collective in JAX (except for pgather(), which isn't public), is on the list. So the allowlist serves no purpose any more.

PiperOrigin-RevId: 555124144
2023-08-09 04:43:51 -07:00
Peter Hawkins
ca17b6c08f Move functions out of xla.py closer to their users.
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.

Remove an unused top_k translation rule as well.

PiperOrigin-RevId: 554946059
2023-08-08 14:40:42 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Matthew Johnson
6ea8a546f6 always lower all_to_all to AllToAll 2023-04-11 18:31:17 -07:00
Peter Hawkins
be1cf46a49 Split sharding_impls into its own Bazel target.
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.

PiperOrigin-RevId: 523146076
2023-04-10 10:15:58 -07:00
Peter Hawkins
abf1acf76c Replace references to jax.interpreters with jax._src.interpreters in JAX core.
PiperOrigin-RevId: 520933067
2023-03-31 08:58:00 -07:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Yash Katariya
d84ac2240c Remove use_stablehlo as minimum mlir_api_version >= 43
PiperOrigin-RevId: 512176274
2023-02-24 15:20:09 -08:00
Roy Frostig
1c84e4a753 migrate internal dependencies from jax.interpreters.batching to jax._src.interpreters.batching
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.

PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
Roy Frostig
219723c738 migrate internal dependencies from jax.interpreters.ad to jax._src.interpreters.ad
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.

Includes some import fixups along the way.

PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Jake VanderPlas
671c72a782 Update signature of ad.defbilinear to simplify transpose rules 2023-01-31 09:07:39 -08:00
Jake VanderPlas
dc862a9279 psum: fix docstring formatting 2022-12-27 12:55:48 -08:00
Eugene Burmako
a1480c454e Migrate JAX from producing MHLO to producing StableHLO
As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically:
  1) MLIR lowerings now produce StableHLO ops instead of MHLO ops.
  2) Fallback lowerings now produce StableHLO ops as well.
  3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs).

From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO):
  a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing.
  b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath.
  c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo".
  d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues.

PiperOrigin-RevId: 497978733
2022-12-27 08:53:20 -08:00
Qiao Zhang
9fda20fc29 Add examples to lax.psum to illustrate axis_index_groups better.
PiperOrigin-RevId: 497401892
2022-12-23 12:04:59 -08:00
Peter Hawkins
2c6c30d458 Bump the minimum jaxlib version to 0.4.1.
Jaxlib 0.4.1 has XLA client version 109 and MLIR API version 39.
2022-12-19 17:49:24 +00:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Anselm Levskaya
ffb4711969 Expose channel_id in AllToAllOp in both XLA builder and MHLO.
PiperOrigin-RevId: 494334791
2022-12-09 21:58:28 -08:00
Peter Hawkins
516f0d0d0a Support negative axes in all_gather.
Previously we didn't check for these and they caused crashes during MHLO verification.

PiperOrigin-RevId: 493160581
2022-12-05 17:48:50 -08:00
jax authors
18f77a526b Add inferReturnTypes for PartitionIdOp.
Same as what already exists for ReplicaIdOp.

PiperOrigin-RevId: 491947476
2022-11-30 10:03:38 -08:00
Sholto Douglas
92fd8534cd Fixes b/259636412, all-gather failing when called within xmap in pjit. Piece by piece making xmap in pjit work with all collectives so that we can use it to write 'manual kernels' safely!
PiperOrigin-RevId: 491749906
2022-11-29 15:16:23 -08:00
Johannes Reifferscheid
da4108d5e0 mhlo.all_to_all: support tuple form in importer/exporter.
PiperOrigin-RevId: 490560403
2022-11-23 12:24:12 -08:00
Yash Katariya
a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Adam Paszke
d742e6a410 Transpose all_gather to reduce_scatter
Also, add support for AD and batching of reduce_scatter (with its transpose being all_gather again).

PiperOrigin-RevId: 488706478
2022-11-15 11:03:22 -08:00
Jake VanderPlas
7f89fd40a2 Cleanup: remove unused imports in private modules
Also improve our flake8 filter rules to avoid ignoring these.
2022-10-20 14:37:21 -07:00
Yash Katariya
63be0c3815 Guard the new channel_handle feature on mlir_api_version for backwards compatibility
PiperOrigin-RevId: 481246613
2022-10-14 15:28:30 -07:00
Adam Paszke
746dd5ab13 Add support for MANUAL lowering of ppermute
PiperOrigin-RevId: 481157480
2022-10-14 09:02:55 -07:00
Parker Schuh
361d3fe553 Add an experimental custom_partitioner API which allows
customizing the partitioning rules.

PiperOrigin-RevId: 481032649
2022-10-13 18:37:21 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Adam Paszke
ffd34d5ad7 Allow collectives in manually sharded computations
... at least when the manual sharding applies to the whole mesh, because
that's all that XLA can support right now. This is especially important
when computing gradients of xmapped functions (when manual lowering is
enabled), since AD often introduces many `psum`s.

PiperOrigin-RevId: 467895089
2022-08-16 04:54:14 -07:00
George Necula
3d9c8fbe6f [dynamic-shapes] Ensure that the axis_size_env is passed to sub lowering contexts 2022-07-12 12:44:23 +03:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Peter Hawkins
cc0f51603d [JAX] Don't expand complex all-reduce ops into real/complex pairs.
[XLA:CPU] Implement complex all-reductions for sum and product.

Fixes https://github.com/google/jax/issues/11133 by making XLA implement the all-reduction whenever we build one, not just the one path on which we happened to have a workaround.

PiperOrigin-RevId: 455687275
2022-06-17 13:45:24 -07:00
Xin Zhou
c017d09767 [mhlo] Add result type inference for mhlo.alltoall.
PiperOrigin-RevId: 449591261
2022-05-18 15:24:22 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Eugene Burmako
90f926ac6b [MHLO] Migrate mhlo.all_reduce to HLO_CompatibleOperandsAndResultType
This runs into the currently unsupported feature in Python bindings which prevents it from taking advantage of the type inference functionality provided by HLO_CompatibleOperandsAndResultType.

PiperOrigin-RevId: 447844880
2022-05-10 15:39:18 -07:00
Peter Hawkins
931bf3674b [JAX] Split the "gpu" platform in internal JAX usage into separate "cuda" and "rocm" platforms.
In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist.

[PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device.

PiperOrigin-RevId: 446737518
2022-05-05 09:33:06 -07:00
Anudhyan Boral
a147046d18 Add unary xeinsum and allow named axis reductions for unary and binary xeinsums 2022-04-26 09:55:42 +00:00
Sharad Vikram
f17c09eb8d add in mlir lowering for tokens 2022-04-21 11:28:58 -07:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00
Peter Hawkins
cb4abe754a [MHLO] Separate registrations for collective and initial_style primitives from the XLA translation rule registration.
Change in preparation for removing XLA translation rules for many primitives. However, even after the MHLO switch we still need to tag collective and initial_style primitives.

PiperOrigin-RevId: 441474701
2022-04-13 07:26:26 -07:00
Sharad Vikram
0fa1eddd25 Adds simple effect types to jaxprs 2022-04-11 11:50:41 -07:00
Peter Hawkins
cbdcdf7401 [MHLO] Add MHLO lowerings for parallel collectives.
PiperOrigin-RevId: 440106753
2022-04-07 07:59:26 -07:00
Lukas Geiger
50e8bc4514 Replace reshape with expand_dims if possible 2022-03-31 01:34:26 +01:00
Peter Hawkins
f7ba328e7a Ensure that pdot parameters are hashable.
PiperOrigin-RevId: 416804949
2021-12-16 07:22:59 -08:00
Peter Hawkins
b0646557ee Change primitive arguments to parallel and sparse primitives to make parameters hashable.
An upcoming change adds a cache keyed on (primitive, params), but to do that, we need the params to be hashable.

PiperOrigin-RevId: 416793521
2021-12-16 06:13:32 -08:00
Matthew Johnson
c555f5f0e4 handle trivial case for ppermute batching rule
fixes #8688
2021-12-14 10:42:05 -08:00
Peter Hawkins
1f2d8c0c07 In CPU all_gather lowering, make sure the outputs are bools if the inputs are bools.
PiperOrigin-RevId: 414045093
2021-12-03 16:12:03 -08:00
Peter Hawkins
42647e013f [MLIR] Make jit(pmap(...)) work in the MLIR lowering.
This is redundant with the XLA lowering, but it's probably not the end of the world as a temporary state. An alternative would have been to port the _xla_shard/_xla_unshard primitives to the LAX level and to use xla.lower_fun, but it's not immediately obvious to me how to access ReplicaId() without defining a new primitive. lax.axis_index() is similar but not identical.

Add an axis_env argument to xla.primitive_subcomputation for use by the MLIR fallback path.

PiperOrigin-RevId: 413124116
2021-11-30 05:34:34 -08:00
Peter Hawkins
db0e3fbea9 Reenable pytype checking for jax._src.lax.lax.
pytype checking for this module is no longer excessively slow after the module was split.

PiperOrigin-RevId: 412098920
2021-11-24 11:15:40 -08:00