149 Commits

Author SHA1 Message Date
yashkatariya
a65f74b392 Fix the docs build 2023-06-16 13:14:38 -07:00
Yash Katariya
6007698f4e Allow None to be passed to in_shardings and out_shardings. The default is still UNSPECIFIED to handle edge cases around the old semantics where None is treated as fully replicated.
The semantics are as follow:

* if the mesh context manager is not provided, None will be treated as UNSPECIFIED for both in_shardings and out_shardings

* If the mesh context manager is provided, None will be treated as fully replicated as per the old semantics.

This will make sure that we don't break existing code depending on None meaning replicated but also start making the transition to None meaning UNSPECIFIED for jit and pjit.

PiperOrigin-RevId: 540705660
2023-06-15 15:22:22 -07:00
Yash Katariya
79a1bc9a3e No need to return jaxpr from common_infer_params since it is already in params
PiperOrigin-RevId: 539715375
2023-06-12 11:38:06 -07:00
Yash Katariya
fa099fd262 Simplify sharding types input to physical_hlo_sharding and lower_jaxpr_to_fun.
Make sure lower_jaxpr_to_fun always sees HloSharding in arg_shardings and results_shardings.

Also make sure physical_hlo_sharding only accepts HloSharding as the input.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 538342152
2023-06-06 18:03:12 -07:00
Yash Katariya
c7b372a8df Convert OpSharding to HloSharding in the constructor of GSPMDSharding. Also make op_sharding_to_indices to work with only HloSharding.
PiperOrigin-RevId: 538165282
2023-06-06 06:35:29 -07:00
Yash Katariya
01fdd91a5f Use _to_xla_hlo_sharding everywhere in JAX. Remove _to_xla_op_sharding in favor of _to_xla_hlo_sharding since constructing a C++ class is faster than protos and will help with further changes coming to HloSharding.
PiperOrigin-RevId: 537969500
2023-06-05 13:41:31 -07:00
Yash Katariya
5cf0e042eb Change the _most_recent_executable logic to store a weakref dict of jaxpr -> executable so that with the inner cpp cache and outer cpp cache, we extract the correct executable.
PiperOrigin-RevId: 537908874
2023-06-05 10:07:05 -07:00
André Susano Pinto
cfabad5886 Avoid IndexError when constructing a ValueError for a DeviceAssignmentMismatchError.
_get_arg_names was throwing IndexError when handling functions with variadic args.

PiperOrigin-RevId: 537308439
2023-06-02 07:43:59 -07:00
Yash Katariya
f884b4d13f Fix the test_sharding_on_output_with_vmap failure in Pathways which was getting a cache miss in pjit_call_impl.
There was an inconsistency between how the global cache was used at the top level and in pjit_call_impl so standardize it via a helper function.

In the test, check for re-compilation which is what that test was doing before cl/535630905

PiperOrigin-RevId: 536575987
2023-05-30 19:51:48 -07:00
Yash Katariya
fe3fed3627 Remove axis_resources from with_sharding_constraint since it has been 3 months since the deprecation as per the API deprecation policy.
PiperOrigin-RevId: 535687618
2023-05-26 12:35:16 -07:00
Yash Katariya
4f074718d4 Make pjit_call_impl go via C++ dispatch.
This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl.

PiperOrigin-RevId: 535630905
2023-05-26 08:57:30 -07:00
Mark Sandler
bc547aa318 Adds a note that pjit is equivalent to jit.
PiperOrigin-RevId: 535296532
2023-05-25 10:17:25 -07:00
Yash Katariya
b71829f882 Allow pjit.AUTO to be used with jax.jit. This introduces an API change which requires a mesh to be provided to pjit.AUTO(mesh).
`with mesh:` is no longer required with pjit to use the auto spmd pass of GSPMD.

PiperOrigin-RevId: 533801596
2023-05-20 23:00:35 -07:00
Yash Katariya
4a5c6f8200 For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the tracing time.
PiperOrigin-RevId: 533263584
2023-05-18 15:09:54 -07:00
Yash Katariya
f1c2711292 Add impl rule for with_sharding_constraint so that users can use their functions with and without a jit.
The semantics of eager wsc is the same as within a jit i.e. it will reshard to the given sharding only if the devices are the same and in the same order.

eager wsc won't work as expected with AD transpose because there is no `src` argument to reverse the shardings when transposing and was decided that it is fine for now. jax.device_put should be the API to use for that.

