Yash Katariya
88d4bc3d45
Rename AxisTypes enum to AxisType
...
PiperOrigin-RevId: 736935746
2025-03-14 11:48:21 -07:00
Yash Katariya
a4ca0dbc6c
Make the signature of AbstractMesh to be AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types)
instead of AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types)
so that we are consistent across all Mesh APIs: Mesh
, AbstractMesh
and make_mesh
...
PiperOrigin-RevId: 736371111
2025-03-12 21:32:31 -07:00
Yash Katariya
c6dcbb6759
[sharding_in_types] Rework the axis_types
argument in Mesh and AbstractMesh APIs. The changes are:
...
1. axis_types now takes a `AxisTypes | tuple[AxisTypes, ...] | None`. It doesn't take a dictionary anymore
2. `jax.make_mesh` also takes the same `axis_types` tuple as in point 1.
PiperOrigin-RevId: 736360041
2025-03-12 20:41:50 -07:00
Matthew Johnson
7c2f842353
shard_map and other fixes to direct-linearize
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-07 21:02:40 +00:00
Yash Katariya
53494ade2d
PRNGKeyArray.aval
should have the correct logical sharding. This required refactoring code so that we don't hit recursion errors.
...
PiperOrigin-RevId: 732536521
2025-03-01 18:18:19 -08:00
Bart Chrzaszcz
4997e45743
#sdy close any partially sharded dimensions if using auto
axes in a shard_map
.
...
PiperOrigin-RevId: 731724837
2025-02-27 07:53:18 -08:00
Yash Katariya
9deb7e3d96
[sharding_in_types] physical_aval
should set the correct sharding on ShapedArray
so that lowering and compilation don't crash
...
PiperOrigin-RevId: 730885084
2025-02-25 07:53:14 -08:00
Yash Katariya
66037d10e7
Set the mesh of the sharding during broadcast in vmap so that we don't hit an error during canonicalization. This is similar to bcd4048dd5
...
PiperOrigin-RevId: 729532213
2025-02-21 08:05:42 -08:00
Yash Katariya
bcd4048dd5
Set the mesh of tangent.aval
when we are creating zeros_like_aval
because when you close over an array which is unused, we error out during canonicalization
...
PiperOrigin-RevId: 729340808
2025-02-20 19:32:34 -08:00
Yash Katariya
b6b319cd06
If cur_mesh is empty and AxisTypes of Mesh passed to shmap are Explicit, then treat the axes mentioned in auto
as explicit too. In other words, "auto" really means "don't convert to manual", ie leave the listed mesh axes as they are, whether explicit or auto type
...
PiperOrigin-RevId: 728942780
2025-02-19 21:25:53 -08:00
Yash Katariya
8305803b76
[sharding_in_types] Initial support for partial-auto/explicit shard_map + sharding-in-types. If the axes in shmap(..., auto=...)
is an explicit axes in the outer mesh context, then that axis is treated as Explicit instead of Auto.
...
PiperOrigin-RevId: 728920514
2025-02-19 20:04:54 -08:00
Parker Schuh
b7c66bd22e
Only add new manual axes to residuals when adding axes with partial_auto.
...
PiperOrigin-RevId: 728839349
2025-02-19 15:27:32 -08:00
Matthew Johnson
3681960427
[shard_map] fix debug_print with partial auto shmap
...
Co-authored-by: Parker Schuh <parkers@google.com>
2025-02-15 00:23:59 +00:00
George Necula
550d1aa187
[better_errors] Continue adding debug info to Jaxprs (step 6)
...
This follows in a series, starting with #26078 and #26313 , adding debug_info to more calls to lu.wrap_init.
Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).
Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
2025-02-11 11:28:58 +01:00
Bixia Zheng
9cbff64251
#sdy Enable test_partial_auto_of_random_keys under Shardy.
...
PiperOrigin-RevId: 720731202
2025-01-28 15:36:52 -08:00
Bart Chrzaszcz
21913b8efa
#sdy Enable more shard_map
tests under Shardy.
...
PiperOrigin-RevId: 720193144
2025-01-27 09:10:49 -08:00
Parker Schuh
f3e27b6c28
Support axis_index using a nested shard_map instead of iota with full to shard.
...
PiperOrigin-RevId: 718661661
2025-01-22 19:14:37 -08:00
Peter Hawkins
efab6945ca
Remove code that supported jaxlib < 0.5.
...
The new xla_extension_version is 303 and the new mlir_api_version is 57.
2025-01-17 14:22:27 -05:00
Peter Hawkins
8f2f4b45fb
Annotate several tests as thread-unsafe.
...
PiperOrigin-RevId: 714117130
2025-01-10 11:24:39 -08:00
George Necula
dd0447a7c6
[aot] Add support for as_text(debug_info=True).
...
This exposes an easier way to get StableHLO and HLO
with more debugging information (source locations
for StableHLO and metadata for HLO).
2025-01-10 07:59:56 +02:00
Peter Hawkins
51b9fe3010
[JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
...
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.
In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.
PiperOrigin-RevId: 713272197
2025-01-08 06:37:44 -08:00
Zixuan Jiang
64c0f62ec4
Sort manual axes when lowering jax.shard_map
to sdy.manual_computation
, which ensures the determinism in the generated sdy.manual_computation
.
...
PiperOrigin-RevId: 712973327
2025-01-07 11:02:55 -08:00
Parker Schuh
b49ba6553c
Remove the need for check_rep for with_sharding_constraint.
...
PiperOrigin-RevId: 712630197
2025-01-06 12:59:22 -08:00
Yunlong Liu
3ff000ee3e
fix the degenerated case
2025-01-06 16:08:07 +00:00
Yunlong Liu
97b1faacdd
Fixes the random key sharding in shard_map.
2024-12-29 18:43:21 +00:00
Matthew Johnson
9f42b99a76
add test for partial-auto ppermute
...
PiperOrigin-RevId: 707992245
2024-12-19 12:27:18 -08:00
Adam Paszke
3b9a8f7913
Avoid assuming that jnp.sin will be traced in abstract mesh tests
...
The test does not clear the JAX caches, and jax.sin is a jitted closure
that's shared between all test methods, so there's no guarantee that someone
hasn't already traced sine at that same shape before. This only shows up rarely
since it depends on the subset of tests assigned to the same test executor.
PiperOrigin-RevId: 706706380
2024-12-16 07:45:03 -08:00
Parker Schuh
0e7f218eb0
Support axis_index inside shard_map(auto=...) by using iota and
...
then calling full_to_shard.
PiperOrigin-RevId: 705704369
2024-12-12 18:39:05 -08:00
Peter Hawkins
62e66b684b
Don't monkey-patch functions in test_utils to count events for tests.
...
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.
Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
Keith Rush
c0811c9dff
Adds coverage for spmd-axisname-filtering in shard_map transpose.
...
PiperOrigin-RevId: 699193349
2024-11-22 09:14:29 -08:00
Parker Schuh
2c9b917b9d
Don't psum over auto mesh dims in _unmentioned2.
...
PiperOrigin-RevId: 698440525
2024-11-20 10:36:03 -08:00
Dan Foreman-Mackey
41a0493e56
Add shard map replication rule for ffi_call.
2024-11-14 15:44:31 -08:00
Dan Foreman-Mackey
dfabcb027d
Add a shard map replication rule for cond_p.
2024-11-13 06:33:57 -08:00
Bart Chrzaszcz
3544efcade
#sdy Fix Shardy bug where we weren't setting shmap in/out shardings as open.
...
If I revert the change in `shard_map.py`, then the unit test added `test_partial_auto_propagate_through` fails with:
```
self.assertEqual(actual.sharding, sharding)
AssertionError: Named[18 chars]('i': 2, 'j': 2), spec=PartitionSpec(), memory_kind=device) != Named[18 chars]('i': 2, 'j': 2), spec=PartitionSpec('i',), memory_kind=device)
```
PiperOrigin-RevId: 692971413
2024-11-04 08:12:35 -08:00
Bart Chrzaszcz
44158ab0e4
#sdy add shardy CPU config for all JAX tests, disabling any known failing test cases.
...
Only test cases breaking on CPU are related to:
- pure callbacks
- export
- shard alike
Note that `layout_test` is broken on TPU, leaving a comment saying to enable it.
Also fixed `shard_map_test` test that was broken when running Shardy on one TPU, and `aot_test` which was breaking due to calling a different C++ StableHLO compilation function.
PiperOrigin-RevId: 691496997
2024-10-30 11:40:20 -07:00
Parker Schuh
9500bd451a
Fix float0 behavior inside shard_map transpose under scan.
...
PiperOrigin-RevId: 689512880
2024-10-24 14:15:40 -07:00
Yash Katariya
6c8e56f43f
Finish 0.4.35 release by removing dead code
...
PiperOrigin-RevId: 689396609
2024-10-24 08:45:43 -07:00
Vladimir Belitskiy
2f2fd8a334
Skip some Shardy-enabled tests if XLA < 292.
...
PiperOrigin-RevId: 686133374
2024-10-15 09:30:41 -07:00
Peter Hawkins
366c823857
Fix test failure when shardy is not enabled.
...
PiperOrigin-RevId: 679713372
2024-09-27 13:42:20 -07:00
Bart Chrzaszcz
e62a50cd34
#sdy add JAX Shardy support for shard_map.
...
For example the following JAX program:
```py
devices = np.array(jax.devices()[:8])
mesh = Mesh(devices, axis_names=('x'))
a = jax.device_put(
jnp.arange(8 * 8).reshape((8, 8)),
jax.sharding.NamedSharding(mesh, P('x', None)))
@jax.jit
@partial(
shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)
)
def fwd(a):
axis_size = lax.psum(1, 'x')
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
return lax.ppermute(a, 'x', perm=perm)
print(jax.jit(fwd).lower(a).as_text())
```
prints:
```cpp
module @jit_fwd attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <["x"=8]>
func.func public @main(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = call @fwd(%arg0) : (tensor<8x8xi32>) -> tensor<8x8xi32>
return %0 : tensor<8x8xi32>
}
func.func private @fwd(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<8x8xi32> {mhlo.layout_mode = "default"}) {
%0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x"} (%arg1: tensor<1x8xi32>) {
%1 = "stablehlo.collective_permute"(%arg1) <{channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 0]]> : tensor<8x2xi64>}> : (tensor<1x8xi32>) -> tensor<1x8xi32>
sdy.return %1 : tensor<1x8xi32>
} : (tensor<8x8xi32>) -> tensor<8x8xi32>
return %0 : tensor<8x8xi32>
}
}
```
PiperOrigin-RevId: 679165100
2024-09-26 08:45:40 -07:00
Parker Schuh
5e3f7618fc
Support pmin and pmax in check_rep.
...
PiperOrigin-RevId: 678336530
2024-09-24 11:46:30 -07:00
Parker Schuh
1acf9567aa
Add get_replication to shard_map.py for verifying if an array is replicated.
...
PiperOrigin-RevId: 676910872
2024-09-20 11:25:15 -07:00
Michael Hudgins
d4d1518c3d
Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
...
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Yash Katariya
e209abfb2c
Improve the coverage of shard map tests for < 8 devices. Due to the skip in SetupModule before this change, we lost a lot of coverage on latest hardware.
...
PiperOrigin-RevId: 676571965
2024-09-19 14:49:08 -07:00
Peter Hawkins
940860625e
Remove code that existed to support jaxlib < 0.4.32.
...
New minimum versions:
* jaxlib 0.4.32
* xla_extension_version 283
* mlir_api_version 57
PiperOrigin-RevId: 675291231
2024-09-16 14:30:00 -07:00
Matthew Johnson
358f00d5e0
shmap in_spec None shouldn't require hashability
...
Co-authored-by: Roy Frostig <frostig@google.com>
2024-09-12 23:03:06 +00:00
Peter Hawkins
5cc5ed2c5c
Disable a shard_map test case that fails on TPU v5e.
...
PiperOrigin-RevId: 672618556
2024-09-09 11:45:41 -07:00
Keith Rush
265bb7bf4c
Adds failing test for https://github.com/google/jax/issues/23476 .
...
PiperOrigin-RevId: 672183133
2024-09-07 20:30:18 -07:00
Yash Katariya
e1b497078e
Rename jtu.create_global_mesh
to jtu.create_mesh
and use jax.make_mesh
inside jtu.create_mesh
to get maximum test coverage of the new API.
...
PiperOrigin-RevId: 670744047
2024-09-03 16:23:07 -07:00
Yash Katariya
6913551d8d
If AbstractMesh
is an input to shard_map
, then in eager mode require atleast one input to be a NamedSharding
not all inputs.
...
PiperOrigin-RevId: 663310336
2024-08-15 08:10:42 -07:00