464 Commits

Author SHA1 Message Date
jax authors
b5583742b5 Merge pull request #21273 from superbobry:mypy-ruff
PiperOrigin-RevId: 636146344
2024-05-22 06:35:38 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Yash Katariya
02c19e9600 Make jax.grad and compute_on work correctly. If the forward pass has annotation to execute on CPU, then it's backward pass also executes on CPU.
PiperOrigin-RevId: 634917402
2024-05-17 16:38:35 -07:00
Yash Katariya
2d6d408b19 Initial commit for jax.experimental.compute_on API.
The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host.

`cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation.

PiperOrigin-RevId: 634909918
2024-05-17 15:59:21 -07:00
Yash Katariya
671fb1265d Update the multi-process note in pjit's docstring
PiperOrigin-RevId: 632160561
2024-05-09 08:38:29 -07:00
Yash Katariya
96f888bcfe Reverts 1956ff7d7b73794012fece2d8452e097196587fc
PiperOrigin-RevId: 631974751
2024-05-08 17:23:13 -07:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
Yash Katariya
1956ff7d7b Add specialize on jax.jit so that we can delete the duplicate code in jax.make_jaxpr.
You can now do (in addition to make_jaxpr): `jax.jit(f).specialize(*args, **kwargs) -> stages.Specialized`

PiperOrigin-RevId: 628748620
2024-04-27 18:58:16 -07:00
Yash Katariya
755f350910 Clean up some code in pxla.py that deals with jaxpr and avals. Lift the discharging of refs into a separate function and remove global_in_avals argument from lower_sharding_computation
PiperOrigin-RevId: 628564679
2024-04-26 18:29:18 -07:00
Yash Katariya
8239674dab Replace donation_vector's logic with donation_vector_with_in_tree which is now deleted
PiperOrigin-RevId: 627556267
2024-04-23 17:38:30 -07:00
Yash Katariya
3f17626f4b Fix donation with kwargs. The problem is that pytrees sort dictionaries by default. So if we create the donation vector with original kwargs order, it won't match the aval order (which is created by sorting kwargs i.e. dict) and we end up donating the wrong input.
Fix this by calculating the donation vector by looking at the in_tree.

A bonus is that we can now cache the calculation of donation vector leading to faster tracing times in JAX.

PiperOrigin-RevId: 627512710
2024-04-23 14:50:04 -07:00
Parker Schuh
7ba811eb4a Support auto in shard_map.
- Pull mesh from NamedSharding when rewriting manual axes.
- Properly set manual axes in SPMDAxisContext in shard_map.
- Properly set dims as unspecified inside shard_map.

PiperOrigin-RevId: 627156892
2024-04-22 14:29:35 -07:00
Yash Katariya
eb92a5c711 Add layout support to make_array_from_callback.
PiperOrigin-RevId: 625048520
2024-04-15 12:38:34 -07:00
Yash Katariya
90401d51e9 Accept layout on ShapeDtypeStruct on the sharding argument. DeviceLocalLayout.AUTO is not allowed on SDS.
PiperOrigin-RevId: 624982814
2024-04-15 09:19:40 -07:00
Matthew Johnson
c33126f45a fix 2024-04-12 14:25:38 -07:00
Dougal Maclaurin
f313a46916
Merge branch 'main' into refs-in-vjps 2024-04-12 15:25:37 -04:00
Dougal
29368e6a8e Add a zeros rule for mutable arrays and test it using a custom vjp.
add jit compatibility (have pjit jvp instantiate all ref tangents)

Co-authored-by: Matt Johnson <mattjj@google.com>
2024-04-12 15:22:07 -04:00
Jake VanderPlas
d5405bd92f [key reuse] handle reuse of closed-over constants 2024-04-11 15:39:45 -07:00
Yash Katariya
0d8eb45c20 Remove the sharding and layout checks for non-DCE'd arguments during AOT safe call.
This is because the tracing, lowering and compilation caches do not register a miss if sharding/layout of a DCE'd arg changes when it's passed again to a jitted function.

This is not true for avals so that check still exists.

PiperOrigin-RevId: 623375760
2024-04-09 22:12:05 -07:00
Yash Katariya
d1b1d0b019 Reverts a1c8207caea8bbc323bbcfb7735768822a59f5ce
PiperOrigin-RevId: 623045488
2024-04-08 21:35:02 -07:00
Yash Katariya
a1c8207cae Add kwargs support to in_shardings argument of jax.jit.
Currently, we only support this case:

* If kwargs are specified, then all in_shardings should be specified as dict matching the kwargs. args and kwargs mixture is not allowed. Either everything are kwargs or args hence in_shardings is a dict or specified positionally.

Example:

```
@partial(jax.jit, in_shardings=dict(y=s2, x=s1))
def f(x, y):
  return x * 2, y * 2

