238 Commits

Author SHA1 Message Date
Yash Katariya
5cbb26f36d Make device_local_layout and sharding optional in Layout. Also only accept Layout class to _in_layouts and _out_layouts.
This is in preparation to get `jax.jit` to accept `Layout`.

PiperOrigin-RevId: 621697750
2024-04-03 18:37:32 -07:00
Yash Katariya
92326dbc71 Expose Layout(device_local_layout, sharding) class allowing users to specify layouts of Arrays.
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.

Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
2024-04-03 16:13:31 -07:00
jax authors
00489be23d Fix a bug where exceptions were thrown in debug message formatting, when sharding was set to None on arrays.
PiperOrigin-RevId: 621193460
2024-04-02 08:56:37 -07:00
Yash Katariya
6e0c95585a Remove the canonicalization to GSPMDSharding internally in jit. This is not required anymore since the caches are split into tracing, lowering and compilation.
The canonicalization doesn't provide any value anymore and only makes the internals more complicated.

The canonicalization can be done by lowering to HloSharding in places where required and there are utilities to help with that.

PiperOrigin-RevId: 619292757
2024-03-26 13:28:45 -07:00
jax authors
e3bbd670bc Avoid jax_explain_cache_misses unpacking error.
PiperOrigin-RevId: 618931412
2024-03-25 12:55:00 -07:00
Yash Katariya
25d01e983c [Take 2] Expose .layout on jax.Array. Also add checks in the AOT path to make sure that the input Array's layout matches the layout given to jax.jit.
Reverts cd79e71d85621a8d6dede9a710bdb2a29bb380fd

PiperOrigin-RevId: 618878870
2024-03-25 10:08:43 -07:00
Jake VanderPlas
8949a63ce1 [key reuse] rename flag to jax_debug_key_reuse 2024-03-22 05:37:30 -07:00
jax authors
cd79e71d85 Reverts 0e092a77067dbbce33cfd6d54a46e743b779919b
PiperOrigin-RevId: 618127324
2024-03-22 03:46:09 -07:00
Yash Katariya
0e092a7706 Expose .layout on jax.Array. Also add checks in the AOT path to make sure that the input Array's layout matches the layout given to jax.jit.
PiperOrigin-RevId: 618050680
2024-03-21 21:02:40 -07:00
Yash Katariya
d57bb8c748 Raise a better error message when an invalid input is passed to jit call.
Before:

```
TypeError: Argument 'ShapeDtypeStruct(shape=(4, 2), dtype=int32)' of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type.

```

After:

```
TypeError: Argument 'x['b']['c']' of shape int32[4,2] of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type.

```

The error is raised deep down the stack during `shard_arg`, so we raise an `InvalidInputException` and catch it in `_python_pjit_helper` where we have the `arg_names` information.

PiperOrigin-RevId: 618014044
2024-03-21 17:46:32 -07:00
Peter Hawkins
5532e5505b [XLA:Python] Add a C++ implementation of flatten_one_level.
Also add a copy of the default registry that doesn't have None registered as a leaf, which is slightly faster than using an is_leaf function.

This is mostly just doing an old TODO.

PiperOrigin-RevId: 617988496
2024-03-21 15:57:23 -07:00
Peter Hawkins
54d8bde057 Don't tree_flatten in_shardings and out_shardings each time a jit() is traced.
Do it once when the jit is constructed.

