184 Commits

Author SHA1 Message Date
Oleg Shyshkov
db464b3f0a Clarify documentation for output_offsets operand of ragged_all_to_all.
PiperOrigin-RevId: 708321802
2024-12-20 07:52:11 -08:00
Yash Katariya
473e2bf527 Put abstract_mesh on every eqn so that we can preserve it during eval_jaxpr and check_jaxpr roundtrip.
Also allow users to enter into `Auto`/`User` mode inside jit along all or some axes.

Add checks to make sure that avals inside a context match the surrounding context. This check happens inside `abstract_eval` rules but maybe we need a more central place for it which we can create later on.

PiperOrigin-RevId: 707128096
2024-12-17 09:17:21 -08:00
Oleg Shyshkov
6d82a6fc90 Allow lax.ragged_all_to_all input and output operands to have different ragged dimension sizes.
We need to guaranty that the outermost dimension of the output is big enough to fit all received elements, but it's not necessary for input and output outermost dimensions to be exactly equal.

PiperOrigin-RevId: 707011916
2024-12-17 02:20:10 -08:00
Parker Schuh
0e7f218eb0 Support axis_index inside shard_map(auto=...) by using iota and
then calling full_to_shard.

PiperOrigin-RevId: 705704369
2024-12-12 18:39:05 -08:00
Gunhyun Park
12c30578b2 Introduce lax.ragged_all_to_all primitive
This version emits a StableHLO custom call. The test outputs the following MLIR module:
```
module @jit_ragged_all_to_all {
  func.func public @main(%arg0: tensor<6xf32>, %arg1: tensor<6xf32>, %arg2: tensor<3xi32>, %arg3: tensor<3xi32>, %arg4: tensor<3xi32>, %arg5: tensor<3xi32>) -> (tensor<6xf32>) {
    %0 = stablehlo.custom_call @ragged_all_to_all(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {api_version = 4 : i32, backend_config = {replica_groups = dense<[[0, 1, 2]]> : tensor<1x3xi64>}} : (tensor<6xf32>, tensor<6xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<6xf32>
    return %0 : tensor<6xf32>
  }
}
```

For now, the API assumes `split_axis` and `concat_axis` of `all_to_all` to be the outermost (ragged) dim, and `axis_index_groups` is default to all replicas (e.g. there is only one group and covers all axis indices aka iota like the example above).

The current API is inspired from https://www.mpich.org/static/docs/v3.1/www3/MPI_Alltoallv.html which essentially also does a ragged all to all.

PiperOrigin-RevId: 704550890
2024-12-09 22:19:40 -08:00
Yash Katariya
9a0e9e55d8 [sharding_in_types] Handle collective axes in lowering rules more generally. If any axis is collective, set all dims of aval to unspecified dims in wrap_with_sharding_op.
Also lower shardings with `Collective` axes correctly to HloSharding.

PiperOrigin-RevId: 696703030
2024-11-14 17:32:01 -08:00
Dougal Maclaurin
478b750c29 Reverts f281c6f46475270a57a02416469226315377592c
PiperOrigin-RevId: 693339094
2024-11-05 07:17:14 -08:00
Dougal Maclaurin
f281c6f464 Reverts ec39b592f7c096b0b8183723feaab2ed0d001041
PiperOrigin-RevId: 692949053
2024-11-04 06:54:06 -08:00
Dougal Maclaurin
ec39b592f7 Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat)
PiperOrigin-RevId: 692557993
2024-11-02 17:03:50 -07:00
Sergei Lebedev
bdf2ca10fc Removed more dead code from various submodules
PiperOrigin-RevId: 691342832
2024-10-30 02:41:53 -07:00
Dougal Maclaurin
c36e1f7c1a Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
Tom Natan
ed5ba633d4 Reverts 6cf09f8c24c67ff650b95d174501fff3cb59db0d
PiperOrigin-RevId: 682440543
2024-10-04 13:56:27 -07:00
Tom Natan
6cf09f8c24 Reverts eff00cc4499cfe3f3f24bafda6c1ecf908232ff3
PiperOrigin-RevId: 678756266
2024-09-25 10:33:53 -07:00
Tom Natan
eff00cc449 [JAX] add support for gather/scatter batching dims following the new attributes in stablehlo.
This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See https://github.com/openxla/stablehlo/pull/2259