PiperOrigin-RevId: 532858670
2023-05-17 11:50:12 -07:00
Yash Katariya
42a8982d23 Smuggle _experimental_lowering_platform via kwargs to make it hidden and extremely private temporary.
PiperOrigin-RevId: 532644979
2023-05-16 19:47:58 -07:00
Yash Katariya
8b9e6bcbd4 For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the tracing time.
PiperOrigin-RevId: 532155068
2023-05-15 10:32:24 -07:00
Yash Katariya
b196ad2e8c Remove the f-string evaluation during logging the elapsed time by passing in fun_name to log_elapsed_time
PiperOrigin-RevId: 532132574
2023-05-15 09:15:58 -07:00
Yash Katariya
559b837ba5 Add logging if we get a C++ cache miss
PiperOrigin-RevId: 531555996
2023-05-12 11:19:58 -07:00
Yash Katariya
befa29b566 Fix the cache on to_gspmd_sharding to depend on if device/backend is set on pjit/jit.
Before if a SingleDeviceSharding went via `to_gspmd_sharding` and then the same SingleDeviceSharding (created when device/backend is set) went via `to_gspmd_sharding`, we would hit the cache and return the first SingleDeviceSharding which didn't have the dynamic attribute on it.

This would eventually cause errors down the stack. The fix is to explicitly thread this argument through all the caches so we miss them and create the correct sharding.

PiperOrigin-RevId: 530712918
2023-05-09 14:24:21 -07:00
Yash Katariya
1629c6c76b Make jax.jit work with vmap(..., spmd_axis_name) when there is no mesh context manager.
This will only work if the input Array's sharding is a NamedSharding

Fixes https://github.com/google/jax/issues/15886

PiperOrigin-RevId: 529758233
2023-05-05 10:48:33 -07:00
Yash Katariya
a6254c75e0 Improve the shape incompatible error message by adding the argument/result name path to it.
PiperOrigin-RevId: 529605855
2023-05-04 21:50:04 -07:00
Yash Katariya
bffddf76cb Improve the error raised when wsc is passed a PartitionSpec without a mesh context manager
PiperOrigin-RevId: 529260748
2023-05-03 19:35:51 -07:00
Yash Katariya
9515ccf376 Fix pjit + vmap when device is passed as an argument to pjit/jit
PiperOrigin-RevId: 529155035
2023-05-03 11:55:23 -07:00
Yash Katariya
c52e48b6c0 Only return the same input Sharding object is the original aval's ndim and out_aval's ndim are the same.
This is because if both the OpShardings are replicated then the ndim is not encoded in the OpSharding and it will return True even if the Sharding is incompatible with the output's ndim. Concretely `NamedSharding({'x': 1, y: '2'}, P('x'))` is not compatible with a input with `ndim == 0`.

PiperOrigin-RevId: 528621971
2023-05-01 17:39:51 -07:00
Yash Katariya
4a3fb238f6 Return the same sharding object if the output OpSharding matches the input OpSharding.
Fixes https://github.com/google/jax/issues/15782

PiperOrigin-RevId: 528531594
2023-05-01 11:46:57 -07:00
Yash Katariya
34d5a6259f Default jax_spmd_mode to allow_jit which will allow explicit jax.jit to not raise the multihost error (since jit and pjit have been merged).
Implicit jit and apply_primitive will still raise an error though (which is recognized via inline parameter). Majority of jnp operations in JAX should be inlined.

PiperOrigin-RevId: 527398394
2023-04-26 15:56:46 -07:00
Jake VanderPlas
8dc06ed2ce Document jax.lax.with_sharding_constraint 2023-04-26 10:19:04 -07:00
jax authors
db2cbd4ae8 Merge pull request #15665 from hawkinsp:sourceinfo
PiperOrigin-RevId: 525581713
2023-04-19 16:30:23 -07:00
Yash Katariya
0a19638490 Plumb debug_info to meshExecutable as a optional arg to raise better error messages.
PiperOrigin-RevId: 525521694
2023-04-19 12:35:49 -07:00
Peter Hawkins
a3b262c379 Use the traceback of the call site when assigning a source location to an inlined function.
Improves but does not completely fix https://github.com/google/jax/issues/15663 . The non-inlined case still has similar problems.
2023-04-19 13:56:53 -04:00
Yash Katariya
75cf3d96d5 Try to preserve shardings with vmap(pjit) by converting the GSPMDShardings to original sharding type via the pxla.py helper
PiperOrigin-RevId: 524966654
2023-04-17 15:32:57 -07:00
Yash Katariya
38c7939fc0 Cache the entire to_gspmd_sharding function to maximize cache hits even for GSPMDShardings
PiperOrigin-RevId: 524951951
2023-04-17 14:33:21 -07:00
George Necula
961b0655fa [shape_poly] Lowering sharding annotations in presence of dynamic shapes
Sharding annotations are lowered to custom calls, and in presence of dynamic shapes
we must use the `indices_of_shape_operands` attribute to hlo.CustomCall.
In order to be able to generate the code to compute the result shapes
we must pass the `LoweringRuleContext` and the result abstract value
to the lowering helpers that generate the custom calls.

