514 Commits

Author SHA1 Message Date
jax authors
0b3f0e11fb Reverts ebb75db8a523150c48376d15391f84380a2bb110
PiperOrigin-RevId: 688477769
2024-10-22 03:29:32 -07:00
Yash Katariya
ebb75db8a5 [sharding_in_types] Add out_type argument to einsum and dot_general to allow specifying for the output type. Right now, it only accept a NamedSharding but in the future we can allow a polymorphic type of: jax.ShapeDtypeStruct | Sharding | Layout.
PiperOrigin-RevId: 688399552
2024-10-21 22:23:53 -07:00
Yash Katariya
2153de4ce0 [sharding_in_types] If out_aval.sharding is not None and the user specified out_sharding is None, concretize it with the device assignment available and add it to the final out_shardings that's used for lowering and compilation.
This will allow us to return the exact sharding spec that sharding propagation rules figured out.

PiperOrigin-RevId: 687349015
2024-10-18 10:27:58 -07:00
Yash Katariya
4db212d2c6 Add _sharding argument to broadcasted_iota as a private parameter which only works under sharding_in_types mode.
This is required because `jax.nn.one_hot` calls into `broascasted_iota`.

PiperOrigin-RevId: 687152343
2024-10-17 21:16:51 -07:00
Yash Katariya
3e634d9530 [sharding_in_types] Add lax.transpose sharding propagation rule
PiperOrigin-RevId: 687094297
2024-10-17 17:08:04 -07:00
Yash Katariya
57a95a77ff [sharding_in_types] Support jnp.array with sharding_in_types. When the input array has a sharding, propagate it through without dropping the sharding.
PiperOrigin-RevId: 687089357
2024-10-17 16:51:41 -07:00
Yash Katariya
5df4878ad0 [sharding_in_types] Add reduce max, integer_pow and standard_unop sharding rules
PiperOrigin-RevId: 687073144
2024-10-17 15:55:29 -07:00
Yash Katariya
e92e1191b3 [sharding_in_types] Add broadcast_in_dim rule.
PiperOrigin-RevId: 687054181
2024-10-17 14:55:10 -07:00
Vladimir Belitskiy
2f2fd8a334 Skip some Shardy-enabled tests if XLA < 292.
PiperOrigin-RevId: 686133374
2024-10-15 09:30:41 -07:00
Yash Katariya
2f6cb89ac0 Add a private property to NamedSharding called _logical_device_ids which allows you to pass a custom tile_assignment_devices() equivalent.
This is because for Shardy, GSPMDSharding doesn't work, so `device_put` on a mesh with different device order needs `NamedSharding` support. Bonus is that the logic is now simplified wrt the previous version in `_different_device_order_reshard`.

This will also allow us to remove OpSharding usage in other projects which require such kind of permutation capabilities.

PiperOrigin-RevId: 685925636
2024-10-14 20:08:54 -07:00
Tongfei Guo
d621737f13 [XLA:Collective] Expose a factory for constructing HLOSharding with explicit device ordering.
PiperOrigin-RevId: 685858699
2024-10-14 15:41:23 -07:00
Yash Katariya
824ccd7183 [Shardy] Inline meshes when using shardy and get rid of global meshes from the MLIR body.
Also do a couple of cleanups.

