582 Commits

Author SHA1 Message Date
Yash Katariya
d50d1e2c40 Don't allow users to query tracer.sharding even under sharding in types mode.
Instead, users should do `tracer.aval.sharding` so that code behaves the same under jit and eager mode.

PiperOrigin-RevId: 717638986
2025-01-20 15:12:47 -08:00
Yash Katariya
799eb98cac Add reshard API in experimental. Currently for sharding_in_types we have 2 APIs: mesh_cast and reshard. Both work in sharding_in_types mode and affect the sharding of the aval. Following are the semantics of both:
* `mesh_cast`: AxisTypes between src and dst mesh **must** differ. There should be **no "visible" data movement**. The shape of the aval doesn't change.

* `reshard`: Mesh should be the **same** between src and dst (same axis_names, axis_sizes and axis_types). **Data movement is allowed**. The shape of the aval doesn't change.

We might make `reshard` == `device_put`, hence the API is in experimental. This decision can be taken at a later point in time. The reason not to just give `device_put` this power is because `device_put` does a lot of stuff right now (and is going to get even more powers in the near future like cross-host transfers) and it's semantics would be very confusing if we keep piling sharding-in-types stuff on it.

PiperOrigin-RevId: 717588253
2025-01-20 11:39:25 -08:00
Yash Katariya
c7f8d17f5a Expose hidden_axes via jax namespace as public API. Also mention it as a workaround for primitives we don't support yet.
PiperOrigin-RevId: 716839003
2025-01-17 16:48:58 -08:00
Yash Katariya
12b59f8e53 Rename hidden_mode -> hidden_axes and hidden_mode_ctx -> use_hidden_axes. Same for visible mode and visible_mode_ctx.
Also make the `axes` parameter optional of hidden_axes and visible_axes functions. If axes is optional, you drop into full hidden/visible mode.

PiperOrigin-RevId: 716771872
2025-01-17 13:01:07 -08:00
Yash Katariya
695c02b1c4 [sharding_in_types] Rename sharding_cast to mesh_cast and add a few restrictions:
* mesh_cast only works when the axis types between src and dst mesh changes. Hence the name!

* No explicit data movement is allowed. Specs containing axes that are visible cannot be different between src and dst shardings.

* src and dst mesh axis_names and axis_sizes should be the same.

TODO: Make `shardings` parameter to `mesh_cast` optional.
PiperOrigin-RevId: 716727084
2025-01-17 10:53:43 -08:00
Yash Katariya
ce85b89884 [sharding_in_types] Error out for reshape for splits like this: (4, 6, 8) -> (4, 4, 2, 6)
PiperOrigin-RevId: 716653203
2025-01-17 06:58:29 -08:00
Yash Katariya
af667199db [sharding_in_types] Rename .at[...].get(out_spec) to .at[...].get(out_sharding).
PiperOrigin-RevId: 716466870
2025-01-16 18:56:52 -08:00
Yash Katariya
97cd748376 Rename out_type -> out_sharding parameter on einsum
PiperOrigin-RevId: 716454800
2025-01-16 18:16:52 -08:00
Yash Katariya
49224d6cdb Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective.
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager

Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.

PiperOrigin-RevId: 716446406
2025-01-16 17:55:54 -08:00
Parker Schuh
f2f552c108 Allow resharding between tokens on a single device
and multiple devices.

Whenever this happens we can essentially introduce an effects barrier
instead of doing the normal device -> host -> device transfer.

Fixes https://github.com/jax-ml/jax/issues/25671.