The above is easy everywhere, except for the sharding annotations for
the inputs and outputs for a function, because we do not yet have
a LoweringRuleContext available.

This code is tested by tests that are still disabled in sharding_test.
They can be enabled once StableHLO improves the support for
dynamic shapes for custom calls: https://github.com/openxla/stablehlo/issues/1367
2023-04-17 14:27:00 +03:00
Yash Katariya
673730c065 Add is_fully_replicated method to Shardings. This allows to scrub the usage of is_op_sharding_replicated from JAX because we can just query it on Shardings and save an expensive round trip to OpSharding creation.
PiperOrigin-RevId: 524379122
2023-04-14 13:56:33 -07:00
Yash Katariya
c235f214d0 Create same Sharding objects wherever possible to get maximum cache hits
PiperOrigin-RevId: 524116574
2023-04-13 15:22:17 -07:00
Yash Katariya
b06d627c05 Remove _allow_propagation_to_outputs from compile in MeshComputation since after jax.Array it is not required and can just default to being set to True if a sharding is unspecified.
PiperOrigin-RevId: 523851611
2023-04-12 17:38:18 -07:00
Yash Katariya
393e5931d1 Move parse_flatten_op_sharding to sharding_impls.py to remove local import of pjit using that function from pxla.py
PiperOrigin-RevId: 523573375
2023-04-11 19:26:25 -07:00
Yash Katariya
4a5bf290a9 Remove None's from initialization of ParsedPartitionSpec so that we are consistent across jax. This also makes accessing .user_spec return the normalized value.
PiperOrigin-RevId: 523155411
2023-04-10 10:49:06 -07:00
Peter Hawkins
be1cf46a49 Split sharding_impls into its own Bazel target.
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.

PiperOrigin-RevId: 523146076
2023-04-10 10:15:58 -07:00
Yash Katariya
5d2f453094 Preserve shardings on the output of pjit that were provided on the arguments.
Following are the changes:

* Make _pjit_lower_cached depend on exact sharding equality if `_original_sharding` exists. This top level cache should fill up eventually if users are passing different shardings into the pjit function.
* Split lower_sharding_computation into 3 caches:
  * _trace_to_jaxpr_and_dce cache -- This will return a closed jaxpr which is DCE'd
  * _cached_lowering_to_hlo cache -- This will cache the generation of MHLO. This cache is dependent on the semantic equality of shardings i.e. if 2 shardings lower to the same OpSharding, then there will be a cache hit
  * _cached_compilation cache -- This caches the compilation so that we don't recompile if the shardings are semantically equal.

The way this works is the out_handlers are created again if we pass in different shardings to pjit (but there is no recompilation). This allows us to maintain the shardings passed by the user.

For ops like `jnp.squeeze` where we infer the sharding from the executable, we try to recreate a NamedSharding (right now, more support will be added in following CLs) from the GSPMDSharding since it will be available on the input.

PiperOrigin-RevId: 522991145
2023-04-09 15:42:11 -07:00
jax authors
06569e0889 Merge pull request #15461 from jakevdp:fix-pjit-doc
PiperOrigin-RevId: 522621919
2023-04-07 10:03:55 -07:00
Jake VanderPlas
2af5af1ed9 fix formatting in pjit doc 2023-04-07 09:35:51 -07:00
Yash Katariya
694e43a44a Remove experimental_cpp_jit since that flag is unused and also remove experimental_cpp_pjit.
For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags.

I am leaving pmap's flag alone for now.

PiperOrigin-RevId: 522602754
2023-04-07 08:29:20 -07:00
Yash Katariya
038ac445c2 Remove global_str since all avals in pjit are global
PiperOrigin-RevId: 522443476
2023-04-06 14:52:07 -07:00
Peter Hawkins
b4402185db Move PartitionSpec into its own file (jax/_src/partition_spec.py).
No functional changes intended.

A subsequent change will move ParsedPartitionSpec and array mapping utilities here also.

PiperOrigin-RevId: 522393166
2023-04-06 11:43:25 -07:00
Yash Katariya
e42ea83ab8 Improve the error message raised from jax.jit if Pspec or None is passed
PiperOrigin-RevId: 522377813
2023-04-06 10:50:31 -07:00
Peter Hawkins
452f3c55e3 Rename jax._src.sharding_utils to jax._src.op_shardings.
Move some more op_sharding related helpers to that module.

PiperOrigin-RevId: 522343010
2023-04-06 08:32:46 -07:00
Yash Katariya
b926e04afc Remove the shim of functions in sharding_utils from pxla.py and use those functions directly from sharding_utils in JAX
PiperOrigin-RevId: 522319332
2023-04-06 06:18:03 -07:00
Peter Hawkins
29ba2ca926 Report the argument path when encountering an overflow error for a Python value.
PiperOrigin-RevId: 522106244
2023-04-05 11:24:40 -07:00