PiperOrigin-RevId: 685746298
2024-10-14 10:08:04 -07:00
Yash Katariya
4be1e332f7 [sharding_in_types] Add constraints during lowering for dot_general and reduce_sum so that we can enforce the sharding we choose during tracing
PiperOrigin-RevId: 685216047
2024-10-12 09:58:53 -07:00
Yash Katariya
8139c531a3 Fix repr of sharding in aval when a dimension is sharded on multiple mesh axes
PiperOrigin-RevId: 685215764
2024-10-12 09:56:02 -07:00
Yash Katariya
5b8775dc2f [sharding_in_types] Add sharding rule for reduce sum which is just drop the specs for the axis we are reducing over
PiperOrigin-RevId: 685069065
2024-10-11 21:31:25 -07:00
Yash Katariya
89fcd9f1f1 Better repr of aval when shardings are present
Example: (for array for shape (8, 2) with dtype float32

```
P('x', 'y') -- float32[8@x,2@y]

P('x', None) -- float32[8@x,2]

P(('x', 'y'), None) -- float32[8@xy,2]

P(None, None) -- float32[8, 2]
```

PiperOrigin-RevId: 684996577
2024-10-11 16:48:13 -07:00
Yash Katariya
18bc354305 [sharding_in_types] Add dot_general sharding rule. We only handle the simple cases and rely on xla to insert the collectives.
Cases where we error

* batch dimensions not having consistent sharding (ignore None)
* contracting dimensions not having consistent sharding (ignore None)
* lhs.mesh != rhs.mesh
* if batch dimension and tensor dimension sharding match -> Error

PiperOrigin-RevId: 684983567
2024-10-11 16:05:13 -07:00
Yash Katariya
8ef41a6e14 [sharding_in_types] Normalize partition specs when creating avals so that P(None, None) and P() are treated as replicated and equivalent. Shardings on avals are always normalized.
PiperOrigin-RevId: 684465123
2024-10-10 09:07:44 -07:00
Yash Katariya
351187d9da [sharding_in_types] Add support for nary ops to propagate sharding when 1 input is sharded and all others are replicated.
PiperOrigin-RevId: 684289345
2024-10-09 21:24:37 -07:00
Peter Hawkins
d3f63a66b8 Remove code to support jaxlib <= 0.4.33. 2024-10-04 11:39:05 -04:00
Yash Katariya
79ff8e6232 Cache the iteration over jaxpr equation when extracting shardings because majority of the time, it's the same jaxpr so we don't need to evaluate it again and again.
PiperOrigin-RevId: 682148975
2024-10-03 20:47:59 -07:00
Yash Katariya
1efca33187 Add donate and may_alias as an argument to device_put to allow for donation and aliasing.
The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state.

**Definition:**

* donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory.

* may_alias: If True, we may return the original buffer depending on the implementation.

**What problem are we solving?**

Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want.

Adding `donate` allows users to avoid this pattern of code:

```
inp = ...
out = device_put(inp, sharding)
jax.block_until_ready(out)
jax.tree.map(lambda x: x.delete(), inp)
```

Now it can just be: `jax.device_put(inp, sharding, donate=True)`

**So what are the semantics of these 2 options?** Let's create a table:

| may-alias \= None (default) | donate \= False (default) | Result |
| :---- | :---- | :---- |
| True | True | Error |
| True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe |
| False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe |
| False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No |
| None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True |
| None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False |

`donate` is best effort for now until we fix the following things:

 * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do.

 * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`.

PiperOrigin-RevId: 681073828
2024-10-01 10:28:23 -07:00
Yash Katariya
203cda6f98 Move test_aot_device_implicit_transfer to pjit_test.py
This test is not specific to compute offload and is more relevant to pjit.

PiperOrigin-RevId: 680599882
2024-09-30 09:10:17 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Yash Katariya
c9bbf71ec6 Cleanup ParsedPartitionSpec and remove CanonicalizedParsedPartitionSpec. Also mark user_spec as private.
PiperOrigin-RevId: 676498946
2024-09-19 11:38:48 -07:00
Parker Schuh
86fe463ad7 [Take 2] Generalize global jit cpp cache keys so we can add more keys than the current donate_argnums.
This allows us to get more cache hits globally. For example:

Before:

jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache miss
After:

jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache hit

Reverts b615266175effe4aefeb903620a19f3719a604da

PiperOrigin-RevId: 675746175
2024-09-17 16:11:28 -07:00
Peter Hawkins
940860625e Remove code that existed to support jaxlib < 0.4.32.
New minimum versions:
* jaxlib 0.4.32
* xla_extension_version 283
* mlir_api_version 57

PiperOrigin-RevId: 675291231
2024-09-16 14:30:00 -07:00
Bart Chrzaszcz
062a69a97e Make JAX extract the mesh from an AUTO in/out sharding.
Automatic partitioners using JAX+Shardy want to partition models which are fully marked as `AUTO` - so no in/out sharding with a `NamedSharding`. In such a case they weren't seeing the mesh on the MLIR module. This makes sure we extract it from the `AUTO` sharding.

PiperOrigin-RevId: 672881018
2024-09-10 03:07:02 -07:00
Yash Katariya
bf66e816dd Split physical axes by default when device kind is TPU v5 lite to allow for mesh shapes (2, 2) when there are 8 v5e devices on a 4x2 topology.
PiperOrigin-RevId: 671047455
2024-09-04 11:49:17 -07:00
Yash Katariya
e1b497078e Rename jtu.create_global_mesh to jtu.create_mesh and use jax.make_mesh inside jtu.create_mesh to get maximum test coverage of the new API.
PiperOrigin-RevId: 670744047
2024-09-03 16:23:07 -07:00
Yash Katariya
252caebce3 Create jax.make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], devices: Sequence[jax.Device] | None = None) API to make it easier to create a mesh and reduce a ton of boilerplate.
`jax.make_mesh` is the stable API endpoint of `mesh_utils` but without all the extra options. If you want those, you can still use the experimental endpoint in `mesh_utils`.

PiperOrigin-RevId: 670707995
2024-09-03 14:32:03 -07:00
Yash Katariya
164b884f33 Fix failing tests in CI
PiperOrigin-RevId: 669357019
2024-08-30 09:49:58 -07:00
Yash Katariya
bcfe95e98e Initial integration of sharding in types in JAX. Currently we just support nary ops in forward only sharding propagation. Currently this functionality is experimental and hidden behind jax_sharding_in_types config flag.
There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
PiperOrigin-RevId: 668991384
2024-08-29 10:50:04 -07:00
Yash Katariya
b615266175 Reverts 82c9da020a78997862a8f7ccd494bed363f7ed01
PiperOrigin-RevId: 668969133
2024-08-29 09:43:19 -07:00
Bart Chrzaszcz
71b7e78916 Add jax_test configs for shardy and enable it for pjit_test.py and fix any tests.
Tests fixed include:

- `test_globally_sharded_key_array_8x4_multi_device`
  - Issue was in `replicate_trailing_dims` where an `xc.OpSharding` was always created. Fixed by creating an equivalent SDY sharding.
- `test_aot_out_info`
  - Issue was there was no mesh since there weren't any NamedShardings. Fixed by not asserting a mesh tuple exists in `lower_jaxpr_to_module` when adding the sdy MeshOp (there won't be any propagation)
- `test_concurrent_pjit`
  - In Shardy if there was a tensor dimension of size 0, we'd emit a verification error if the dimension is sharded on an axes. But if the axis is of size 1, then JAX says this is okay. So have shardy assume the same.
- `test_globally_sharded_key_array_result_8x4_single_device`
  - This tests adds a WSC when no `mesh_shape_tuple` exists (`"sdy.sharding_constraint"(%8) <{sharding = #sdy.sharding<@mesh, [{?}, {?}, {}]>}>`), so we should create a mesh named `mesh` with a single device id in case it doesn't exist.
- `testLowerCostAnalysis`
  - This calls into `mlir_module_to_xla_computation` which calls its own MLIR parsing function in `//third_party/tensorflow/compiler/xla/python/mlir.cc`. Needed to register the SDY dialect in it.
- `testShardingConstraintWithArray`
  - This calls `.compiler_ir(dialect="hlo")` which calls `PyMlirModuleToXlaComputation` which converts the MLIR to HLO, but the Sdy dialect is still inside. Export it before converting it to HLO.

PiperOrigin-RevId: 666777167
2024-08-23 06:51:13 -07:00
Tongfei Guo
5da27432e1 [XLA:SPMD] Check gather/scatter partitioning for index parallel case have the index parallel dimensions matches for operand and indices.
PiperOrigin-RevId: 666469705
2024-08-22 13:32:33 -07:00
Yash Katariya
82c9da020a Generalize global jit cpp cache keys so we can add more keys than the current donate_argnums.
This allows us to get more cache hits globally. For example:

Before:

```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache miss
```

After:
```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache hit
```

Also, we can remove the hack (which I didn't like) in multihost_utils.py.

PiperOrigin-RevId: 665574475
2024-08-20 16:18:58 -07:00
Yash Katariya
daa69da321 Introduce jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...]) and allow with_sharding_constraint and shard_map to accept an abstract mesh as input (with_sharding_constraint is via NamedSharding(abstract_mesh, pspec)).
**Semantics**

Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).

Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.

