173 Commits

Author SHA1 Message Date
Roy Frostig
6abefa1977 fast dispatch for functions over typed PRNG key arrays
Before this change, JAX could dispatch compiled functions over new-style (typed)
RNG key arrays, but it would always do so off of the fast (C++-based) dispatch
path. In other words, switching from old-style `uint32` RNG keys to new-style
keys would regress dispatch times. With this change, dispatch happens on the
fast path again and performance regressions ought to be minimal.

We currently maintain only one pytree registry, for all registered pytree node
types. We want RNG key arrays to also be treated as pytree leaves everywhere
*except* during dispatch. In other words: we want operations on (typed) RNG key
arrays to appear in Jaxpr, but we want to unravel those arrays into their
underlying `uint32` arrays only during dispatch.

To do this, we add a new internal pytree registry that dispatch respects
uniquely. This registry includes all items in the default registry, but also the
RNG key array type.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 565077758
2023-09-13 09:43:58 -07:00
Peter Hawkins
3ce5cb6d22 Remove use of get_default_device_assignment().
This is the only caller of this API in JAX, and it can be simplified.

Change in preparation for removing get_default_device_assignment() from the Python bindings.

PiperOrigin-RevId: 563770199
2023-09-08 09:18:32 -07:00
Yash Katariya
65a2ae9dec Remove always_lower now that trivial computations don't exist
PiperOrigin-RevId: 561516209
2023-08-30 19:31:30 -07:00
Yash Katariya
603c879fa0 Run _check_sharding checks during api.device_put instead of in the impl rule so that we don't have to repeat these checks in each rule of device_put.
The same is done for jit and with_sharding_constraint.

PiperOrigin-RevId: 561380348
2023-08-30 10:27:37 -07:00
Yash Katariya
e785f89470 Raise a good error message when mesh is not provided to jax.jit when using spmd_axis_name parameter of jax.vmap
PiperOrigin-RevId: 561217612
2023-08-29 20:58:57 -07:00
Yash Katariya
0501a15fd5 Print str_short of the arg and remove printing the value of the arg.
PiperOrigin-RevId: 559524941
2023-08-23 13:31:35 -07:00
Parker Schuh
e58ddb7258 Add _manual_axes support to NamedSharding. This is needed because
custom_partitioning may produce manually sharded axes.

PiperOrigin-RevId: 559288864
2023-08-22 19:24:29 -07:00
Jake VanderPlas
9aca944891 Fix type annotation for tree_util.default_registry 2023-08-16 15:07:48 -07:00
Yash Katariya
1ae37b4131 Canonicalize to default memory in init of Shardings only on the backends that support memories right now.
PiperOrigin-RevId: 553942534
2023-08-04 16:27:15 -07:00
Yash Katariya
4fb8cdb019 [Memories] Add Memories support to jax.jit and jax.device_put!
These are the following changes:

* Add a temporary flag (`JAX_FETCH_MEMORY_KIND_ON_EXECUTABLE`) (should not be used by user but needed in C++ in pjrt-ifrt code) on whether to fetch memory kinds from executable. If it is set to True, the host runtime dep needs to be linked in and should also work in OSS (more work needs to happen for that). So only the test sets it to True for now until jax memories is under development.

* Add with_memory_kind method on Sharding to allow for easier creation of shardings with different memory kind.

* Add lowering rules for device_put and jax.jit.
  * For device_put, we always add the annotation that describes a transfer to a memory and a sharding annotation.
  * For jax.jit, if the argument is on host memory, it will have an extra attribute _xla_buffer_placement.

* Handle the correct output sharding in pxla.py by extracting the memory kind from the executable.

* Handle the caching of pjit caches by canonicalizing the memory_kinds so that `NS(mesh, pspec) == NS(mesh, pspec, memory_kind='tpu_hbm')`. Also canonicalize memory_kind in `__hash__` and `__eq__` of shardings.
  * This is to not change the StableHLO to include device placement annotations right now since the host aware passes are not enabled by default and the work is under progress to make it work everywhere.

PiperOrigin-RevId: 553833344
2023-08-04 09:44:24 -07:00
Peter Hawkins
0116d196a7 Prune some exports from jax.experimental.pjit.
jax.experimental.pjit is deprecated in its entirety (use "jit" instead), and experimental APIs have no stability promises.

PiperOrigin-RevId: 552903601
2023-08-01 13:27:17 -07:00
Yash Katariya
4ddf6a9a54 Bump minimum_jaxlib_version to 0.4.14. xla_extension_version is 174 and mlir_api_version is 54
PiperOrigin-RevId: 552816893
2023-08-01 08:53:28 -07:00
Yash Katariya
3929a63a74 Fix the bug where args_info didn't have correct donated bit for args when donate_argnums was set on jax.jit. Fixes https://github.com/google/jax/issues/16906
PiperOrigin-RevId: 552509081
2023-07-31 09:43:50 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
jax authors
e2a49ee297 Tweaks the utility function _get_ppspec_from_executable to get the shardings directly from the executable (instead of from its HLO modules).
PiperOrigin-RevId: 549473458
2023-07-19 17:38:59 -07:00
Peter Hawkins
cdb48134e5 [JAX] Add support for multiple pytree registries.
We have a number of potential use cases where we want different functions that interpret pytrees differently. By allowing multiple pytree registries the same tree node can be registered in registry but not another.