PiperOrigin-RevId: 678649138
2024-09-25 04:53:11 -07:00
Jake VanderPlas
7b41583414 refactor jax.lax to not depend on jax.numpy 2024-09-01 07:49:49 -07:00
jax authors
efba5f61b5 Merge pull request #22812 from superbobry:maint
PiperOrigin-RevId: 658751187
2024-08-02 04:43:33 -07:00
Abhinav Gunjal
dfe8d94170 Integrate StableHLO at openxla/stablehlo@fb18ee25
PiperOrigin-RevId: 658515936
2024-08-01 13:23:01 -07:00
Sergei Lebedev
92b1f71314 Removed various ununsed functions
To rerun the analysis do

    python -m vulture jax/_src --ignore-names "[A-Za-z]*" --ignore-decorators "*"
2024-08-01 11:18:19 +01:00
Matthew Johnson
88d1cd731d remove pdot and xeinsum (since xmap is gone) 2024-07-25 21:19:17 +00:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Peter Hawkins
8ab0c07edc Don't wrap singleton ir.Values with tuples during HLO lowering.
In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons.

To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
2024-07-01 16:11:00 -04:00
rajasekharporeddy
b93da3873b Fix Typos 2024-06-17 13:55:46 +05:30
jax authors
ede94c3c81 Rollback of https://github.com/google/jax/pull/20705
Causing pmap_test.py failures.

Reverts a7bce471440dda2a8bbeed1fe01dd9f733ef5bbc

PiperOrigin-RevId: 638437174
2024-05-29 15:46:55 -07:00
jax authors
a7bce47144 Merge pull request #20705 from chaserileyroberts:chase/pbroadcast_channel_fix
PiperOrigin-RevId: 637986186
2024-05-28 12:29:40 -07:00
Michael Levesque-Dion
43f51d73ce Clean up version switches from dense array migration
PiperOrigin-RevId: 637955865
2024-05-28 10:58:51 -07:00
Chase Roberts
af6970e432 Pipe channel handle 2024-05-28 10:20:50 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Sergei Lebedev
c3bc88d5e4 Bumped mypy to 1.10.0 and ruff to 0.4.4 2024-05-16 23:16:32 +01:00
Samuel
959ecca182
Minor typo fix in docstring jax.lax.psum
Fix code formatting inconsistency in `psum` docstring

Currently, "device2" and "device3" are rendered incorrectly in the JAX documentation (see second example [here](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.psum.html))
2024-04-12 13:12:00 +01:00
Chase Roberts
01412f7645 pbroadcast 2024-03-18 15:12:33 -07:00
Philip Pham
3fe65e2005 Pipe tiled through all_to_all primitive
The `_all_to_all_transpose_rule` calls `all_to_all` which can accept a `tiled`
argument. Thus, for the transpose to know the right value of `tiled` to pass, we
need to plumb the `tiled` argument through the primitive and various
interpreters, even though it's a no-op because the `tiled` argument is handled
outside the primitive. It would be cleaner to handle `tiled` inside the
primitive, but I will leave that for followup work.

Fixes #15982.

PiperOrigin-RevId: 612628600
2024-03-04 16:33:56 -08:00
Sergei Lebedev
5283d4b4a5 Axis names are now tracked via an effect
This allows propagating the names bottom up -- from equations to the jaxpr,
instead of "discovering" them top-down by traversing (and rebuilding) the
jaxpr via core.subst_axis_names.