**Why do this?**

There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.

So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

@jax.jit
def f(x):
  y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # DEVICE MISMATCH ERROR!
```

The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.

**Okay, so how do you fix this?**

As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)

The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.

**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)

@jax.jit
def f(x):
  y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # tracing and lowering cache hit
```

**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**

**What about `shard_map`?**

shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)

@jax.jit
def f(x):
  y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # tracing and lowering cache hit
```

This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!

PiperOrigin-RevId: 662670932
2024-08-13 15:18:08 -07:00
Yash Katariya
53045380b1 Make custom partitioning work without a mesh context manager. If the arguments have NamedSharding on them, then inside partition function, we should get NamedSharding without the existence of the mesh context manager
PiperOrigin-RevId: 662146686
2024-08-12 10:40:31 -07:00
Yash Katariya
c08656c61d [Rollback] We still want to allow multiple meshes in the user program
Reverts dd958adc39550d2758ecdb13809c6d85df7658a2

PiperOrigin-RevId: 661537233
2024-08-09 23:17:46 -07:00
Yash Katariya
abc9ba00e9 Rename count_jit_and_pmap_compiles to count_jit_and_pmap_lowerings
PiperOrigin-RevId: 661496993
2024-08-09 20:03:43 -07:00
jax authors
3bd3597703 Improves error message in case of invalid sharding mesh
PiperOrigin-RevId: 661358450
2024-08-09 12:18:16 -07:00
Yash Katariya
e6303244bf If the memory kind is the default kind throughout the jaxpr, then revert back to the previous device_put behavior which was a no-op inside jit.
This is also the same behavior for arguments and outputs, where we don't insert `mhlo.memory_kind` attributes in the stableHLO if the entire jaxpr only has the default memory kind.

PiperOrigin-RevId: 660913387
2024-08-08 11:24:25 -07:00
Yash Katariya
be53ee10b1 Set jax_enable_memories flag to True by default
PiperOrigin-RevId: 660579462
2024-08-07 16:25:25 -07:00
Yash Katariya
dd958adc39 Add mesh_shape to the lowering context. This is to allow custom partitioning to not depend on the mesh context manager to return NamedShardings even if the arguments have NamedShardings on them.
Since `shardy`, sharding in types work, world 2 dagger is going in a direction of making Mesh and PartitionSpec a first class sharding type, let's pull the trigger right now to start fixing these bad user interactions.

Some things that will break due to this change: Before passing NamedSharding and an equivalent PositionalSharding to the same jitted function one after another would lead to a lowering cache hit. But now we will cache miss. In other words: `f(ns); f(ps) # cache hit before`