One motivating use case is the new opaque PRNG array type. We want `jit` to treat these objects as if they were pytrees, but we want other transformations to leave them alone or handle them specially.

PiperOrigin-RevId: 549301796
2023-07-19 06:48:21 -07:00
Yash Katariya
f0ce0d8c6a Delete in_axis_resources and out_axis_resources from pjit since it's been more than 3 months since their deprecation. The replace is to use in_shardings and out_shardings. You can still pass PartitionSpecs to {in|out}_shardings to pjit.
PiperOrigin-RevId: 548673905
2023-07-17 06:35:49 -07:00
Yash Katariya
89c78bf53f jax.jit now works correctly if both donate_argnums and donate_argnames are specified.
Update the docstring and changelog too to mention `donate_argnames`.

PiperOrigin-RevId: 548223395
2023-07-14 14:28:16 -07:00
jax authors
ed302cbdda Merge pull request #16685 from axch:ragged-jit
PiperOrigin-RevId: 547833923
2023-07-13 10:03:06 -07:00
Yash Katariya
b337c26c72 Add donate_argnames to jax.jit. This works similarly to static_argnames.
Note that if donate_argnames is not None and donate_argnums is None, then JAX will infer donate_argnums from the names which will then we used to find the donation_vector. This is fine because currently, the same thing happens from static_argnums and static_argnames.

I'll fix the TODOs, etc in follow up CLs.

Fixes https://github.com/google/jax/issues/10539

PiperOrigin-RevId: 547612861
2023-07-12 15:09:57 -07:00
Matthew Johnson
e04db23651 Indirectify ragged axes across jitting boundaries, input- and output-side.
Also propagate DShapedArray through at least the simple cases of
shardings that show up in test cases.

Co-authored-by: Alexey Radul <axch@google.com>
2023-07-11 15:21:55 -04:00
Juliana Franco
f81a48a819 Makes it possible to lower primitives with user-defined lowering rules.
PiperOrigin-RevId: 547228102
2023-07-11 10:26:07 -07:00
Yash Katariya
744a64fce6 Make sharding on ShapeDtypeStruct a property that always exists. The previous behavior was it only existed if sharding was not None.
sharding=None means that JAX is free to choose whatever sharding it wants. As it stands, jax will choose to mark the input as replicated but JAX reserves the right to change that as it sees fit.
PiperOrigin-RevId: 543630595
2023-06-26 21:46:50 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
yashkatariya
a65f74b392 Fix the docs build 2023-06-16 13:14:38 -07:00
Yash Katariya
6007698f4e Allow None to be passed to in_shardings and out_shardings. The default is still UNSPECIFIED to handle edge cases around the old semantics where None is treated as fully replicated.
The semantics are as follow:

* if the mesh context manager is not provided, None will be treated as UNSPECIFIED for both in_shardings and out_shardings

* If the mesh context manager is provided, None will be treated as fully replicated as per the old semantics.

This will make sure that we don't break existing code depending on None meaning replicated but also start making the transition to None meaning UNSPECIFIED for jit and pjit.

PiperOrigin-RevId: 540705660
2023-06-15 15:22:22 -07:00
Yash Katariya
79a1bc9a3e No need to return jaxpr from common_infer_params since it is already in params
PiperOrigin-RevId: 539715375
2023-06-12 11:38:06 -07:00
Yash Katariya
fa099fd262 Simplify sharding types input to physical_hlo_sharding and lower_jaxpr_to_fun.
Make sure lower_jaxpr_to_fun always sees HloSharding in arg_shardings and results_shardings.

Also make sure physical_hlo_sharding only accepts HloSharding as the input.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 538342152
2023-06-06 18:03:12 -07:00
Yash Katariya
c7b372a8df Convert OpSharding to HloSharding in the constructor of GSPMDSharding. Also make op_sharding_to_indices to work with only HloSharding.
PiperOrigin-RevId: 538165282
2023-06-06 06:35:29 -07:00
Yash Katariya
01fdd91a5f Use _to_xla_hlo_sharding everywhere in JAX. Remove _to_xla_op_sharding in favor of _to_xla_hlo_sharding since constructing a C++ class is faster than protos and will help with further changes coming to HloSharding.
PiperOrigin-RevId: 537969500
2023-06-05 13:41:31 -07:00
Yash Katariya
5cf0e042eb Change the _most_recent_executable logic to store a weakref dict of jaxpr -> executable so that with the inner cpp cache and outer cpp cache, we extract the correct executable.
PiperOrigin-RevId: 537908874
2023-06-05 10:07:05 -07:00
André Susano Pinto
cfabad5886 Avoid IndexError when constructing a ValueError for a DeviceAssignmentMismatchError.
_get_arg_names was throwing IndexError when handling functions with variadic args.