(In general we do a bit too much switching back and forth between flattened and unflattened representations, and we'd probably do well just to keep things flattened.)

PiperOrigin-RevId: 617859205
2024-03-21 09:00:16 -07:00
Yash Katariya
dd574cbc74 Remove _python_pjit and make _cpp_pjit the only function wrapper.
PiperOrigin-RevId: 617846352
2024-03-21 08:18:42 -07:00
Peter Hawkins
79b18948c3 Only call inspect.signature once during the initial call to jit().
We call inspect.signature() once for debug information and once for argnum resolving. We can just call it once and reuse the result.

PiperOrigin-RevId: 617824439
2024-03-21 06:36:07 -07:00
Peter Hawkins
d3e03fff5d Refactorings to the jit implementation.
Notably:
* We can share more code between jit/pjit. There's no significant difference between the two, other than the handling of the resource environment, so we can share more of the code.
* Rather than having an infer_params callback, we can just teach common_infer_params (now named _infer_params) to handle the resource environment, which is the only meaningful difference. common_infer_params already had to understand the two cases, so there's no reason we need to hoist part of that logic into a callback.
* If we slightly alter the role of PjitInfo so it contains only the things we know about a jit() or can deduce from its arguments, we can construct it ahead of time. This does require that we split out a couple of things that we cannot deduce at that time, namely the resource environment and the two layout parameters into separate arguments, but the result reads more cleanly to me.

No functional changes intended, this is just to improve readability.

PiperOrigin-RevId: 617812557
2024-03-21 05:37:32 -07:00
Yash Katariya
cd1e55a351 Remove physical_hlo_sharding from TyRules.
The only caller of `physical_op_sharding` outside of TyRules was mlir.py. This CL also changes lower_jaxpr_to_fun to only accept logical arg_shardings and result_shardings which are XLACompatiableShardings.

PiperOrigin-RevId: 616267810
2024-03-15 16:02:13 -07:00
Matthew Johnson
8c2f6b3e8c re-enable pjit forwarding optimization, add tests 2024-03-15 14:06:35 -07:00
Matthew Johnson
8a7c604aa7 disable optimization 2024-03-15 10:35:08 -07:00
Matthew Johnson
c515f15e01 fix residual forwarding bug, fixes #20267 2024-03-15 10:07:29 -07:00
Matthew Johnson
649cd50681 [mutable-arrays] support closed-over mutable arrays in jit 2024-03-13 09:59:03 -07:00
Yash Katariya
1cb8d31c66 Convert in_shardings to physical shardings in cpp dispatch path because the same happens with prng arrays.
Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch.

PiperOrigin-RevId: 613289252
2024-03-06 11:42:40 -08:00
Yash Katariya
ca3f3f0f17 Make sure that if gspmd_sharding1 == gspmd_sharding2, then their hash also is equal.
PiperOrigin-RevId: 613009976
2024-03-05 16:36:49 -08:00
Matthew Johnson
ab0f7061ad [mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others

The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
   handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
   refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.

As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-29 21:50:19 -08:00
Jake VanderPlas
d08e9a03d8 [key reuse] add eager checks 2024-02-29 15:30:19 -08:00
Yash Katariya
217f08236e Allow sharding propagation to input for prng keys whose sharding is not specified.
Convert shardings returned by XLA (when propagation is on for input and output) for extended dtypes to user shardings which allows to remove `are_out_shardings_from_xla`.

PiperOrigin-RevId: 611246986
2024-02-28 15:22:16 -08:00
Jake VanderPlas
85f205bdc7 typing: fix incorrect tuple annotations 2024-02-26 10:53:19 -08:00
Matthew Johnson
3736b322b7 [xmap-removal] remove reduce_axes from grad / vjp / backward_pass
The reduce_axes machinery was planned to be used for xmap. It's not needed for
e.g. shard_map, see https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html.
2024-02-25 15:50:54 -08:00
Peter Hawkins
f1ea67117e Split name_stack out of mlir.ModuleContext.
A unique name_stack is built for every equation, which means that we're constantly rebuilding ModuleContext objects, even though the lifetime of almost everything else (naturally) is the Module scope. Split name_stack into an object that is threaded separately, including as part of mlir.LoweringRuleContext.

PiperOrigin-RevId: 608594374
2024-02-20 07:17:23 -08:00
Peter Hawkins
b5e4ba4900 Don't call inspect.signature() each time we trace a jit().
We can just call it once when jit itself is called.

While we're here, also don't recompute api_util.fun_sourceinfo.

PiperOrigin-RevId: 607443283
2024-02-15 13:49:27 -08:00
Peter Hawkins
885e8a2311 Don't recompute abstract eval rules when inlining a jit jaxpr.
The current implementation of jit inlining uses core.eval_jaxpr() and retraces the subjaxpr. This ends up performing abstract evaluation a second time. Instead, write a direct implementation of inlining that doesn't use the tracing machinery.

PiperOrigin-RevId: 607418006
2024-02-15 12:28:48 -08:00
Matthew Johnson
a45cc437f4 [attrs] allow passing a jax-attrs object to jit functions
currently we don't get any interesting cache hits; only on object identity
match
2024-02-13 16:53:46 -08:00
jax authors
80d23d64cd Merge pull request #19566 from mattjj:attrs-aqt
PiperOrigin-RevId: 602864008
2024-01-30 15:51:00 -08:00
Matthew Johnson
6c2d9c7e3a add getstate/setstate in pjit transpose, for bwd pass effects
Co-authored-by: Roy Frostig <frostig@google.com>
2024-01-29 20:03:11 -08:00
Yash Katariya
d9122b8bac Add sharding to ShapeDtypeStruct retured by eval_shape if jit has out_shardings specified
PiperOrigin-RevId: 602556016
2024-01-29 18:02:51 -08:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
Yash Katariya
a63197fed8 Add an internal _device_list parameter to GSPMDSharding so that we can save on the initialization cost of PyDeviceList when creating GSPMDSharding from other shardings
PiperOrigin-RevId: 601055733
2024-01-24 02:29:22 -08:00
Sergei Lebedev
46f796b38d Dedupe shardings before passing them to _get_and_check_device_assignment
In practice, the number of different shardings is usually much smaller then
the number of inputs/output.

PiperOrigin-RevId: 600558309
2024-01-22 13:45:20 -08:00
Yash Katariya
3a0b495faa Internal change
PiperOrigin-RevId: 600007054
2024-01-19 20:35:07 -08:00
Yash Katariya
f04f305489 Make eval_shape a wrapper around jax.jit(f).eval_shape(*args, **kwargs)
PiperOrigin-RevId: 599724490
2024-01-18 22:10:57 -08:00
Yash Katariya
51ef738c86 Use jit's jaxpr creation function for eval_shape to maximize tracing cache hits.
This comes up in LLM models, where we trace twice (one for eval_shape (usually the init function) and another during jit) when the output jaxpr is the same. This shouldn't happen and we should cache as much as possible.

The only caveat here is that in eval_shape the `traced_for` on `DebugInfo` is set to `jit`. But maybe it's ok to do that if we want to deprecate eval_shape for a AOT style method on `jax.jit` or have it be a thin wrapper around something like `jax.jit(f).eval_shape`

PiperOrigin-RevId: 599602407
2024-01-18 13:11:44 -08:00
Yash Katariya
b8098b1782 Remove indices and devices from shard_arg_handlers and shard_args.
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).

If your code ends up taking the python dispatch, then something is going wrong anyways.

PiperOrigin-RevId: 596081987
2024-01-05 14:17:14 -08:00
Matthew Johnson
9112dcebc9 add jax.explain_cache_misses tracing cache miss explanations
As part of making JAX's behavior more transparent, it must be clear not only
when code is slow because it's spending all its time missing caches (and hence
retracing/recompiling), but also _why_ it missed those caches. That is, just
knowing (from e.g. setting jax_log_compiles) that code is retracing a lot
doesn't tell the user what to do to fix things. But once the user knows that
the cache misses are due to changing dtypes, or due to jit being passed a new
callable object on every iteration of a loop, it's often clear what to do. And
JAX can provide that information