In followup CLs, we will make the tracing cache aware of the mesh shape too to fix some other issues related to tracing and lowering cache misses

PiperOrigin-RevId: 660177423
2024-08-06 18:35:44 -07:00
Yash Katariya
958234a9c1 Thread the mesh context manager to the place where we recover out_shardings back from GSPMDShardings. Before if you had a program like this:
```
with mesh:
  out = pjit(lambda: 1)()
```

The sharding of `out` was a `GSPMDSharding` which is not ideal. This change fixes that and returns a `NamedSharding` instead.

This is also required for `Shardy` integration.

PiperOrigin-RevId: 658842350
2024-08-02 11:04:48 -07:00
Yash Katariya
e6851e6b22 Fix the AOT check for sharding consistency which skipped checking the devices of the sharding.
So before for TPU compiled computation, a user could have passed in a committed array on CPU and JAX wouldn't have errored which is wrong.

This change fixes that. Also `is_equivalent_to` should check for devices, HloSharding and memory_kind (so removing the redundant `memory_kind` check too).

PiperOrigin-RevId: 658794885
2024-08-02 08:15:32 -07:00
Bart Chrzaszcz
4cf1fbe4cc Hide the SDY dialect right before MLIR->HLO conversion in the XLA pipeline.
Since Shardy is inside the middle of the XLA pipeline, after converting down to HLO, we need to run the Shardy export pipeline to preserve the SDY ops and sharding attributes for when we come back from HLO to MLIR when Shardy propagation is run.

PiperOrigin-RevId: 658040672
2024-07-31 09:45:43 -07:00
Bart Chrzaszcz
b00f978f70 #sdy Support with_sharding_constraint lowering through Shardy.
PiperOrigin-RevId: 655905063
2024-07-25 04:20:52 -07:00
Yash Katariya
51e27923e8 Simplify pjit's batching rule now that xmap is deleted. Also do cleanup around adding manual axes under shard_map
PiperOrigin-RevId: 655776234
2024-07-24 19:02:13 -07:00