332 Commits

Author SHA1 Message Date
Yash Katariya
fe3fed3627 Remove axis_resources from with_sharding_constraint since it has been 3 months since the deprecation as per the API deprecation policy.
PiperOrigin-RevId: 535687618
2023-05-26 12:35:16 -07:00
Yash Katariya
4f074718d4 Make pjit_call_impl go via C++ dispatch.
This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl.

PiperOrigin-RevId: 535630905
2023-05-26 08:57:30 -07:00
Peter Hawkins
e464dc8700 Reland: [XLA:Python] Add buffer protocol support to jax.Array
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.

The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.

Fixes https://github.com/google/jax/issues/14713

PiperOrigin-RevId: 535248553
2023-05-25 07:20:42 -07:00
Ce Zheng
8e397f7f08 [XLA:Client] Change replicate_last_dim to subgroup_types in HloSharding.iota_tile to cover arbitrary subgroups, adding necessary accessors.
PiperOrigin-RevId: 535079635
2023-05-24 20:26:28 -07:00
Ce Zheng
4f1f5e4516 [XLA:Client] Expose HloSharding pybind factories for iota tile/partial tile, replicated and manual sharding,
PiperOrigin-RevId: 534600886
2023-05-23 16:37:42 -07:00
Yash Katariya
b71829f882 Allow pjit.AUTO to be used with jax.jit. This introduces an API change which requires a mesh to be provided to pjit.AUTO(mesh).
`with mesh:` is no longer required with pjit to use the auto spmd pass of GSPMD.

PiperOrigin-RevId: 533801596
2023-05-20 23:00:35 -07:00
Sholto Douglas
e0b5003880 Allow unconstrained dimensions when using NamedShardings.
PiperOrigin-RevId: 533752415
2023-05-20 16:28:12 -07:00
Parker Schuh
56ca8af9bb Make custom_partitioning support multiple return values.
PiperOrigin-RevId: 533584581
2023-05-19 16:58:54 -07:00
Yash Katariya
4a5c6f8200 For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the tracing time.
PiperOrigin-RevId: 533263584
2023-05-18 15:09:54 -07:00
Parker Schuh
08169291a4 Simplify custom_partitioning to use jax.ShapeDtypeStruct instead of passing separate
arguments for shape and sharding.

PiperOrigin-RevId: 533257532
2023-05-18 14:48:07 -07:00
Yash Katariya
f1c2711292 Add impl rule for with_sharding_constraint so that users can use their functions with and without a jit.
The semantics of eager wsc is the same as within a jit i.e. it will reshard to the given sharding only if the devices are the same and in the same order.

eager wsc won't work as expected with AD transpose because there is no `src` argument to reverse the shardings when transposing and was decided that it is fine for now. jax.device_put should be the API to use for that.

PiperOrigin-RevId: 532858670
2023-05-17 11:50:12 -07:00
Yash Katariya
8b9e6bcbd4 For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the tracing time.
PiperOrigin-RevId: 532155068
2023-05-15 10:32:24 -07:00
Yash Katariya
befa29b566 Fix the cache on to_gspmd_sharding to depend on if device/backend is set on pjit/jit.
Before if a SingleDeviceSharding went via `to_gspmd_sharding` and then the same SingleDeviceSharding (created when device/backend is set) went via `to_gspmd_sharding`, we would hit the cache and return the first SingleDeviceSharding which didn't have the dynamic attribute on it.

This would eventually cause errors down the stack. The fix is to explicitly thread this argument through all the caches so we miss them and create the correct sharding.

PiperOrigin-RevId: 530712918
2023-05-09 14:24:21 -07:00
Yash Katariya
1629c6c76b Make jax.jit work with vmap(..., spmd_axis_name) when there is no mesh context manager.
This will only work if the input Array's sharding is a NamedSharding

Fixes https://github.com/google/jax/issues/15886

PiperOrigin-RevId: 529758233
2023-05-05 10:48:33 -07:00
Yash Katariya
a6254c75e0 Improve the shape incompatible error message by adding the argument/result name path to it.
PiperOrigin-RevId: 529605855
2023-05-04 21:50:04 -07:00
Yash Katariya
bffddf76cb Improve the error raised when wsc is passed a PartitionSpec without a mesh context manager
PiperOrigin-RevId: 529260748
2023-05-03 19:35:51 -07:00
Yash Katariya
9515ccf376 Fix pjit + vmap when device is passed as an argument to pjit/jit
PiperOrigin-RevId: 529155035
2023-05-03 11:55:23 -07:00
Rahul Joshi
9d750ae97d Fix pjit outfeed test avoid potential deadlocks.
PiperOrigin-RevId: 529076350
2023-05-03 06:51:26 -07:00
Yash Katariya
40349a8612 Normalize 1 length tuples to a string while getting PartitionSpec from array mapping.
Fixes https://github.com/google/jax/issues/15782

