166 Commits

Author SHA1 Message Date
Matthew Johnson
88d1cd731d remove pdot and xeinsum (since xmap is gone) 2024-07-25 21:19:17 +00:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Peter Hawkins
8ab0c07edc Don't wrap singleton ir.Values with tuples during HLO lowering.
In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons.

To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
2024-07-01 16:11:00 -04:00
rajasekharporeddy
b93da3873b Fix Typos 2024-06-17 13:55:46 +05:30
jax authors
ede94c3c81 Rollback of https://github.com/google/jax/pull/20705
Causing pmap_test.py failures.

Reverts a7bce471440dda2a8bbeed1fe01dd9f733ef5bbc

PiperOrigin-RevId: 638437174
2024-05-29 15:46:55 -07:00
jax authors
a7bce47144 Merge pull request #20705 from chaserileyroberts:chase/pbroadcast_channel_fix
PiperOrigin-RevId: 637986186
2024-05-28 12:29:40 -07:00
Michael Levesque-Dion
43f51d73ce Clean up version switches from dense array migration
PiperOrigin-RevId: 637955865
2024-05-28 10:58:51 -07:00
Chase Roberts
af6970e432 Pipe channel handle 2024-05-28 10:20:50 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Sergei Lebedev
c3bc88d5e4 Bumped mypy to 1.10.0 and ruff to 0.4.4 2024-05-16 23:16:32 +01:00
Samuel
959ecca182
Minor typo fix in docstring jax.lax.psum
Fix code formatting inconsistency in `psum` docstring

Currently, "device2" and "device3" are rendered incorrectly in the JAX documentation (see second example [here](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.psum.html))
2024-04-12 13:12:00 +01:00
Chase Roberts
01412f7645 pbroadcast 2024-03-18 15:12:33 -07:00
Philip Pham
3fe65e2005 Pipe tiled through all_to_all primitive
The `_all_to_all_transpose_rule` calls `all_to_all` which can accept a `tiled`
argument. Thus, for the transpose to know the right value of `tiled` to pass, we
need to plumb the `tiled` argument through the primitive and various
interpreters, even though it's a no-op because the `tiled` argument is handled
outside the primitive. It would be cleaner to handle `tiled` inside the
primitive, but I will leave that for followup work.

Fixes #15982.

PiperOrigin-RevId: 612628600
2024-03-04 16:33:56 -08:00
Sergei Lebedev
5283d4b4a5 Axis names are now tracked via an effect
This allows propagating the names bottom up -- from equations to the jaxpr,
instead of "discovering" them top-down by traversing (and rebuilding) the
jaxpr via core.subst_axis_names.

PiperOrigin-RevId: 612416803
2024-03-04 05:42:03 -08:00
Michael Levesque-Dion
ebfce197ea Emit dense arrays for StableHLO ops migrating to dense arrays
We are migrating some attrs on some StableHLO ops to use DenseI64ArrayAttr instead of DenseIntElementsAttr. Using DenseI64ArrayAttr enforces that the attr values are 1-dimensional and provides nicer APIs. (see https://github.com/openxla/stablehlo/issues/1578 for additional context)

Unfortunately, we have to duplicate the `dense_int_array` function because we migrated the ops in batches. We can't use the existing `dense_int_array` function because it would produce arrays for ops that hadn't yet been migrated. This PR makes the final batch of changes, so no additional methods should be added going forward.

We also have to introduce a new `dense_bool_array` function, with a similar version check.

When the minimum supported jaxlib version uses a recent enough version of StableHLO  (v6 or above), it will be possible to remove the version checks and remove the duplicated `dense_int_array_v6` function.

PiperOrigin-RevId: 601271749
2024-01-24 16:41:37 -08:00
Jan Hrček
4da56dcdd7 Fix duplicate word occurrences 2023-12-19 06:15:30 +01:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Peter Hawkins
458a8962be Always lower reduce_scatter_p as an HLO ReduceScatter.
We don't need the fallback path for CPU: XLA:CPU already does its own lowering of ReduceScatter as AllReduce + DynamicSlice, and I plan to teach it a direct lowering in an upcoming change.

PiperOrigin-RevId: 586311031
2023-11-29 05:37:58 -08:00
Peter Hawkins
1e961b80da Remove fallback path that lowers all_gather via psum.
As far as I can tell this is no longer necessary on GPU, which handles arbitrary allgather dimensions (by making the dimension the major-most dimension in layout assignment), and on CPU, where at present XLA would do the same lowering JAX would.

I'm planning to improve the XLA:CPU lowering in a subsequent change.

PiperOrigin-RevId: 586291911
2023-11-29 04:14:11 -08:00
Lukas Geiger
7f5784a903 Add missing f-strings identifiers in xeinsum error message 2023-11-18 03:11:24 +00:00
Peter Hawkins
8e8dc263bc Use MLIR generated convenience functions athing(...) instead of writing AThingOp(...).result.
In most cases these are more succinct.

This change does not update Pallas/Mosaic.

PiperOrigin-RevId: 583448254
2023-11-17 11:47:14 -08:00
Neil Girdhar
3c920c0120 Switch from flake8 to Ruff 2023-11-15 22:35:52 -05:00
Matthew Johnson
f33ef3ff9c improve psum_scatter docstring (formatting and content) 2023-11-14 17:46:35 -08:00
George Necula
edbe49fb2a Cleanup the handling of single- and multi-platform lowering in ModuleContext
Previously, we introduced support for multi-platform lowering, by
adding a new LoweringParameters object that can be used to specify
a cross-lowering platform or even multiple platforms. But we had
kept the ModuleContext.platform in place because some lowering rules
were still referencing it. Now we replace ModuleContext.platform with
ModuleContext.platforms, which removes the redundancy, simplifies
the code, and makes it clearer that the lowering rules should not
simply assume single-platform lowering.

PiperOrigin-RevId: 576575376
2023-10-25 10:40:41 -07:00
George Necula
a59ada03bd [export] Adapt several collective lowering rules for multi-platform lowering
This fixes a few more places where the lowering rules used module_context.platform,
which is not supported for multi-platform lowering.
2023-10-13 11:15:41 -07:00
Matthew Johnson
70b58bbd30 rolling forward shard_map transpose fixes
The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests!

The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those!

The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule.

Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa

PiperOrigin-RevId: 561805840
2023-08-31 17:31:21 -07:00
Matthew Johnson
8a04dfd830 rolling back shard_map transposition change to fix a bug
Reverts 437d7be73534403f39fbee9d6391be1c532933a1

PiperOrigin-RevId: 561730581
2023-08-31 12:39:56 -07:00
Matthew Johnson
fdd252f6ca [shard-map] add rewrite for efficient transposition 2023-08-30 15:08:11 -07:00
Peter Hawkins
93900245aa Remove jax.interpreters.xla.register_collective_primitive.
We aren't consuming this data any more. It existed only to compare against the set of multiprocess-allowed collectives, but we removed that list also. So this registry is completely pointless.

PiperOrigin-RevId: 561150259
2023-08-29 15:10:05 -07:00
Peter Hawkins
c9cf6b4423 Remove allowlist for multihost collectives.
This allowlist used to prevent users from using collectives that didn't work correctly in multihost pmap(). But currently every collective in JAX (except for pgather(), which isn't public), is on the list. So the allowlist serves no purpose any more.

