155 Commits

Author SHA1 Message Date
Yash Katariya
229cbae5ea Add num_devices to Sharding interface so that it works with NamedSharding containing AbstractMesh too.
PiperOrigin-RevId: 662938823
2024-08-14 09:03:17 -07:00
Yash Katariya
daa69da321 Introduce jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...]) and allow with_sharding_constraint and shard_map to accept an abstract mesh as input (with_sharding_constraint is via NamedSharding(abstract_mesh, pspec)).
**Semantics**

Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).

Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.

**Why do this?**

There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.

So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

@jax.jit
def f(x):
  y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # DEVICE MISMATCH ERROR!
```

The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.

**Okay, so how do you fix this?**

As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)

The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.

**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)

@jax.jit
def f(x):
  y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # tracing and lowering cache hit
```

**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**

**What about `shard_map`?**

shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)

@jax.jit
def f(x):
  y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # tracing and lowering cache hit
```

This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!

PiperOrigin-RevId: 662670932
2024-08-13 15:18:08 -07:00
Parker Schuh
d7d9724e14 If the product of manual axes is of size 1, then skip emitting
any shard_to_full or full_to_shard ops.

PiperOrigin-RevId: 658116164
2024-07-31 13:12:06 -07:00
Matthew Johnson
abfb1ce72d add temporary flag to suppress an error message, to unblock a user 2024-07-31 17:23:47 +00:00
Parker Schuh
be6b77cc54 Update shard_map(jit) to properly set manual_axes on in_shardings and out_shardings of the nested jit. This avoids a problem where the jit returns {manaual} and then this gets passed to ShardToFull (manual is already considered a full sharding).
PiperOrigin-RevId: 655719254
2024-07-24 15:25:27 -07:00
Matthew Johnson
3beb3d5eec add test for #17691
fixes #17691 (actually fixed by #18854)
2024-07-22 23:23:23 +00:00
jax authors
5ecd1965d1 Merge pull request #22544 from mattjj:19175
PiperOrigin-RevId: 654858318
2024-07-22 12:45:22 -07:00
Matthew Johnson
f7cef92ed7 [shard_map] fix psum rewrite rule's pbroadcast logic
We want to allow `psum(x, axes)` regardless of how `x` is replicated. That
means when we rewrite it into the stricter `psum2`, which can only sum over
non-replicated axes, we need to insert a pbroadcast like this:

```
psum(x, axes) == psum2(pbroadcast(x, axes & input_replicated_axes), axes)
```

In words, we need to first `pbroadcast` over all those axes we're about to sum
over but that the input is already replicated over.

We write it as a comprehension over mesh.axis_names, rather than just that set
intersection, just to ensure deterministic ordering, since Python set
operations are not guaranteed to be deterministic. There are other places in
the file where we don't ensure deterministic ordering; someday I'll come back
and fix those.

fixes #19175
2024-07-22 17:16:30 +00:00
Matthew Johnson
173794bcef [shard_map] shard_map check_rep=True rules for custom_linear_solve
fixes #21855
2024-07-20 18:06:55 +00:00
jax authors
2985b623bc Merge pull request #22117 from hawkinsp:pyver2
PiperOrigin-RevId: 647079147
2024-06-26 14:29:04 -07:00
Matthew Johnson
df907117da [shard-map] allow in_specs=None to work with arbitrary objects
fixes #22043, #17461
2024-06-26 20:53:41 +00:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Peter Hawkins
07d24e7dcc Bump minimum jaxlib version to v0.4.30.
This corresponds to xla_extension_version 271 and mlir_api_version 57.
2024-06-18 12:35:08 -04:00
Jake VanderPlas
a861c55a28 test cleanup: use ExitStack to reduce test boilerplate 2024-06-06 14:18:27 -07:00
Zixuan Jiang
dfd7d17c1d [JAX] Use iota_reshape_dims and iota_transpose_perm in pxla, which is more efficient than tile_assignment_devices.
HloSharding V1 -> HloSharding V2.

PiperOrigin-RevId: 639210975
2024-05-31 18:15:00 -07:00
Matthew Johnson
3984d822ba add error checks for vmap spmd_axis_name 2024-05-30 20:48:11 +00:00
Matthew Johnson
0a693faf48 add pjit forwarding rule
Co-authored-by: Roy Frostig <frostig@google.com>
2024-05-25 17:46:01 +00:00
Jaroslav Sevcik
a4f090819f Order axis names for shard_map residuals 2024-05-22 00:27:08 -07:00
Sergei Lebedev
c3bc88d5e4 Bumped mypy to 1.10.0 and ruff to 0.4.4 2024-05-16 23:16:32 +01:00
Parker Schuh
e8ac011d6e Allow nested shard_map.
PiperOrigin-RevId: 632275515
2024-05-09 14:45:40 -07:00
Matthew Johnson
7a87010f84 [shard_map] better fix for spmd_axis_name issues with shmap residuals
The fix in #21032 was not correct because it assumed that the set of all mesh
axis names appearing in in_specs was an upper bound on the set of mesh axes
over which residuals could be device-varying. But collectives can introduce
device variance! So it's not an upper bound.

We track device variance when check_rep=True, but often people set
check_rep=False (e.g. when using pallas_call in a shard_map). So relying on our
device variance tracking would be limiting. That may be a decent long term
solution, if we can make it easy to annotate pallas_calls with device variance
information. But it's not a great short term one to unblock things.

So instead I temporrarily went with context sensitivity: instead of making
residuals sharded over all mesh.axis_names (as we did before these patches), we
make them sharded over all mesh axis names _excluding_ any spmd_axis_names in
our dynamic context (by looking at the traces in our trace stack). It's illegal
to mention any spmd_axis_names in collectives (indeed anywhere in the body of
the function being vmapped), but I don't think we check it.

TODO(mattjj): add more testing (maybe in follow-ups)
2024-05-04 01:31:15 +00:00
Parker Schuh
7ba811eb4a Support auto in shard_map.
- Pull mesh from NamedSharding when rewriting manual axes.
- Properly set manual axes in SPMDAxisContext in shard_map.
- Properly set dims as unspecified inside shard_map.

PiperOrigin-RevId: 627156892
2024-04-22 14:29:35 -07:00
Jieying Luo
1b1c6e7c0f Enable some more C API tests.
PiperOrigin-RevId: 627065492
2024-04-22 09:38:59 -07:00
Matthew Johnson
c021117b24 add forwarding optimization test for shard_map 2024-03-15 15:11:16 -07:00
Matthew Johnson
3d32262b21 ignore NamedAxisEffect for remat and dce purposes 2024-03-05 11:04:23 -08:00
Philip Pham
3fe65e2005 Pipe tiled through all_to_all primitive
The `_all_to_all_transpose_rule` calls `all_to_all` which can accept a `tiled`
argument. Thus, for the transpose to know the right value of `tiled` to pass, we
need to plumb the `tiled` argument through the primitive and various
interpreters, even though it's a no-op because the `tiled` argument is handled
outside the primitive. It would be cleaner to handle `tiled` inside the
primitive, but I will leave that for followup work.

Fixes #15982.

PiperOrigin-RevId: 612628600
2024-03-04 16:33:56 -08:00
Jake VanderPlas
cddee4654c tests: access tree utilities via jax.tree.* 2024-02-26 14:17:18 -08:00
Jake VanderPlas
e59a0506fe Deprecate jax.tree_map in favor of jax.tree.map 2024-02-22 11:35:39 -08:00
Philip Pham
c93f283f08 [shmap] shard_map and axis_index_groups for standard collectives
These collectives have unreplicated outputs and the exception can simply be
removed. Unit tests are added for `lax.all_gather`, `lax.all_to_all`,
`lax.psum_scatter`.

Fixes #19709.

PiperOrigin-RevId: 607104947
2024-02-14 14:23:43 -08:00
Philip Pham
c850b10747 [shmap] Support multiple axes for standard collectives in shard_map 2024-02-06 17:14:03 +00:00
Matthew Johnson
5275ef6e90 [shard-map] fix disable_jit-of-shmap 2024-01-22 20:23:19 -08:00
Matthew Johnson
30c0fc4c5f [shard-map] add approx_top_k replication rule 2024-01-16 03:59:40 -08:00
Peter Hawkins
e558feaa5e Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
2024-01-15 02:13:40 -08:00
Yash Katariya
dccc0e8e5c Preserve the specs passed by the user in the output sharding from a eager shard_map.
PiperOrigin-RevId: 596665787
2024-01-08 12:09:20 -08:00
Matthew Johnson
12e57dea3f [shard-map] improve error message when a custom_vjp bwd has extra psum 2024-01-02 13:26:40 -08:00
Yash Katariya
57d74d6d24 Always return a NamedSharding from eager shard_map
PiperOrigin-RevId: 592689808
2023-12-20 16:42:57 -08:00
Parker Schuh
7ba8622719 For custom_partitioning, directly emit call when inside of a shard_map.
PiperOrigin-RevId: 592011427
2023-12-18 14:32:38 -08:00
jax authors
4459991d55 Merge pull request #18961 from mattjj:issue18955
PiperOrigin-RevId: 590655131
2023-12-13 11:04:24 -08:00
jax authors
5104c6b098 Merge pull request #18951 from mattjj:shmap-varargs-error
PiperOrigin-RevId: 590636629
2023-12-13 10:10:18 -08:00
Matthew Johnson
4ba6bd5108 [shard-map] register cumsum et al with generic rules
fixes #18955
2023-12-13 09:54:01 -08:00
Matthew Johnson
2bff2f4094 [shard-map] fix varargs error message bug
see #18823

Co-authored-by: Chase Roberts <chaser@nvidia.com>
2023-12-13 09:40:39 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Matthew Johnson
5641d8186e [shard-map] add lax.special rep checking rules 2023-12-06 15:05:36 -08:00
Matthew Johnson
5862852f85 [shard-map] add rewrite and replication checking rules for remat
these rules enable shmap-of-remat with check_rep=True
2023-11-30 10:15:48 -08:00
jax authors
11d7a2b860 Merge pull request #18741 from mattjj:shmap-test-fix
PiperOrigin-RevId: 586710378
2023-11-30 10:09:32 -08:00
Matthew Johnson
5c2635c205 [shard-map] fix test running broken by 0aec40a16fad02f084ef0cabd350db78b86b335e 2023-11-30 09:56:34 -08:00
Matthew Johnson
b8f758e4a0 [shard-map] replace jaxpr interpreters with final-style-xform-of-eval-jaxpr 2023-11-29 20:06:12 -08:00
Matthew Johnson
6f20c0af38 [shard-map] add conv replication rules
fixes #18737
2023-11-29 16:58:54 -08:00
Jake VanderPlas
0aec40a16f Deprecate arr.device_buffer and arr.device_buffers 2023-11-29 15:31:01 -08:00
Matthew Johnson
7589c2bdb8 [shard_map] implement eager custom_jvp / custom_vjp 2023-11-28 16:08:56 -08:00