PiperOrigin-RevId: 528796985
2023-05-02 08:55:40 -07:00
Yash Katariya
c52e48b6c0 Only return the same input Sharding object is the original aval's ndim and out_aval's ndim are the same.
This is because if both the OpShardings are replicated then the ndim is not encoded in the OpSharding and it will return True even if the Sharding is incompatible with the output's ndim. Concretely `NamedSharding({'x': 1, y: '2'}, P('x'))` is not compatible with a input with `ndim == 0`.

PiperOrigin-RevId: 528621971
2023-05-01 17:39:51 -07:00
Yash Katariya
4a3fb238f6 Return the same sharding object if the output OpSharding matches the input OpSharding.
Fixes https://github.com/google/jax/issues/15782

PiperOrigin-RevId: 528531594
2023-05-01 11:46:57 -07:00
Yash Katariya
86c1f5bcee Preserve the sharding type of physical sharding on logical sharding when .sharding is accessed on a PRNGKeyArray
PiperOrigin-RevId: 527639257
2023-04-27 11:41:00 -07:00
Jake VanderPlas
3108f05eee fix pjit_test:testWithCustomPRNGKey 2023-04-25 10:52:15 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Parker Schuh
87c328864b Improve testing for custom_partitioning.
Add a test to demonstrate how to force XLA to choose
a different sharding.

Also it is possible to return the wrong
shape from a partition function. We should error in this case.

PiperOrigin-RevId: 525606690
2023-04-19 18:26:51 -07:00
Yash Katariya
53e6382f4a Add arg_names to aval mismatch error raised during AOT compilation to raise better error messages
PiperOrigin-RevId: 525561905
2023-04-19 15:08:53 -07:00
Yash Katariya
0a19638490 Plumb debug_info to meshExecutable as a optional arg to raise better error messages.
PiperOrigin-RevId: 525521694
2023-04-19 12:35:49 -07:00
Yash Katariya
75cf3d96d5 Try to preserve shardings with vmap(pjit) by converting the GSPMDShardings to original sharding type via the pxla.py helper
PiperOrigin-RevId: 524966654
2023-04-17 15:32:57 -07:00
Yash Katariya
c235f214d0 Create same Sharding objects wherever possible to get maximum cache hits
PiperOrigin-RevId: 524116574
2023-04-13 15:22:17 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Peter Hawkins
2e524411db Add unregistered mhlo.num_replicas and mhlo.num_partitions attributes to HLO output.
These are to allow PJRT plugin developers an inline way to determine the number of replicas/partitions to which the module is targeted. There are no stability guarantees on these attributes at the moment.

PiperOrigin-RevId: 524013922
2023-04-13 08:55:44 -07:00
Yash Katariya
393e5931d1 Move parse_flatten_op_sharding to sharding_impls.py to remove local import of pjit using that function from pxla.py
PiperOrigin-RevId: 523573375
2023-04-11 19:26:25 -07:00
Yash Katariya
cf2f182a6c Preserve PositionalSharding on the output of pjits if the inputs had PositionalSharding on them by converting GSPMDSharding to PositionalSharding
PiperOrigin-RevId: 523535581
2023-04-11 16:28:03 -07:00
Yash Katariya
5d1abe1ba9 Make apply_primitive preserve shardings on outputs.
PiperOrigin-RevId: 523186148
2023-04-10 12:41:02 -07:00
Yash Katariya
49438c78e4 Do the sharding.addressable_devices check only once in _get_input_indices since all shardings should have the same device_assignment.
That check happens at the start of lower_sharding_computation. Also use the optimized DeviceAssignment object which has all the calculations cached if this path is hit multiple times.

Also remove `device_assignment` from MeshExecutable since it is not used anywhere in that class

PiperOrigin-RevId: 523182028
2023-04-10 12:23:25 -07:00
Peter Hawkins
be1cf46a49 Split sharding_impls into its own Bazel target.
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.