PiperOrigin-RevId: 716309978
2025-01-16 11:24:22 -08:00
Yash Katariya
b23c42372b [sharding_in_types] If an indexing operation hits into gather_p, error out saying to use .at[...].get(out_spec=...) instead.
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
2025-01-16 10:51:15 -08:00
Yash Katariya
c6b5ac5c7b [sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.

  `operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`

* Merging into 1 dimension only and all the merging dimensions should be unsharded.

  `operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`

* Split into singleton dimensions i.e. adding extra dims of size 1

  `operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`

* Merge singleton dimensions i.e. removing extra dims of size 1

  `operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`

* Identity reshape

  `operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`

These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.

PiperOrigin-RevId: 716216240
2025-01-16 06:47:26 -08:00
Yash Katariya
c72ed260fe [sharding_in_types] Handle ShapeDtypeStruct inputs with sharding_in_types by registering the sharding on the aval properly created by SDS in it's pytype_aval_mapping.
Also If we are running under full auto mode, don't error out if primitives don't have a sharding rule registered.

PiperOrigin-RevId: 715383866
2025-01-14 08:03:50 -08:00
Bart Chrzaszcz
c14e5b4332 Add JAX unit test for Shardy which causes the compiler to introduce the mlir::tensor::TensorDialect. This was causing the compiler to crash.
PiperOrigin-RevId: 714896947
2025-01-13 03:08:33 -08:00
Yash Katariya
6b253b2f75 [shardy] Fix cases in shardy where you have a nullary function with partially specified out_shardings (i.e. some out_sharding's are None and others are NamedShardings).
In this case, the returned out_shardings should all be NamedSharding (because of NamedSharding's presence in some out_sharding's).

PiperOrigin-RevId: 714681941
2025-01-12 08:09:17 -08:00
Yash Katariya
a817f532b4 [sharding_in_types] Introduce auto_mode, user_mode, auto_mode_ctx and user_mode_ctx as **private** APIs to make writing auto/user sharding in types code way easier and noise-free.
These can be made public in the future under different names.

PiperOrigin-RevId: 714169304
2025-01-10 14:14:25 -08:00
Peter Hawkins
8f2f4b45fb Annotate several tests as thread-unsafe.
PiperOrigin-RevId: 714117130
2025-01-10 11:24:39 -08:00
Peter Hawkins
c61b2f6b81 Make JAX test suite pass (at least most of the time) with multiple threads enabled.
Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile.

PiperOrigin-RevId: 714037277
2025-01-10 06:58:46 -08:00
Yash Katariya
6319126e2d Remove extraneous print statement in a test
PiperOrigin-RevId: 713830757
2025-01-09 16:20:21 -08:00
Yash Katariya
b2b38679e2 Make sharding_in_types work with Shardy
PiperOrigin-RevId: 713479962
2025-01-08 18:05:43 -08:00
Bart Chrzaszcz
cbcc883ea3 #sdy add repr for Sdy ArraySharding and DimSharding
PiperOrigin-RevId: 713422071
2025-01-08 14:41:41 -08:00
Bixia Zheng
c4ac0dd6bd Implement the extension to the custom_partitioning API.
Add a sharding rule string and trailing factor_sizes to def_partition, to
provide a sharding rule specification when Shardy is used. We use this
information to construct a SdyShardingRule and invoke SdyShardingRule.build
during MLIR lowering.

Extend custom_partitioner tests in  pjit_test.py for Shardy sharding rule.

PiperOrigin-RevId: 713399604
2025-01-08 13:34:47 -08:00
Yash Katariya
3848f0d2ac [sharding_in_types] Functions like einsum, reshape, broadcast_in_dim, broadcasted_iota, convert_element_type and sharding_cast that take out_sharding as an argument in their signature should also allow PartitionSpec instead of just NamedSharding as an input.
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.

We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.

PiperOrigin-RevId: 713352542
2025-01-08 11:11:16 -08:00
Peter Hawkins
51b9fe3010 [JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
2025-01-08 06:37:44 -08:00
Yash Katariya
755d6cdad8 [sharding_in_types] Aval sharding under full auto mode should contain None and not UNCONSTRAINED because axis_types + pspec give the full picture.
PiperOrigin-RevId: 713105375
2025-01-07 18:04:20 -08:00
Bixia Zheng
16712b5116 Ensure that the two offsets of a dynamic_slice have the same dtype regardless
the value of config.enable_x64.

PiperOrigin-RevId: 708031525
2024-12-19 14:21:49 -08:00
Yash Katariya
9041b02dff Account for tokens in allow_spmd_sharding_propagation_to_parameters and allow_spmd_sharding_propagation_to_output compile options
PiperOrigin-RevId: 707723232
2024-12-18 17:51:33 -08:00
Yash Katariya
af63e443ef [sharding_in_types] Check out_avals with mesh context too. This is because users can pass their own shardings to functions like einsum, reshape, broadcast`, etc
PiperOrigin-RevId: 707672801
2024-12-18 14:42:40 -08:00
Yash Katariya
09fdd0daaa [sharding_in_types] Add tests allowing inputs and outputs of jit to have different axis_types on their mesh than the axis_types on the surrounding mesh context
PiperOrigin-RevId: 707356052
2024-12-17 19:45:46 -08:00
Yash Katariya
e854f1657a Allow P.UNCONSTRAINED in out_shardings at top level jit. This is required for sharding in types to work properly when out_avals contain UNCONSTRAINED specs.
This also simplifies the `impl` rule of `sharding_cast`.

PiperOrigin-RevId: 707349491
2024-12-17 19:18:24 -08:00
Peter Hawkins
7de9eb20df Reverts 525b646c0ebd5205f4fa0639c94adb2de47e1cf0
PiperOrigin-RevId: 707146329
2024-12-17 10:12:34 -08:00
Yash Katariya
473e2bf527 Put abstract_mesh on every eqn so that we can preserve it during eval_jaxpr and check_jaxpr roundtrip.
Also allow users to enter into `Auto`/`User` mode inside jit along all or some axes.

Add checks to make sure that avals inside a context match the surrounding context. This check happens inside `abstract_eval` rules but maybe we need a more central place for it which we can create later on.

PiperOrigin-RevId: 707128096
2024-12-17 09:17:21 -08:00
Adam Paszke
3b9a8f7913 Avoid assuming that jnp.sin will be traced in abstract mesh tests
The test does not clear the JAX caches, and jax.sin is a jitted closure
that's shared between all test methods, so there's no guarantee that someone
hasn't already traced sine at that same shape before. This only shows up rarely
since it depends on the subset of tests assigned to the same test executor.

PiperOrigin-RevId: 706706380
2024-12-16 07:45:03 -08:00
Yash Katariya
d0f63da4b5 Allow tracing and lowering (with lowering_platforms specified) to work with an AbstractMesh. Such a computation cannot be compiled.
This is useful for `jax.export`, e.g., for cross-platform export when we do not have access to the actual devices for which this computation is lowered.

PiperOrigin-RevId: 705764178
2024-12-12 23:17:27 -08:00
Peter Hawkins
62e66b684b Don't monkey-patch functions in test_utils to count events for tests.
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.

Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
Yash Katariya
39e4f7f2ce [sharding_in_types] Make jnp.where broadcast shardings properly when a scalar exists
PiperOrigin-RevId: 705283318
2024-12-11 16:41:18 -08:00
Yash Katariya
41f490aef4 [sharding_in_types] Default axis_types to Auto for all axis_names if user does not set any AxisType. Also resolve some TODOs now that we have a way for user to set the mesh.
PiperOrigin-RevId: 704944255
2024-12-10 20:20:23 -08:00
Yash Katariya
b5e4fd161d [sharding_in_types] Enforce AxisTypes to always exist if set_mesh is used.
Also support `Auto` mode fully or mixed in with `User` mode. This works by overriding the sharding of `Auto` axes in the PartitionSpec with `Unconstrained` in `ShapedArray` constructor. The `ShapedArray` constructor is the central place where we can make such substitutions.

During lowering of shardings with auto axes, we mark the auto dims are `unspecifed_dims`. We don't mark all dims as unspecified because that would enable XLA to shard them even further which is not what we want if some of the dims are user sharded.

PiperOrigin-RevId: 704911253
2024-12-10 18:03:21 -08:00
Jake VanderPlas
6541a62099 jax.core: deprecate a number of APIs 2024-12-10 11:11:32 -08:00
Kanglan Tang
66b900540a Disable pjit ArrayPjitTest.test_device_put_grad test on TPU v5e
PiperOrigin-RevId: 704378732
2024-12-09 12:30:36 -08:00
Peter Hawkins
79318a08cf Remove dead code after minimum jaxlib version bump to v0.4.36.
New minimum xla_extension_version is 299, and the new mlir_api_version is 57.

PiperOrigin-RevId: 704280856
2024-12-09 07:35:05 -08:00
Yash Katariya
a735bf83e5 Simply abstract_mesh and device_context context managers and handle everything via their corresponding configs in config.py
PiperOrigin-RevId: 702852769
2024-12-04 14:04:25 -08:00
Yash Katariya
9e2708eb57 [sharding_in_types] Use set_mesh API to trigger sharding_in_types instead of the config option.
PiperOrigin-RevId: 702814257
2024-12-04 12:12:29 -08:00
Yash Katariya
456dfeb0ae [Take 2] Raise a better error message if anything other than a sequence of ints is passed to make_mesh or create_device_mesh
Reverts a158e02b7d1c1a50e53adfec7f48bec69cc0dc5b

PiperOrigin-RevId: 701045239
2024-11-28 09:24:20 -08:00
Fabian Mentzer
a158e02b7d Reverts cc5036cc18bc585b0d92a4f606956da084effbad
PiperOrigin-RevId: 700998046
2024-11-28 05:35:38 -08:00
Yash Katariya
cc5036cc18 Raise a better error message if anything other than a sequence of ints is passed to make_mesh or create_device_mesh
PiperOrigin-RevId: 700779838
2024-11-27 12:43:08 -08:00
Yash Katariya
0d2dfea4b1 Add a private set_mesh API to enter into sharding_in_types mode. This is how users will enable sharding in types mode (with correct axis types set too but that doesn't work yet).
Also adding a device_context so `set_mesh` sets the devices the computation should run on correctly. The device_context however enters concrete devices into tracing and lowering cache but this should be fixed with the other jax context work going on.

PiperOrigin-RevId: 700537898
2024-11-26 20:01:04 -08:00
Yash Katariya
59e13f8114 Add sharding argument to reshape since it also takes a shape argument for the output shape
PiperOrigin-RevId: 700163883
2024-11-25 18:16:08 -08:00
Yash Katariya
deab6fbd80 Remove _pjit_lower_cached cache. We can simplify the caching of jit as we have downstream caches and a cpp cache too.
If you drop out of cpp cache, things are going to be slow anyways.

PiperOrigin-RevId: 700052522
2024-11-25 11:40:50 -08:00
Bill Varcho
066859e62f [SDY] Enable test_pjit_array_multi_input_multi_output since Shardy conflict resolution is now complete.
PiperOrigin-RevId: 700042542
2024-11-25 11:10:00 -08:00