f(x=arr, y=arr2)
```

Fixes https://github.com/google/jax/issues/17400

PiperOrigin-RevId: 623018032
2024-04-08 19:19:56 -07:00
Yash Katariya
c3f5af7d46 Delete deprecated AOT layouts API.
PiperOrigin-RevId: 622666838
2024-04-07 14:15:36 -07:00
Yash Katariya
3b5980fd73 Share lowering code between jit and aot jit path
PiperOrigin-RevId: 622487044
2024-04-06 13:44:18 -07:00
Yash Katariya
c125442644 Add Layout support to jax.jit.
`jax.jit` now accepts `Layout` instances to the `in_shardings` and `out_shardings` argument. Major changes are just plumbing `in_layouts` and `out_layouts` everywhere.

Note that public api is `Layout(device_local_layout, sharding)` which is how users will pass us the Layout but internally we split them apart into device_local_layout and sharding.

Docs are coming up on how to use the API and what Layouts mean and how to make sense of them (especially on TPU).

PiperOrigin-RevId: 622352537
2024-04-05 20:09:34 -07:00
Yash Katariya
55233a0029 device_local_layout can be None on a jax.Array for backends that don't implement certain required methods for a jax.Array to populate the device_local_layout.
Skip the error checks when arr.layout.device_local_layout is None.

PiperOrigin-RevId: 622007598
2024-04-04 16:42:27 -07:00
Yash Katariya
52f7de0969 Remove the unused return from prepare_axis_resources
PiperOrigin-RevId: 621738698
2024-04-03 22:39:42 -07:00
Yash Katariya
5cbb26f36d Make device_local_layout and sharding optional in Layout. Also only accept Layout class to _in_layouts and _out_layouts.
This is in preparation to get `jax.jit` to accept `Layout`.

PiperOrigin-RevId: 621697750
2024-04-03 18:37:32 -07:00
Yash Katariya
92326dbc71 Expose Layout(device_local_layout, sharding) class allowing users to specify layouts of Arrays.
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.

Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
2024-04-03 16:13:31 -07:00
jax authors
00489be23d Fix a bug where exceptions were thrown in debug message formatting, when sharding was set to None on arrays.
PiperOrigin-RevId: 621193460
2024-04-02 08:56:37 -07:00
Yash Katariya
6e0c95585a Remove the canonicalization to GSPMDSharding internally in jit. This is not required anymore since the caches are split into tracing, lowering and compilation.
The canonicalization doesn't provide any value anymore and only makes the internals more complicated.

The canonicalization can be done by lowering to HloSharding in places where required and there are utilities to help with that.

PiperOrigin-RevId: 619292757
2024-03-26 13:28:45 -07:00
jax authors
e3bbd670bc Avoid jax_explain_cache_misses unpacking error.
PiperOrigin-RevId: 618931412
2024-03-25 12:55:00 -07:00
Yash Katariya
25d01e983c [Take 2] Expose .layout on jax.Array. Also add checks in the AOT path to make sure that the input Array's layout matches the layout given to jax.jit.
Reverts cd79e71d85621a8d6dede9a710bdb2a29bb380fd

PiperOrigin-RevId: 618878870
2024-03-25 10:08:43 -07:00
Jake VanderPlas
8949a63ce1 [key reuse] rename flag to jax_debug_key_reuse 2024-03-22 05:37:30 -07:00
jax authors
cd79e71d85 Reverts 0e092a77067dbbce33cfd6d54a46e743b779919b
PiperOrigin-RevId: 618127324
2024-03-22 03:46:09 -07:00
Yash Katariya
0e092a7706 Expose .layout on jax.Array. Also add checks in the AOT path to make sure that the input Array's layout matches the layout given to jax.jit.
PiperOrigin-RevId: 618050680
2024-03-21 21:02:40 -07:00
Yash Katariya
d57bb8c748 Raise a better error message when an invalid input is passed to jit call.
Before:

```
TypeError: Argument 'ShapeDtypeStruct(shape=(4, 2), dtype=int32)' of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type.

