Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.
Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
The canonicalization doesn't provide any value anymore and only makes the internals more complicated.
The canonicalization can be done by lowering to HloSharding in places where required and there are utilities to help with that.
PiperOrigin-RevId: 619292757
Before:
```
TypeError: Argument 'ShapeDtypeStruct(shape=(4, 2), dtype=int32)' of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type.
```
After:
```
TypeError: Argument 'x['b']['c']' of shape int32[4,2] of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type.
```
The error is raised deep down the stack during `shard_arg`, so we raise an `InvalidInputException` and catch it in `_python_pjit_helper` where we have the `arg_names` information.
PiperOrigin-RevId: 618014044
Also add a copy of the default registry that doesn't have None registered as a leaf, which is slightly faster than using an is_leaf function.
This is mostly just doing an old TODO.
PiperOrigin-RevId: 617988496
Do it once when the jit is constructed.
(In general we do a bit too much switching back and forth between flattened and unflattened representations, and we'd probably do well just to keep things flattened.)
PiperOrigin-RevId: 617859205
We call inspect.signature() once for debug information and once for argnum resolving. We can just call it once and reuse the result.
PiperOrigin-RevId: 617824439
Notably:
* We can share more code between jit/pjit. There's no significant difference between the two, other than the handling of the resource environment, so we can share more of the code.
* Rather than having an infer_params callback, we can just teach common_infer_params (now named _infer_params) to handle the resource environment, which is the only meaningful difference. common_infer_params already had to understand the two cases, so there's no reason we need to hoist part of that logic into a callback.
* If we slightly alter the role of PjitInfo so it contains only the things we know about a jit() or can deduce from its arguments, we can construct it ahead of time. This does require that we split out a couple of things that we cannot deduce at that time, namely the resource environment and the two layout parameters into separate arguments, but the result reads more cleanly to me.
No functional changes intended, this is just to improve readability.
PiperOrigin-RevId: 617812557
The only caller of `physical_op_sharding` outside of TyRules was mlir.py. This CL also changes lower_jaxpr_to_fun to only accept logical arg_shardings and result_shardings which are XLACompatiableShardings.
PiperOrigin-RevId: 616267810
Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch.
PiperOrigin-RevId: 613289252
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
Convert shardings returned by XLA (when propagation is on for input and output) for extended dtypes to user shardings which allows to remove `are_out_shardings_from_xla`.
PiperOrigin-RevId: 611246986
A unique name_stack is built for every equation, which means that we're constantly rebuilding ModuleContext objects, even though the lifetime of almost everything else (naturally) is the Module scope. Split name_stack into an object that is threaded separately, including as part of mlir.LoweringRuleContext.
PiperOrigin-RevId: 608594374
The current implementation of jit inlining uses core.eval_jaxpr() and retraces the subjaxpr. This ends up performing abstract evaluation a second time. Instead, write a direct implementation of inlining that doesn't use the tracing machinery.
PiperOrigin-RevId: 607418006
This comes up in LLM models, where we trace twice (one for eval_shape (usually the init function) and another during jit) when the output jaxpr is the same. This shouldn't happen and we should cache as much as possible.
The only caveat here is that in eval_shape the `traced_for` on `DebugInfo` is set to `jit`. But maybe it's ok to do that if we want to deprecate eval_shape for a AOT style method on `jax.jit` or have it be a thin wrapper around something like `jax.jit(f).eval_shape`
PiperOrigin-RevId: 599602407
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).
If your code ends up taking the python dispatch, then something is going wrong anyways.
PiperOrigin-RevId: 596081987
As part of making JAX's behavior more transparent, it must be clear not only
when code is slow because it's spending all its time missing caches (and hence
retracing/recompiling), but also _why_ it missed those caches. That is, just
knowing (from e.g. setting jax_log_compiles) that code is retracing a lot
doesn't tell the user what to do to fix things. But once the user knows that
the cache misses are due to changing dtypes, or due to jit being passed a new
callable object on every iteration of a loop, it's often clear what to do. And
JAX can provide that information
The main idea here is that pointing out which parts of the cache key differs
from previously-seen keys can constitute a pretty good explanation.
This PR adds an explanation mechanism. It can be enabled in a few different ways:
* setting the `JAX_EXPLAIN_CACHE_MISSES` shell environment variable to something truthy;
* setting the config option `jax.config.update('jax_explain_cache_misses', True)`;
* using the context manager `jax._src.config.explain_cache_misses` context
manager (not in public namespace yet);
* when parsing command line flags with absl, using the
`--jax_explain_cache_misses` flag.
Co-authored-by: Yash Katariya <yashkatariya@google.com>
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py