464 Commits

Author SHA1 Message Date
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Peter Hawkins
9e30079dba [JAX] Add caching to pjit._infer_params.
When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.

In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.

```
name           old cpu/op   new cpu/op   delta
jit_add_chain  59.1ms ±14%  49.4ms ±10%  -16.32%  (p=0.008 n=5+5)

name           old time/op          new time/op          delta
jit_add_chain  60.3ms ±14%          50.7ms ±11%  -15.99%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 645491650
2024-06-21 13:53:04 -07:00
Yash Katariya
6ba16e0348 Add lowering_platforms to traced.lower() to allow lowering to different backends and multi-backend lowering too. In other words, enable cross-lowering!
The motivation for doing this is 2-fold:

1) This will help with deprecating and eventually deleting `jax.xla_computation` which allows for cross backend lowering.

2) Allow for cross-backend and multi-backend lowering via jax AOT APIs which will help cleanup some hacks implemented for `jax.export`.

Note that this is only available by `.trace.lower(lowering_platforms=('tpu',))`. You cannot use `.lower` to do cross-lowering. We can introduce top-level APIs in the future to allow for composable aot apis to make this easier if `.trace(*args).lower(lowering_platforms)` is cumbersome to write.

Designed with @froystig!

PiperOrigin-RevId: 644087787
2024-06-17 11:59:10 -07:00
Junwhan Ahn
cec796f5dc Batch pxla.shard_args calls triggered by jax.device_put
With this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096) will batch device-to-device transfers across arrays in a pytree.

The api_benchmark indicates that this CL makes `device_put` with 10 to 1000 arrays ~30% faster, likely because it reduces the number of `device_put_p.bind()` calls.

PiperOrigin-RevId: 644051624
2024-06-17 10:17:25 -07:00
Junwhan Ahn
5046cedbfc Make pxla.shard_arg batch calls to xc.copy_array_to_devices_with_sharding
This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so.

Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays.

PiperOrigin-RevId: 643097852
2024-06-13 13:10:10 -07:00
Yash Katariya
6c34a56b87 Add util.cache to jax.clear_caches and move pjit, sharding, array, etc uses of functools.lru_cache to util.cache so that those caches will be cleared if jax.clear_caches is called.
PiperOrigin-RevId: 642359226
2024-06-11 12:46:47 -07:00
Yash Katariya
956226c929 Raise an error if device_put sees an invalid value.
PiperOrigin-RevId: 642053543
2024-06-10 16:07:44 -07:00
Yash Katariya
44a13c9d4b Merge code between make_jaxpr and jit(f).trace.
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't.

Since we can keep the existing behavior and still merge the implementation is a good cleanup!

Fixes https://github.com/google/jax/issues/21116

PiperOrigin-RevId: 641347140
2024-06-07 13:48:31 -07:00
Roy Frostig
ea6dfd1947 rename Specialized to Traced (and specialize to trace)
PiperOrigin-RevId: 641076488
2024-06-06 17:43:08 -07:00
Yash Katariya
aee62e4874 Implement lower in terms of specialize
PiperOrigin-RevId: 641005643
2024-06-06 13:39:07 -07:00
Yash Katariya
fbf2a62aa1 Remove jaxpr and name from Lowered because specialize already has those. This keeps the abstraction boundary clear. Adapt export to use specialize.
PiperOrigin-RevId: 640968129
2024-06-06 11:38:56 -07:00
Yash Katariya
1edd649de4 Deprecate XLACompatibleSharding in favor of jax.sharding.Sharding.
PiperOrigin-RevId: 640544939
2024-06-05 09:07:27 -07:00
Matthew Johnson
10d285dea7 fix error message for vjp arguments 2024-05-30 21:22:35 +00:00
Yash Katariya
73a8e7f0a8 Rename is_compatible_aval to check_compatible_aval since it returns None and not a bool.
PiperOrigin-RevId: 638431968
2024-05-29 15:29:32 -07:00
George Necula
bb4c073574 Add the function name, the Jaxpr, and lowering platforms to Lowered.
These changes are necessary to ensure that `Lowered` carries all the
information that is needed for export and serialization.
These are in preparation of a cleanup of the exporting and serialization APIs
to integrate them with the AOT APIs. In particular, exporting will start
with a `Lowered` object and will not include anymore its own lowering code.

We add the lowered function name and the Jaxpr (as the attributes `_fun_name` and `_jaxpr`)
to `Lowered`,
and we add the tuple of lowering platforms (as `Lowered._lowering._platforms`).

The function name is useful for better error messages when exporting and
serializating. The Jaxpr is useful for exporting also the VJP of the function
and obtaining an `Exported` that can be differentiated.
2024-05-29 05:04:17 +03:00
jax authors
b5583742b5 Merge pull request #21273 from superbobry:mypy-ruff
PiperOrigin-RevId: 636146344
2024-05-22 06:35:38 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Yash Katariya
96f888bcfe Reverts 1956ff7d7b73794012fece2d8452e097196587fc
PiperOrigin-RevId: 631974751
2024-05-08 17:23:13 -07:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
Lianmin Zheng
0eed28a010
Fix a typo in jax.jit docstring 2024-05-06 04:59:23 -07:00
Yash Katariya
1956ff7d7b Add specialize on jax.jit so that we can delete the duplicate code in jax.make_jaxpr.
You can now do (in addition to make_jaxpr): `jax.jit(f).specialize(*args, **kwargs) -> stages.Specialized`

PiperOrigin-RevId: 628748620
2024-04-27 18:58:16 -07:00
Yash Katariya
8239674dab Replace donation_vector's logic with donation_vector_with_in_tree which is now deleted
PiperOrigin-RevId: 627556267
2024-04-23 17:38:30 -07:00
jax authors
1f4c31d0af Merge pull request #20849 from mattjj:jit-docstring-tweaks
PiperOrigin-RevId: 627167535
2024-04-22 15:05:38 -07:00
Yash Katariya
1837b436d7 Merge some loops in device_put since it's trivial to do so
PiperOrigin-RevId: 626546322
2024-04-19 20:59:55 -07:00
Matthew Johnson
b8df23c25b tweak jit docstring 2024-04-19 17:37:52 -07:00
Yash Katariya
837f0bbf6f Cache the _check_sharding check in device_put. If aval and sharding are the same, no need to check multiple times
PiperOrigin-RevId: 626244240
2024-04-18 21:26:35 -07:00
Yue Sheng
c2d4373535 Make core.Token a non-trivial class which wraps a jax.Array. Currently, we use a singleton and empty core.token object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).

PiperOrigin-RevId: 626091210
2024-04-18 11:09:55 -07:00
jax authors
bb8cf34a31 Document the fact that jax.clear_caches() doesn't affect the persistent cache.
PiperOrigin-RevId: 626019057
2024-04-18 06:52:40 -07:00
Yash Katariya
90401d51e9 Accept layout on ShapeDtypeStruct on the sharding argument. DeviceLocalLayout.AUTO is not allowed on SDS.
PiperOrigin-RevId: 624982814
2024-04-15 09:19:40 -07:00
Matthew Johnson
83a200a42f simple fix to make_jaxpr docstring
maybe it was accidentally copied from xla_computation before?
2024-04-11 21:50:51 -07:00
Sai-Suraj-27
5564521308 Prefer raising of TypeError for invalid types instead of ValueError. 2024-04-08 13:08:24 +05:30
Yash Katariya
92326dbc71 Expose Layout(device_local_layout, sharding) class allowing users to specify layouts of Arrays.
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.

Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
2024-04-03 16:13:31 -07:00
Peter Hawkins
d3e03fff5d Refactorings to the jit implementation.
Notably:
* We can share more code between jit/pjit. There's no significant difference between the two, other than the handling of the resource environment, so we can share more of the code.
* Rather than having an infer_params callback, we can just teach common_infer_params (now named _infer_params) to handle the resource environment, which is the only meaningful difference. common_infer_params already had to understand the two cases, so there's no reason we need to hoist part of that logic into a callback.
* If we slightly alter the role of PjitInfo so it contains only the things we know about a jit() or can deduce from its arguments, we can construct it ahead of time. This does require that we split out a couple of things that we cannot deduce at that time, namely the resource environment and the two layout parameters into separate arguments, but the result reads more cleanly to me.

No functional changes intended, this is just to improve readability.

PiperOrigin-RevId: 617812557
2024-03-21 05:37:32 -07:00
Junwhan Ahn
f569031456 Reverts 55394a0914dc0583427a4ceb73dac56348911d15
PiperOrigin-RevId: 616201321
2024-03-15 11:56:45 -07:00
Peter Hawkins
642f20de1c [JAX] Convert stablehlo to MLIR bytecode, not an MLIR string.
Bytecode is considerably more compact.

PiperOrigin-RevId: 615386276
2024-03-13 06:02:18 -07:00
Roy Frostig
98f790f5d5 update package/API reference docs to new-style typed PRNG keys 2024-03-07 12:40:09 -08:00
Sergei Lebedev
5283d4b4a5 Axis names are now tracked via an effect
This allows propagating the names bottom up -- from equations to the jaxpr,
instead of "discovering" them top-down by traversing (and rebuilding) the
jaxpr via core.subst_axis_names.

PiperOrigin-RevId: 612416803
2024-03-04 05:42:03 -08:00
Matthew Johnson
3736b322b7 [xmap-removal] remove reduce_axes from grad / vjp / backward_pass
The reduce_axes machinery was planned to be used for xmap. It's not needed for
e.g. shard_map, see https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html.
2024-02-25 15:50:54 -08:00
Matthew Johnson
b0b88d87d3 [attrs] add linearize and vjp support 2024-02-23 16:43:49 -08:00
Peter Hawkins
b5e4ba4900 Don't call inspect.signature() each time we trace a jit().
We can just call it once when jit itself is called.

While we're here, also don't recompute api_util.fun_sourceinfo.

PiperOrigin-RevId: 607443283
2024-02-15 13:49:27 -08:00
Jake VanderPlas
82611eb8ae document that under disable_jit, individual primitives are still compiled 2024-02-05 12:01:33 -08:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
Peter Hawkins
fc6df3218c Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.

i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.

Why do this?

The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.

The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.

This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.

Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.

The change is disabled by default, so we do not expect any user visible impacts from this change.

PiperOrigin-RevId: 599787818
2024-01-19 03:53:37 -08:00
Yash Katariya
f04f305489 Make eval_shape a wrapper around jax.jit(f).eval_shape(*args, **kwargs)
PiperOrigin-RevId: 599724490
2024-01-18 22:10:57 -08:00
Matthew Johnson
0d4f200e08 Allow unhashable callables in jax.eval_shape.
PiperOrigin-RevId: 599691923
2024-01-18 19:16:48 -08:00
Yash Katariya
51ef738c86 Use jit's jaxpr creation function for eval_shape to maximize tracing cache hits.
This comes up in LLM models, where we trace twice (one for eval_shape (usually the init function) and another during jit) when the output jaxpr is the same. This shouldn't happen and we should cache as much as possible.

The only caveat here is that in eval_shape the `traced_for` on `DebugInfo` is set to `jit`. But maybe it's ok to do that if we want to deprecate eval_shape for a AOT style method on `jax.jit` or have it be a thin wrapper around something like `jax.jit(f).eval_shape`

PiperOrigin-RevId: 599602407
2024-01-18 13:11:44 -08:00
Yash Katariya
b8098b1782 Remove indices and devices from shard_arg_handlers and shard_args.
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).

If your code ends up taking the python dispatch, then something is going wrong anyways.

PiperOrigin-RevId: 596081987
2024-01-05 14:17:14 -08:00
Tom Cobley
ebc7af95df Fix typo in pmap docstring
Docstring states:
>  If the pmapped function is called with fewer positional arguments than indicated by **`static_argnums`** then an error is raised.

However `static_argnums` is not an argument that exists - I believe this should be corrected to `static_broadcasted_argnums`.

PiperOrigin-RevId: 595731210
2024-01-04 09:50:00 -08:00
Axel Donath
e8330b5fc5 Add eval_shape example for function with static arguments
Improve wording and formating of dynamic eval_shape example
2023-12-18 19:19:48 -05:00
Sergei Lebedev
36f6b52e42 Upgrade most .py sources to 3.9
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00