PiperOrigin-RevId: 523146076
2023-04-10 10:15:58 -07:00
Yash Katariya
a1797170af Create a proper NamedSharding without None as the pspec. This happens when users pass None as the out_shardings/in_shardings and pjit should convert it to a proper PartitionSpec.
PiperOrigin-RevId: 523125287
2023-04-10 08:43:01 -07:00
Yash Katariya
5d2f453094 Preserve shardings on the output of pjit that were provided on the arguments.
Following are the changes:

* Make _pjit_lower_cached depend on exact sharding equality if `_original_sharding` exists. This top level cache should fill up eventually if users are passing different shardings into the pjit function.
* Split lower_sharding_computation into 3 caches:
  * _trace_to_jaxpr_and_dce cache -- This will return a closed jaxpr which is DCE'd
  * _cached_lowering_to_hlo cache -- This will cache the generation of MHLO. This cache is dependent on the semantic equality of shardings i.e. if 2 shardings lower to the same OpSharding, then there will be a cache hit
  * _cached_compilation cache -- This caches the compilation so that we don't recompile if the shardings are semantically equal.

The way this works is the out_handlers are created again if we pass in different shardings to pjit (but there is no recompilation). This allows us to maintain the shardings passed by the user.

For ops like `jnp.squeeze` where we infer the sharding from the executable, we try to recreate a NamedSharding (right now, more support will be added in following CLs) from the GSPMDSharding since it will be available on the input.

PiperOrigin-RevId: 522991145
2023-04-09 15:42:11 -07:00
Yash Katariya
038ac445c2 Remove global_str since all avals in pjit are global
PiperOrigin-RevId: 522443476
2023-04-06 14:52:07 -07:00
Yash Katariya
e42ea83ab8 Improve the error message raised from jax.jit if Pspec or None is passed
PiperOrigin-RevId: 522377813
2023-04-06 10:50:31 -07:00
Peter Hawkins
452f3c55e3 Rename jax._src.sharding_utils to jax._src.op_shardings.
Move some more op_sharding related helpers to that module.

PiperOrigin-RevId: 522343010
2023-04-06 08:32:46 -07:00
Yash Katariya
b926e04afc Remove the shim of functions in sharding_utils from pxla.py and use those functions directly from sharding_utils in JAX
PiperOrigin-RevId: 522319332
2023-04-06 06:18:03 -07:00
Yash Katariya
728a5ed96a [shard-map] fix eager shmap+prngs, revise phys aval/sharding logic
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-04-05 23:04:41 -07:00
Yash Katariya
78678ee9e1 Rename count_pjit_cache_miss with count_pjit_cpp_cache_miss because it is confusing which cache the first function is taking about as pjit has many caches
PiperOrigin-RevId: 521559652
2023-04-03 14:15:02 -07:00
Yash Katariya
a5d308542e Add src argument to device_put as an experimental arg
PiperOrigin-RevId: 519308082
2023-03-24 21:10:26 -07:00
Skye Wanderman-Milne
ef5e4a4035 Remove 'pjrt_c_api_unimplemented' pytest mark.
Instead, we skip tests that the PJRT C API doesn't support. We had
this tag for feature development so it was easy to broadly disable,
but now we don't expect to need to do that.
2023-03-24 23:14:54 +00:00
Skye Wanderman-Milne
4cb3b011a0 Remove PJRT C API bypass.
Now that all functionality needed by frameworks is implemented, let's
remove the possibility of not noticing missing functionality due to
the bypass.

PiperOrigin-RevId: 519018438
2023-03-23 18:39:14 -07:00
Yash Katariya
b5c9c0f47e Raise a better error message when there is a device assignment mismatch via the apply_primitive route.
PiperOrigin-RevId: 518282464
2023-03-21 08:40:42 -07:00
Yash Katariya
c58e2f6280 Improve the empty mesh error message raised in pjit if mesh is not used and Pspec is passed to in|out_shardings
PiperOrigin-RevId: 517495400
2023-03-17 13:37:06 -07:00
Yash Katariya
d02f28199b Clean up pjit after jax.Array
* Remove {in|out}_positional_semantics from pjit_p.bind
* Remove `in_is_global` from lower_sharding_computation
* Remove local_to_global and global_to_local
* Clean up some arguments of sharded_lowering since they are not needed

PiperOrigin-RevId: 517469390
2023-03-17 11:53:00 -07:00