```

After:

```
TypeError: Argument 'x['b']['c']' of shape int32[4,2] of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type.

```

The error is raised deep down the stack during `shard_arg`, so we raise an `InvalidInputException` and catch it in `_python_pjit_helper` where we have the `arg_names` information.

PiperOrigin-RevId: 618014044
2024-03-21 17:46:32 -07:00
Peter Hawkins
5532e5505b [XLA:Python] Add a C++ implementation of flatten_one_level.
Also add a copy of the default registry that doesn't have None registered as a leaf, which is slightly faster than using an is_leaf function.

This is mostly just doing an old TODO.

PiperOrigin-RevId: 617988496
2024-03-21 15:57:23 -07:00
Peter Hawkins
54d8bde057 Don't tree_flatten in_shardings and out_shardings each time a jit() is traced.
Do it once when the jit is constructed.

(In general we do a bit too much switching back and forth between flattened and unflattened representations, and we'd probably do well just to keep things flattened.)

PiperOrigin-RevId: 617859205
2024-03-21 09:00:16 -07:00
Yash Katariya
dd574cbc74 Remove _python_pjit and make _cpp_pjit the only function wrapper.
PiperOrigin-RevId: 617846352
2024-03-21 08:18:42 -07:00
Peter Hawkins
79b18948c3 Only call inspect.signature once during the initial call to jit().
We call inspect.signature() once for debug information and once for argnum resolving. We can just call it once and reuse the result.

PiperOrigin-RevId: 617824439
2024-03-21 06:36:07 -07:00
Peter Hawkins
d3e03fff5d Refactorings to the jit implementation.
Notably:
* We can share more code between jit/pjit. There's no significant difference between the two, other than the handling of the resource environment, so we can share more of the code.
* Rather than having an infer_params callback, we can just teach common_infer_params (now named _infer_params) to handle the resource environment, which is the only meaningful difference. common_infer_params already had to understand the two cases, so there's no reason we need to hoist part of that logic into a callback.
* If we slightly alter the role of PjitInfo so it contains only the things we know about a jit() or can deduce from its arguments, we can construct it ahead of time. This does require that we split out a couple of things that we cannot deduce at that time, namely the resource environment and the two layout parameters into separate arguments, but the result reads more cleanly to me.

No functional changes intended, this is just to improve readability.

PiperOrigin-RevId: 617812557
2024-03-21 05:37:32 -07:00
Yash Katariya
cd1e55a351 Remove physical_hlo_sharding from TyRules.
The only caller of `physical_op_sharding` outside of TyRules was mlir.py. This CL also changes lower_jaxpr_to_fun to only accept logical arg_shardings and result_shardings which are XLACompatiableShardings.

PiperOrigin-RevId: 616267810
2024-03-15 16:02:13 -07:00
Matthew Johnson
8c2f6b3e8c re-enable pjit forwarding optimization, add tests 2024-03-15 14:06:35 -07:00
Matthew Johnson
8a7c604aa7 disable optimization 2024-03-15 10:35:08 -07:00
Matthew Johnson
c515f15e01 fix residual forwarding bug, fixes #20267 2024-03-15 10:07:29 -07:00
Matthew Johnson
649cd50681 [mutable-arrays] support closed-over mutable arrays in jit 2024-03-13 09:59:03 -07:00
Yash Katariya
1cb8d31c66 Convert in_shardings to physical shardings in cpp dispatch path because the same happens with prng arrays.
Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch.

PiperOrigin-RevId: 613289252
2024-03-06 11:42:40 -08:00
Yash Katariya
ca3f3f0f17 Make sure that if gspmd_sharding1 == gspmd_sharding2, then their hash also is equal.
PiperOrigin-RevId: 613009976
2024-03-05 16:36:49 -08:00
Matthew Johnson
ab0f7061ad [mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others

The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
   handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
   refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.

As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-29 21:50:19 -08:00
Jake VanderPlas
d08e9a03d8 [key reuse] add eager checks 2024-02-29 15:30:19 -08:00