Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.
Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
Notably:
* We can share more code between jit/pjit. There's no significant difference between the two, other than the handling of the resource environment, so we can share more of the code.
* Rather than having an infer_params callback, we can just teach common_infer_params (now named _infer_params) to handle the resource environment, which is the only meaningful difference. common_infer_params already had to understand the two cases, so there's no reason we need to hoist part of that logic into a callback.
* If we slightly alter the role of PjitInfo so it contains only the things we know about a jit() or can deduce from its arguments, we can construct it ahead of time. This does require that we split out a couple of things that we cannot deduce at that time, namely the resource environment and the two layout parameters into separate arguments, but the result reads more cleanly to me.
No functional changes intended, this is just to improve readability.
PiperOrigin-RevId: 617812557
This allows propagating the names bottom up -- from equations to the jaxpr,
instead of "discovering" them top-down by traversing (and rebuilding) the
jaxpr via core.subst_axis_names.
PiperOrigin-RevId: 612416803
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
This comes up in LLM models, where we trace twice (one for eval_shape (usually the init function) and another during jit) when the output jaxpr is the same. This shouldn't happen and we should cache as much as possible.
The only caveat here is that in eval_shape the `traced_for` on `DebugInfo` is set to `jit`. But maybe it's ok to do that if we want to deprecate eval_shape for a AOT style method on `jax.jit` or have it be a thin wrapper around something like `jax.jit(f).eval_shape`
PiperOrigin-RevId: 599602407
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).
If your code ends up taking the python dispatch, then something is going wrong anyways.
PiperOrigin-RevId: 596081987
Docstring states:
> If the pmapped function is called with fewer positional arguments than indicated by **`static_argnums`** then an error is raised.
However `static_argnums` is not an argument that exists - I believe this should be corrected to `static_broadcasted_argnums`.
PiperOrigin-RevId: 595731210
We expose 3 modes:
* `SpecifiedLayout`: User specifies the `minor_to_major` field of the layout. Tiling not exposed yet.
* `DefaultLayout`: PJRT chooses the layout. It defaults to the current behavior.
* `AUTO`: Compiler chooses the layout. This field is not a layout per se. It's a request to get the layout from the compiler. This field cannot be on an Array or other data types. It can only be on jit.
Public API coming soon.
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 582692036
The current implementation synchronously calls `ArrayImpl.block_until_ready()` one by one. This is suboptimal when it's not cheap to query the readiness of an array. Also, calling `x.block_until_ready()` causes GIL to be acquired/released repeatedly.
To address this issue, this CL introduces a C++ implementation of `jax.block_until_ready(x)` that uses IFRT's `Array::GetReadyFuture()` to asynchronously query the readiness of all arrays and wait for them once. To preserve the previous behavior, the C++ implementation also has a slow path for any non-PyArray objects that implement `block_until_ready`.
PiperOrigin-RevId: 581302290
Previously, we introduced support for multi-platform lowering, by
adding a new LoweringParameters object that can be used to specify
a cross-lowering platform or even multiple platforms. But we had
kept the ModuleContext.platform in place because some lowering rules
were still referencing it. Now we replace ModuleContext.platform with
ModuleContext.platforms, which removes the redundancy, simplifies
the code, and makes it clearer that the lowering rules should not
simply assume single-platform lowering.
PiperOrigin-RevId: 576575376