PiperOrigin-RevId: 555124144
2023-08-09 04:43:51 -07:00
Peter Hawkins
ca17b6c08f Move functions out of xla.py closer to their users.
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.

Remove an unused top_k translation rule as well.

PiperOrigin-RevId: 554946059
2023-08-08 14:40:42 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Matthew Johnson
6ea8a546f6 always lower all_to_all to AllToAll 2023-04-11 18:31:17 -07:00
Peter Hawkins
be1cf46a49 Split sharding_impls into its own Bazel target.
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.

PiperOrigin-RevId: 523146076
2023-04-10 10:15:58 -07:00
Peter Hawkins
abf1acf76c Replace references to jax.interpreters with jax._src.interpreters in JAX core.
PiperOrigin-RevId: 520933067
2023-03-31 08:58:00 -07:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Yash Katariya
d84ac2240c Remove use_stablehlo as minimum mlir_api_version >= 43
PiperOrigin-RevId: 512176274
2023-02-24 15:20:09 -08:00
Roy Frostig
1c84e4a753 migrate internal dependencies from jax.interpreters.batching to jax._src.interpreters.batching
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.

PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
Roy Frostig
219723c738 migrate internal dependencies from jax.interpreters.ad to jax._src.interpreters.ad
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.

Includes some import fixups along the way.

PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Jake VanderPlas
671c72a782 Update signature of ad.defbilinear to simplify transpose rules 2023-01-31 09:07:39 -08:00
Jake VanderPlas
dc862a9279 psum: fix docstring formatting 2022-12-27 12:55:48 -08:00
Eugene Burmako
a1480c454e Migrate JAX from producing MHLO to producing StableHLO
As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically:
  1) MLIR lowerings now produce StableHLO ops instead of MHLO ops.
  2) Fallback lowerings now produce StableHLO ops as well.
  3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs).

From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO):
  a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing.
  b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath.
  c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo".
  d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues.

PiperOrigin-RevId: 497978733
2022-12-27 08:53:20 -08:00
Qiao Zhang
9fda20fc29 Add examples to lax.psum to illustrate axis_index_groups better.
PiperOrigin-RevId: 497401892
2022-12-23 12:04:59 -08:00
Peter Hawkins
2c6c30d458 Bump the minimum jaxlib version to 0.4.1.
Jaxlib 0.4.1 has XLA client version 109 and MLIR API version 39.
2022-12-19 17:49:24 +00:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Anselm Levskaya
ffb4711969 Expose channel_id in AllToAllOp in both XLA builder and MHLO.
PiperOrigin-RevId: 494334791
2022-12-09 21:58:28 -08:00
Peter Hawkins
516f0d0d0a Support negative axes in all_gather.
Previously we didn't check for these and they caused crashes during MHLO verification.

PiperOrigin-RevId: 493160581
2022-12-05 17:48:50 -08:00
jax authors
18f77a526b Add inferReturnTypes for PartitionIdOp.
Same as what already exists for ReplicaIdOp.

PiperOrigin-RevId: 491947476
2022-11-30 10:03:38 -08:00
Sholto Douglas
92fd8534cd Fixes b/259636412, all-gather failing when called within xmap in pjit. Piece by piece making xmap in pjit work with all collectives so that we can use it to write 'manual kernels' safely!
PiperOrigin-RevId: 491749906
2022-11-29 15:16:23 -08:00