PiperOrigin-RevId: 537308439
2023-06-02 07:43:59 -07:00
Yash Katariya
f884b4d13f Fix the test_sharding_on_output_with_vmap failure in Pathways which was getting a cache miss in pjit_call_impl.
There was an inconsistency between how the global cache was used at the top level and in pjit_call_impl so standardize it via a helper function.

In the test, check for re-compilation which is what that test was doing before cl/535630905

PiperOrigin-RevId: 536575987
2023-05-30 19:51:48 -07:00
Yash Katariya
fe3fed3627 Remove axis_resources from with_sharding_constraint since it has been 3 months since the deprecation as per the API deprecation policy.
PiperOrigin-RevId: 535687618
2023-05-26 12:35:16 -07:00
Yash Katariya
4f074718d4 Make pjit_call_impl go via C++ dispatch.
This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl.

PiperOrigin-RevId: 535630905
2023-05-26 08:57:30 -07:00
Mark Sandler
bc547aa318 Adds a note that pjit is equivalent to jit.
PiperOrigin-RevId: 535296532
2023-05-25 10:17:25 -07:00
Yash Katariya
b71829f882 Allow pjit.AUTO to be used with jax.jit. This introduces an API change which requires a mesh to be provided to pjit.AUTO(mesh).
`with mesh:` is no longer required with pjit to use the auto spmd pass of GSPMD.

PiperOrigin-RevId: 533801596
2023-05-20 23:00:35 -07:00
Yash Katariya
4a5c6f8200 For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the tracing time.
PiperOrigin-RevId: 533263584
2023-05-18 15:09:54 -07:00
Yash Katariya
f1c2711292 Add impl rule for with_sharding_constraint so that users can use their functions with and without a jit.
The semantics of eager wsc is the same as within a jit i.e. it will reshard to the given sharding only if the devices are the same and in the same order.

eager wsc won't work as expected with AD transpose because there is no `src` argument to reverse the shardings when transposing and was decided that it is fine for now. jax.device_put should be the API to use for that.

PiperOrigin-RevId: 532858670
2023-05-17 11:50:12 -07:00
Yash Katariya
42a8982d23 Smuggle _experimental_lowering_platform via kwargs to make it hidden and extremely private temporary.
PiperOrigin-RevId: 532644979
2023-05-16 19:47:58 -07:00
Yash Katariya
8b9e6bcbd4 For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the tracing time.
PiperOrigin-RevId: 532155068
2023-05-15 10:32:24 -07:00
Yash Katariya
b196ad2e8c Remove the f-string evaluation during logging the elapsed time by passing in fun_name to log_elapsed_time
PiperOrigin-RevId: 532132574
2023-05-15 09:15:58 -07:00
Yash Katariya
559b837ba5 Add logging if we get a C++ cache miss
PiperOrigin-RevId: 531555996
2023-05-12 11:19:58 -07:00
Yash Katariya
befa29b566 Fix the cache on to_gspmd_sharding to depend on if device/backend is set on pjit/jit.
Before if a SingleDeviceSharding went via `to_gspmd_sharding` and then the same SingleDeviceSharding (created when device/backend is set) went via `to_gspmd_sharding`, we would hit the cache and return the first SingleDeviceSharding which didn't have the dynamic attribute on it.

This would eventually cause errors down the stack. The fix is to explicitly thread this argument through all the caches so we miss them and create the correct sharding.

PiperOrigin-RevId: 530712918
2023-05-09 14:24:21 -07:00
Yash Katariya
1629c6c76b Make jax.jit work with vmap(..., spmd_axis_name) when there is no mesh context manager.
This will only work if the input Array's sharding is a NamedSharding

Fixes https://github.com/google/jax/issues/15886

PiperOrigin-RevId: 529758233
2023-05-05 10:48:33 -07:00
Yash Katariya
a6254c75e0 Improve the shape incompatible error message by adding the argument/result name path to it.
PiperOrigin-RevId: 529605855
2023-05-04 21:50:04 -07:00
Yash Katariya
bffddf76cb Improve the error raised when wsc is passed a PartitionSpec without a mesh context manager
PiperOrigin-RevId: 529260748
2023-05-03 19:35:51 -07:00
Yash Katariya
9515ccf376 Fix pjit + vmap when device is passed as an argument to pjit/jit
PiperOrigin-RevId: 529155035
2023-05-03 11:55:23 -07:00
Yash Katariya
c52e48b6c0 Only return the same input Sharding object is the original aval's ndim and out_aval's ndim are the same.
This is because if both the OpShardings are replicated then the ndim is not encoded in the OpSharding and it will return True even if the Sharding is incompatible with the output's ndim. Concretely `NamedSharding({'x': 1, y: '2'}, P('x'))` is not compatible with a input with `ndim == 0`.

PiperOrigin-RevId: 528621971
2023-05-01 17:39:51 -07:00
Yash Katariya
4a3fb238f6 Return the same sharding object if the output OpSharding matches the input OpSharding.
Fixes https://github.com/google/jax/issues/15782

PiperOrigin-RevId: 528531594
2023-05-01 11:46:57 -07:00