The main idea here is that pointing out which parts of the cache key differs
from previously-seen keys can constitute a pretty good explanation.

This PR adds an explanation mechanism. It can be enabled in a few different ways:
  * setting the `JAX_EXPLAIN_CACHE_MISSES` shell environment variable to something truthy;
  * setting the config option `jax.config.update('jax_explain_cache_misses', True)`;
  * using the context manager `jax._src.config.explain_cache_misses` context
    manager (not in public namespace yet);
  * when parsing command line flags with absl, using the
    `--jax_explain_cache_misses` flag.

Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-12-26 21:54:27 -08:00
Yash Katariya
9792e00887 Cleanup _find_arg_mismatch logic
PiperOrigin-RevId: 592697969
2023-12-20 17:24:26 -08:00
Yash Katariya
90e47fbc6d Always flatten args and kwargs together i.e. tree_flatten((args, kwargs)) so that we have a uniform in_tree structure everywhere.
Leads to a code cleanup and more standardization in jit.

PiperOrigin-RevId: 592388438
2023-12-19 17:32:07 -08:00
Yash Katariya
9b6bf2cab0 Call shard_arg fallback in pjit's cpp fast path instead of dropping out completely.
PiperOrigin-RevId: 592344105
2023-12-19 14:26:01 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
jax authors
616f4d29bb Merge pull request #18888 from superbobry:pp-improvement
PiperOrigin-RevId: 590269555
2023-12-12 11:12:42 -08:00
Sergei Lebedev
840abfb7ab The pretty printer now de-duplicates identical jaxprs
This compresses the output e.g. when a jitted function is called repeatedly
in a Python loop.
2023-12-12 17:14:43 +00:00
Jake VanderPlas
a52d18781e Add experimental static key reuse checking 2023-12-11 12:03:48 -08:00
Yash Katariya
5fb8ceca73 Make lowering oblivious to real physical devices. Instead cache lowering on HloSharding only (which is based on logical device numbers)
Make an exception for callbacks and custom_partitioning because they need access to device_assignment during lowering.

PiperOrigin-RevId: 589244695
2023-12-08 14:36:09 -08:00