The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host.
`cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation.
PiperOrigin-RevId: 634909918
Fix this by calculating the donation vector by looking at the in_tree.
A bonus is that we can now cache the calculation of donation vector leading to faster tracing times in JAX.
PiperOrigin-RevId: 627512710
- Pull mesh from NamedSharding when rewriting manual axes.
- Properly set manual axes in SPMDAxisContext in shard_map.
- Properly set dims as unspecified inside shard_map.
PiperOrigin-RevId: 627156892
This is because the tracing, lowering and compilation caches do not register a miss if sharding/layout of a DCE'd arg changes when it's passed again to a jitted function.
This is not true for avals so that check still exists.
PiperOrigin-RevId: 623375760
Currently, we only support this case:
* If kwargs are specified, then all in_shardings should be specified as dict matching the kwargs. args and kwargs mixture is not allowed. Either everything are kwargs or args hence in_shardings is a dict or specified positionally.
Example:
```
@partial(jax.jit, in_shardings=dict(y=s2, x=s1))
def f(x, y):
return x * 2, y * 2
f(x=arr, y=arr2)
```
Fixes https://github.com/google/jax/issues/17400
PiperOrigin-RevId: 623018032
`jax.jit` now accepts `Layout` instances to the `in_shardings` and `out_shardings` argument. Major changes are just plumbing `in_layouts` and `out_layouts` everywhere.
Note that public api is `Layout(device_local_layout, sharding)` which is how users will pass us the Layout but internally we split them apart into device_local_layout and sharding.
Docs are coming up on how to use the API and what Layouts mean and how to make sense of them (especially on TPU).
PiperOrigin-RevId: 622352537
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.
Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
The canonicalization doesn't provide any value anymore and only makes the internals more complicated.
The canonicalization can be done by lowering to HloSharding in places where required and there are utilities to help with that.
PiperOrigin-RevId: 619292757
Before:
```
TypeError: Argument 'ShapeDtypeStruct(shape=(4, 2), dtype=int32)' of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type.
```
After:
```
TypeError: Argument 'x['b']['c']' of shape int32[4,2] of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type.
```
The error is raised deep down the stack during `shard_arg`, so we raise an `InvalidInputException` and catch it in `_python_pjit_helper` where we have the `arg_names` information.
PiperOrigin-RevId: 618014044
Also add a copy of the default registry that doesn't have None registered as a leaf, which is slightly faster than using an is_leaf function.
This is mostly just doing an old TODO.
PiperOrigin-RevId: 617988496
Do it once when the jit is constructed.
(In general we do a bit too much switching back and forth between flattened and unflattened representations, and we'd probably do well just to keep things flattened.)
PiperOrigin-RevId: 617859205
We call inspect.signature() once for debug information and once for argnum resolving. We can just call it once and reuse the result.
PiperOrigin-RevId: 617824439
Notably:
* We can share more code between jit/pjit. There's no significant difference between the two, other than the handling of the resource environment, so we can share more of the code.
* Rather than having an infer_params callback, we can just teach common_infer_params (now named _infer_params) to handle the resource environment, which is the only meaningful difference. common_infer_params already had to understand the two cases, so there's no reason we need to hoist part of that logic into a callback.
* If we slightly alter the role of PjitInfo so it contains only the things we know about a jit() or can deduce from its arguments, we can construct it ahead of time. This does require that we split out a couple of things that we cannot deduce at that time, namely the resource environment and the two layout parameters into separate arguments, but the result reads more cleanly to me.
No functional changes intended, this is just to improve readability.
PiperOrigin-RevId: 617812557
The only caller of `physical_op_sharding` outside of TyRules was mlir.py. This CL also changes lower_jaxpr_to_fun to only accept logical arg_shardings and result_shardings which are XLACompatiableShardings.
PiperOrigin-RevId: 616267810
Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch.
PiperOrigin-RevId: 613289252
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.