16 Commits

Author SHA1 Message Date
Yash Katariya
250e2ee7da Use the mesh of out_aval when converting GSPMDSharding to NamedSharding. This makes sure that the axis types of the corresponding output is correct.
Also, if all axes of an out_aval are auto, set the corresponding out_sharding to Unspecified during lowering, otherwise things go horribly wrong. This is actually a XLA bug but we can workaround it in JAX for now.

PiperOrigin-RevId: 729307115
2025-02-20 17:13:24 -08:00
Yash Katariya
8f248fe626 [sharding_in_types] Upstream changes from defaulting sharding_in_types config to True experiment. There aren't a lot of failures in TGP but we can atleast upstream these changes until we work on the failures.
PiperOrigin-RevId: 720639755
2025-01-28 11:04:42 -08:00
Peter Hawkins
51b9fe3010 [JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
2025-01-08 06:37:44 -08:00
Peter Hawkins
62e66b684b Don't monkey-patch functions in test_utils to count events for tests.
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.

Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
Jake VanderPlas
aa6adfb9c5 Remove unused import 2024-11-08 13:36:28 -08:00
Bill Varcho
afd8239ea4 [SDY] add JAX lowering to Shardy ShardingGroupOp for shard_alike.
PiperOrigin-RevId: 694567084
2024-11-08 11:02:50 -08:00
Yash Katariya
e1b497078e Rename jtu.create_global_mesh to jtu.create_mesh and use jax.make_mesh inside jtu.create_mesh to get maximum test coverage of the new API.
PiperOrigin-RevId: 670744047
2024-09-03 16:23:07 -07:00
Jake VanderPlas
a861c55a28 test cleanup: use ExitStack to reduce test boilerplate 2024-06-06 14:18:27 -07:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jieying Luo
4c57d09590 [PJRT C API] Bump the minimum support libtpu version because there is a breaking change (075d25e0c1).
Also remove skip condition that are no longer needed because of this bump.

PiperOrigin-RevId: 611288492
2024-02-28 17:51:24 -08:00
Jieying Luo
3dbbfefef8 [PJRT C API] Add a helper method to check whether the backend is cloud TPU built after certain date.
Skip tests that are not intended to work with older version libtpu.

PiperOrigin-RevId: 610892754
2024-02-27 15:25:10 -08:00
Yash Katariya
6f96c963ff Preserve single device NamedSharding/PositionalSharding on the output instead of always return SingleDeviceShardings.
Fixes https://github.com/google/jax/issues/19459

PiperOrigin-RevId: 600999853
2024-01-23 21:29:14 -08:00
Jake VanderPlas
80aa128e88 Guard shard_alike usage on xla_extension_version 2024-01-19 13:02:29 -08:00
Yash Katariya
68abe0da5c Add correct batching rule for shard_alike
PiperOrigin-RevId: 595532031
2024-01-03 16:24:27 -08:00
Yash Katariya
72fbdb2eb5 Expose shard_alike via jax.experimental. The API is x, y = shard_like(x, y).
The guarantee provided by this API is that the sharding of `x` and `y` will be the same! What the sharding will be is decided by GSPMD.

The flow of sharding is bidirectional i.e. SPMD will choose what the sharding should be of `x` and `y` depending on it's propagation algorithm. It might end up being that the sharding chosen is not of `x` and `y` but something better. At the end of propagation `x` and `y` will be sharded alike.

The API can be made variadic in the future i.e. `*args = shard_alike(*args)` depending on use cases.

Fixes: https://github.com/google/jax/issues/15600
PiperOrigin-RevId: 592375936
2023-12-19 16:31:33 -08:00