PiperOrigin-RevId: 612416803
2024-03-04 05:42:03 -08:00
Michael Levesque-Dion
ebfce197ea Emit dense arrays for StableHLO ops migrating to dense arrays
We are migrating some attrs on some StableHLO ops to use DenseI64ArrayAttr instead of DenseIntElementsAttr. Using DenseI64ArrayAttr enforces that the attr values are 1-dimensional and provides nicer APIs. (see https://github.com/openxla/stablehlo/issues/1578 for additional context)

Unfortunately, we have to duplicate the `dense_int_array` function because we migrated the ops in batches. We can't use the existing `dense_int_array` function because it would produce arrays for ops that hadn't yet been migrated. This PR makes the final batch of changes, so no additional methods should be added going forward.

We also have to introduce a new `dense_bool_array` function, with a similar version check.

When the minimum supported jaxlib version uses a recent enough version of StableHLO  (v6 or above), it will be possible to remove the version checks and remove the duplicated `dense_int_array_v6` function.

PiperOrigin-RevId: 601271749
2024-01-24 16:41:37 -08:00
Jan Hrček
4da56dcdd7 Fix duplicate word occurrences 2023-12-19 06:15:30 +01:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Peter Hawkins
458a8962be Always lower reduce_scatter_p as an HLO ReduceScatter.
We don't need the fallback path for CPU: XLA:CPU already does its own lowering of ReduceScatter as AllReduce + DynamicSlice, and I plan to teach it a direct lowering in an upcoming change.

PiperOrigin-RevId: 586311031
2023-11-29 05:37:58 -08:00
Peter Hawkins
1e961b80da Remove fallback path that lowers all_gather via psum.
As far as I can tell this is no longer necessary on GPU, which handles arbitrary allgather dimensions (by making the dimension the major-most dimension in layout assignment), and on CPU, where at present XLA would do the same lowering JAX would.

I'm planning to improve the XLA:CPU lowering in a subsequent change.

PiperOrigin-RevId: 586291911
2023-11-29 04:14:11 -08:00
Lukas Geiger
7f5784a903 Add missing f-strings identifiers in xeinsum error message 2023-11-18 03:11:24 +00:00
Peter Hawkins
8e8dc263bc Use MLIR generated convenience functions athing(...) instead of writing AThingOp(...).result.
In most cases these are more succinct.

This change does not update Pallas/Mosaic.

PiperOrigin-RevId: 583448254
2023-11-17 11:47:14 -08:00
Neil Girdhar
3c920c0120 Switch from flake8 to Ruff 2023-11-15 22:35:52 -05:00
Matthew Johnson
f33ef3ff9c improve psum_scatter docstring (formatting and content) 2023-11-14 17:46:35 -08:00
George Necula
edbe49fb2a Cleanup the handling of single- and multi-platform lowering in ModuleContext
Previously, we introduced support for multi-platform lowering, by
adding a new LoweringParameters object that can be used to specify
a cross-lowering platform or even multiple platforms. But we had
kept the ModuleContext.platform in place because some lowering rules
were still referencing it. Now we replace ModuleContext.platform with
ModuleContext.platforms, which removes the redundancy, simplifies
the code, and makes it clearer that the lowering rules should not
simply assume single-platform lowering.

PiperOrigin-RevId: 576575376
2023-10-25 10:40:41 -07:00
George Necula
a59ada03bd [export] Adapt several collective lowering rules for multi-platform lowering
This fixes a few more places where the lowering rules used module_context.platform,
which is not supported for multi-platform lowering.
2023-10-13 11:15:41 -07:00
Matthew Johnson
70b58bbd30 rolling forward shard_map transpose fixes
The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests!

The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those!

The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule.

Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa

PiperOrigin-RevId: 561805840
2023-08-31 17:31:21 -07:00
Matthew Johnson
8a04dfd830 rolling back shard_map transposition change to fix a bug
Reverts 437d7be73534403f39fbee9d6391be1c532933a1

PiperOrigin-RevId: 561730581
2023-08-31 12:39:56 -07:00
Matthew Johnson
fdd252f6ca [shard-map] add rewrite for efficient transposition 2023-08-30 15:08:11 -07:00
Peter Hawkins
93900245aa Remove jax.interpreters.xla.register_collective_primitive.
We aren't consuming this data any more. It existed only to compare against the set of multiprocess-allowed collectives, but we removed that list also. So this registry is completely pointless.

PiperOrigin-RevId: 561150259
2023-08-29 15:10:05 -07:00
Peter Hawkins
c9cf6b4423 Remove allowlist for multihost collectives.
This allowlist used to prevent users from using collectives that didn't work correctly in multihost pmap(). But currently every collective in JAX (except for pgather(), which isn't public), is on the list. So the allowlist serves no purpose any more.

PiperOrigin-RevId: 555124144
2023-08-09 04:43:51 -07:00
Peter Hawkins
ca17b6c08f Move functions out of xla.py closer to their users.
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.

Remove an unused top_k translation rule as well.

PiperOrigin-RevId: 554946059
2023-08-08 14:40:42 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00