155 Commits

Author SHA1 Message Date
Yash Katariya
88d4bc3d45 Rename AxisTypes enum to AxisType
PiperOrigin-RevId: 736935746
2025-03-14 11:48:21 -07:00
Yash Katariya
a4ca0dbc6c Make the signature of AbstractMesh to be AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types) instead of AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types) so that we are consistent across all Mesh APIs: Mesh, AbstractMesh and make_mesh
PiperOrigin-RevId: 736371111
2025-03-12 21:32:31 -07:00
Yash Katariya
c6dcbb6759 [sharding_in_types] Rework the axis_types argument in Mesh and AbstractMesh APIs. The changes are:
1. axis_types now takes a `AxisTypes | tuple[AxisTypes, ...] | None`. It doesn't take a dictionary anymore

2. `jax.make_mesh` also takes the same `axis_types` tuple as in point 1.

PiperOrigin-RevId: 736360041
2025-03-12 20:41:50 -07:00
Matthew Johnson
7c2f842353 shard_map and other fixes to direct-linearize
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-07 21:02:40 +00:00
Yash Katariya
53494ade2d PRNGKeyArray.aval should have the correct logical sharding. This required refactoring code so that we don't hit recursion errors.
PiperOrigin-RevId: 732536521
2025-03-01 18:18:19 -08:00
Bart Chrzaszcz
4997e45743 #sdy close any partially sharded dimensions if using auto axes in a shard_map.
PiperOrigin-RevId: 731724837
2025-02-27 07:53:18 -08:00
Yash Katariya
9deb7e3d96 [sharding_in_types] physical_aval should set the correct sharding on ShapedArray so that lowering and compilation don't crash
PiperOrigin-RevId: 730885084
2025-02-25 07:53:14 -08:00
Yash Katariya
66037d10e7 Set the mesh of the sharding during broadcast in vmap so that we don't hit an error during canonicalization. This is similar to bcd4048dd5
PiperOrigin-RevId: 729532213
2025-02-21 08:05:42 -08:00
Yash Katariya
bcd4048dd5 Set the mesh of tangent.aval when we are creating zeros_like_aval because when you close over an array which is unused, we error out during canonicalization
PiperOrigin-RevId: 729340808
2025-02-20 19:32:34 -08:00
Yash Katariya
b6b319cd06 If cur_mesh is empty and AxisTypes of Mesh passed to shmap are Explicit, then treat the axes mentioned in auto as explicit too. In other words, "auto" really means "don't convert to manual", ie leave the listed mesh axes as they are, whether explicit or auto type
PiperOrigin-RevId: 728942780
2025-02-19 21:25:53 -08:00
Yash Katariya
8305803b76 [sharding_in_types] Initial support for partial-auto/explicit shard_map + sharding-in-types. If the axes in shmap(..., auto=...) is an explicit axes in the outer mesh context, then that axis is treated as Explicit instead of Auto.
PiperOrigin-RevId: 728920514
2025-02-19 20:04:54 -08:00
Parker Schuh
b7c66bd22e Only add new manual axes to residuals when adding axes with partial_auto.
PiperOrigin-RevId: 728839349
2025-02-19 15:27:32 -08:00
Matthew Johnson
3681960427 [shard_map] fix debug_print with partial auto shmap
Co-authored-by: Parker Schuh <parkers@google.com>
2025-02-15 00:23:59 +00:00
George Necula
550d1aa187 [better_errors] Continue adding debug info to Jaxprs (step 6)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
2025-02-11 11:28:58 +01:00
Bixia Zheng
9cbff64251 #sdy Enable test_partial_auto_of_random_keys under Shardy.
PiperOrigin-RevId: 720731202
2025-01-28 15:36:52 -08:00
Bart Chrzaszcz
21913b8efa #sdy Enable more shard_map tests under Shardy.
PiperOrigin-RevId: 720193144
2025-01-27 09:10:49 -08:00
Parker Schuh
f3e27b6c28 Support axis_index using a nested shard_map instead of iota with full to shard.
PiperOrigin-RevId: 718661661
2025-01-22 19:14:37 -08:00
Peter Hawkins
efab6945ca Remove code that supported jaxlib < 0.5.
The new xla_extension_version is 303 and the new mlir_api_version is 57.
2025-01-17 14:22:27 -05:00
Peter Hawkins
8f2f4b45fb Annotate several tests as thread-unsafe.
PiperOrigin-RevId: 714117130
2025-01-10 11:24:39 -08:00
George Necula
dd0447a7c6 [aot] Add support for as_text(debug_info=True).
This exposes an easier way to get StableHLO and HLO
with more debugging information (source locations
for StableHLO and metadata for HLO).
2025-01-10 07:59:56 +02:00
Peter Hawkins
51b9fe3010 [JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
2025-01-08 06:37:44 -08:00
Zixuan Jiang
64c0f62ec4 Sort manual axes when lowering jax.shard_map to sdy.manual_computation, which ensures the determinism in the generated sdy.manual_computation.
PiperOrigin-RevId: 712973327
2025-01-07 11:02:55 -08:00
Parker Schuh
b49ba6553c Remove the need for check_rep for with_sharding_constraint.
PiperOrigin-RevId: 712630197
2025-01-06 12:59:22 -08:00
Yunlong Liu
3ff000ee3e fix the degenerated case 2025-01-06 16:08:07 +00:00
Yunlong Liu
97b1faacdd Fixes the random key sharding in shard_map. 2024-12-29 18:43:21 +00:00
Matthew Johnson
9f42b99a76 add test for partial-auto ppermute
PiperOrigin-RevId: 707992245
2024-12-19 12:27:18 -08:00
Adam Paszke
3b9a8f7913 Avoid assuming that jnp.sin will be traced in abstract mesh tests
The test does not clear the JAX caches, and jax.sin is a jitted closure
that's shared between all test methods, so there's no guarantee that someone
hasn't already traced sine at that same shape before. This only shows up rarely
since it depends on the subset of tests assigned to the same test executor.

PiperOrigin-RevId: 706706380
2024-12-16 07:45:03 -08:00
Parker Schuh
0e7f218eb0 Support axis_index inside shard_map(auto=...) by using iota and
then calling full_to_shard.

PiperOrigin-RevId: 705704369
2024-12-12 18:39:05 -08:00
Peter Hawkins
62e66b684b Don't monkey-patch functions in test_utils to count events for tests.
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.

Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
Keith Rush
c0811c9dff Adds coverage for spmd-axisname-filtering in shard_map transpose.
PiperOrigin-RevId: 699193349
2024-11-22 09:14:29 -08:00
Parker Schuh
2c9b917b9d Don't psum over auto mesh dims in _unmentioned2.
PiperOrigin-RevId: 698440525
2024-11-20 10:36:03 -08:00
Dan Foreman-Mackey
41a0493e56 Add shard map replication rule for ffi_call. 2024-11-14 15:44:31 -08:00
Dan Foreman-Mackey
dfabcb027d Add a shard map replication rule for cond_p. 2024-11-13 06:33:57 -08:00
Bart Chrzaszcz
3544efcade #sdy Fix Shardy bug where we weren't setting shmap in/out shardings as open.
If I revert the change in `shard_map.py`, then the unit test added `test_partial_auto_propagate_through` fails with:
```
self.assertEqual(actual.sharding, sharding)
AssertionError: Named[18 chars]('i': 2, 'j': 2), spec=PartitionSpec(), memory_kind=device) != Named[18 chars]('i': 2, 'j': 2), spec=PartitionSpec('i',), memory_kind=device)
```
PiperOrigin-RevId: 692971413
2024-11-04 08:12:35 -08:00
Bart Chrzaszcz
44158ab0e4 #sdy add shardy CPU config for all JAX tests, disabling any known failing test cases.
Only test cases breaking on CPU are related to:
- pure callbacks
- export
- shard alike

Note that `layout_test` is broken on TPU, leaving a comment saying to enable it.

Also fixed `shard_map_test` test that was broken when running Shardy on one TPU, and `aot_test` which was breaking due to calling a different C++ StableHLO compilation function.

PiperOrigin-RevId: 691496997
2024-10-30 11:40:20 -07:00
Parker Schuh
9500bd451a Fix float0 behavior inside shard_map transpose under scan.
PiperOrigin-RevId: 689512880
2024-10-24 14:15:40 -07:00
Yash Katariya
6c8e56f43f Finish 0.4.35 release by removing dead code
PiperOrigin-RevId: 689396609
2024-10-24 08:45:43 -07:00
Vladimir Belitskiy
2f2fd8a334 Skip some Shardy-enabled tests if XLA < 292.
PiperOrigin-RevId: 686133374
2024-10-15 09:30:41 -07:00
Peter Hawkins
366c823857 Fix test failure when shardy is not enabled.
PiperOrigin-RevId: 679713372
2024-09-27 13:42:20 -07:00
Bart Chrzaszcz
e62a50cd34 #sdy add JAX Shardy support for shard_map.
For example the following JAX program:
```py
devices = np.array(jax.devices()[:8])
mesh = Mesh(devices, axis_names=('x'))
a = jax.device_put(
    jnp.arange(8 * 8).reshape((8, 8)),
    jax.sharding.NamedSharding(mesh, P('x', None)))

@jax.jit
@partial(
    shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)
)
def fwd(a):
  axis_size = lax.psum(1, 'x')
  perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
  return lax.ppermute(a, 'x', perm=perm)

print(jax.jit(fwd).lower(a).as_text())
```

prints:

```cpp
module @jit_fwd attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["x"=8]>
  func.func public @main(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = call @fwd(%arg0) : (tensor<8x8xi32>) -> tensor<8x8xi32>
    return %0 : tensor<8x8xi32>
  }
  func.func private @fwd(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<8x8xi32> {mhlo.layout_mode = "default"}) {
    %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x"} (%arg1: tensor<1x8xi32>) {
      %1 = "stablehlo.collective_permute"(%arg1) <{channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 0]]> : tensor<8x2xi64>}> : (tensor<1x8xi32>) -> tensor<1x8xi32>
      sdy.return %1 : tensor<1x8xi32>
    } : (tensor<8x8xi32>) -> tensor<8x8xi32>
    return %0 : tensor<8x8xi32>
  }
}
```

PiperOrigin-RevId: 679165100
2024-09-26 08:45:40 -07:00
Parker Schuh
5e3f7618fc Support pmin and pmax in check_rep.
PiperOrigin-RevId: 678336530
2024-09-24 11:46:30 -07:00
Parker Schuh
1acf9567aa Add get_replication to shard_map.py for verifying if an array is replicated.
PiperOrigin-RevId: 676910872
2024-09-20 11:25:15 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Yash Katariya
e209abfb2c Improve the coverage of shard map tests for < 8 devices. Due to the skip in SetupModule before this change, we lost a lot of coverage on latest hardware.
PiperOrigin-RevId: 676571965
2024-09-19 14:49:08 -07:00
Peter Hawkins
940860625e Remove code that existed to support jaxlib < 0.4.32.
New minimum versions:
* jaxlib 0.4.32
* xla_extension_version 283
* mlir_api_version 57

PiperOrigin-RevId: 675291231
2024-09-16 14:30:00 -07:00
Matthew Johnson
358f00d5e0 shmap in_spec None shouldn't require hashability
Co-authored-by: Roy Frostig <frostig@google.com>
2024-09-12 23:03:06 +00:00
Peter Hawkins
5cc5ed2c5c Disable a shard_map test case that fails on TPU v5e.
PiperOrigin-RevId: 672618556
2024-09-09 11:45:41 -07:00
Keith Rush
265bb7bf4c Adds failing test for https://github.com/google/jax/issues/23476.
PiperOrigin-RevId: 672183133
2024-09-07 20:30:18 -07:00
Yash Katariya
e1b497078e Rename jtu.create_global_mesh to jtu.create_mesh and use jax.make_mesh inside jtu.create_mesh to get maximum test coverage of the new API.
PiperOrigin-RevId: 670744047
2024-09-03 16:23:07 -07:00
Yash Katariya
6913551d8d If AbstractMesh is an input to shard_map, then in eager mode require atleast one input to be a NamedSharding not all inputs.
PiperOrigin-RevId: 663310336
2024-08-15 08:10:42 -07:00