When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.
In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.
```
name old cpu/op new cpu/op delta
jit_add_chain 59.1ms ±14% 49.4ms ±10% -16.32% (p=0.008 n=5+5)
name old time/op new time/op delta
jit_add_chain 60.3ms ±14% 50.7ms ±11% -15.99% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 645491650
The motivation for doing this is 2-fold:
1) This will help with deprecating and eventually deleting `jax.xla_computation` which allows for cross backend lowering.
2) Allow for cross-backend and multi-backend lowering via jax AOT APIs which will help cleanup some hacks implemented for `jax.export`.
Note that this is only available by `.trace.lower(lowering_platforms=('tpu',))`. You cannot use `.lower` to do cross-lowering. We can introduce top-level APIs in the future to allow for composable aot apis to make this easier if `.trace(*args).lower(lowering_platforms)` is cumbersome to write.
Designed with @froystig!
PiperOrigin-RevId: 644087787
With this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096) will batch device-to-device transfers across arrays in a pytree.
The api_benchmark indicates that this CL makes `device_put` with 10 to 1000 arrays ~30% faster, likely because it reduces the number of `device_put_p.bind()` calls.
PiperOrigin-RevId: 644051624
This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so.
Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays.
PiperOrigin-RevId: 643097852
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't.
Since we can keep the existing behavior and still merge the implementation is a good cleanup!
Fixes https://github.com/google/jax/issues/21116
PiperOrigin-RevId: 641347140
These changes are necessary to ensure that `Lowered` carries all the
information that is needed for export and serialization.
These are in preparation of a cleanup of the exporting and serialization APIs
to integrate them with the AOT APIs. In particular, exporting will start
with a `Lowered` object and will not include anymore its own lowering code.
We add the lowered function name and the Jaxpr (as the attributes `_fun_name` and `_jaxpr`)
to `Lowered`,
and we add the tuple of lowering platforms (as `Lowered._lowering._platforms`).
The function name is useful for better error messages when exporting and
serializating. The Jaxpr is useful for exporting also the VJP of the function
and obtaining an `Exported` that can be differentiated.
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.
Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
Notably:
* We can share more code between jit/pjit. There's no significant difference between the two, other than the handling of the resource environment, so we can share more of the code.
* Rather than having an infer_params callback, we can just teach common_infer_params (now named _infer_params) to handle the resource environment, which is the only meaningful difference. common_infer_params already had to understand the two cases, so there's no reason we need to hoist part of that logic into a callback.
* If we slightly alter the role of PjitInfo so it contains only the things we know about a jit() or can deduce from its arguments, we can construct it ahead of time. This does require that we split out a couple of things that we cannot deduce at that time, namely the resource environment and the two layout parameters into separate arguments, but the result reads more cleanly to me.
No functional changes intended, this is just to improve readability.
PiperOrigin-RevId: 617812557
This allows propagating the names bottom up -- from equations to the jaxpr,
instead of "discovering" them top-down by traversing (and rebuilding) the
jaxpr via core.subst_axis_names.
PiperOrigin-RevId: 612416803
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
This comes up in LLM models, where we trace twice (one for eval_shape (usually the init function) and another during jit) when the output jaxpr is the same. This shouldn't happen and we should cache as much as possible.
The only caveat here is that in eval_shape the `traced_for` on `DebugInfo` is set to `jit`. But maybe it's ok to do that if we want to deprecate eval_shape for a AOT style method on `jax.jit` or have it be a thin wrapper around something like `jax.jit(f).eval_shape`
PiperOrigin-RevId: 599602407
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).
If your code ends up taking the python dispatch, then something is going wrong anyways.
PiperOrigin-RevId: 596081987
Docstring states:
> If the pmapped function is called with fewer positional arguments than indicated by **`static_argnums`** then an error is raised.
However `static_argnums` is not an argument that exists - I believe this should be corrected to `static_broadcasted_argnums`.
PiperOrigin-RevId: 595731210