480 Commits

Author SHA1 Message Date
Yash Katariya
de9b98e0a8 Delete jax.xla_computation since it's been 3 months since it was deprecated.
PiperOrigin-RevId: 673938336
2024-09-12 11:47:38 -07:00
jax authors
4957ab9a5e Clean up JAX backend for all backends to avoid dangling PyClient references.
PiperOrigin-RevId: 673102539
2024-09-10 14:19:00 -07:00
Yash Katariya
b615266175 Reverts 82c9da020a78997862a8f7ccd494bed363f7ed01
PiperOrigin-RevId: 668969133
2024-08-29 09:43:19 -07:00
Yash Katariya
dd6f0e2e2e Add weak_type to ShapeDtypeStruct because jax.Array also has it and SDS is a duck of jax.Array
This fixes a tracing cache miss issue when you eval shape with a weak_type input and get a strong type output back and pass that back in leading to a cache miss.

Fixes: https://github.com/google/jax/issues/23302
PiperOrigin-RevId: 668949430
2024-08-29 08:35:42 -07:00
Yash Katariya
82c9da020a Generalize global jit cpp cache keys so we can add more keys than the current donate_argnums.
This allows us to get more cache hits globally. For example:

Before:

```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache miss
```

After:
```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache hit
```

Also, we can remove the hack (which I didn't like) in multihost_utils.py.

PiperOrigin-RevId: 665574475
2024-08-20 16:18:58 -07:00
Yash Katariya
6e1c23610d If input layouts are specified via in_shardings to jit and the array that the jitted function is called with is uncommitted, reshard the input array to the layout specified by the user.
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.

Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
2024-08-19 15:10:32 -07:00
Yue Sheng
09beb33226 Don't call api.clean_up when there is no default backend.
PiperOrigin-RevId: 658936536
2024-08-02 16:14:29 -07:00
Yue Sheng
88c8bacdca Add util.clear_all_caches to api.clear_backends and let api.clear_backends be called before process terminates on JAX CPU. This could make the PjRt CPU client object to be successfully destroyed during Python garbage collection.
PiperOrigin-RevId: 658843789
2024-08-02 11:08:48 -07:00
Sergei Lebedev
fb1dbf15df Bumped mypy to 1.11.0 and jaxlib to 0.4.31 on the CI 2024-08-01 22:30:24 +01:00
Matthew Johnson
c8ea86c9c9 remove inlined jax.nn.initializers definitions, resolving TODO of levskaya et al
fixes breakage from cl/655766534 aka https://github.com/google/jax/pull/21069

PiperOrigin-RevId: 655806010
2024-07-24 20:55:36 -07:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Jake VanderPlas
f887b66d5d Remove the unaccelerate_deprecation utility 2024-07-23 05:07:49 -07:00
Yash Katariya
ff3dc0f5fb Add check_compatible_aval checks to Layout. It checks if len(major_to_minor) == len(aval.shape).
PiperOrigin-RevId: 651777179
2024-07-12 08:10:43 -07:00
Cristian Garcia
e45a95d96c Update jax/_src/api.py
Co-authored-by: Matthew Johnson <mattjj@google.com>
2024-07-02 16:10:36 +01:00
Cristian Garcia
756de6952f Update jax/_src/api.py
Co-authored-by: Matthew Johnson <mattjj@google.com>
2024-07-01 21:55:54 +01:00
Matthew Johnson
987194d4e9 prototyping improving vmap out_axes error
e.g.:

  jax.vmap(lambda x: (x, x), in_axes=0, out_axes=(0, None))(jnp.arange(3))

Co-authored-by: Cristian Garcia <cgarciae@google.com>
2024-07-01 21:55:53 +01:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Peter Hawkins
9e30079dba [JAX] Add caching to pjit._infer_params.
When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.

In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.

```
name           old cpu/op   new cpu/op   delta
jit_add_chain  59.1ms ±14%  49.4ms ±10%  -16.32%  (p=0.008 n=5+5)

name           old time/op          new time/op          delta
jit_add_chain  60.3ms ±14%          50.7ms ±11%  -15.99%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 645491650
2024-06-21 13:53:04 -07:00
Yash Katariya
6ba16e0348 Add lowering_platforms to traced.lower() to allow lowering to different backends and multi-backend lowering too. In other words, enable cross-lowering!
The motivation for doing this is 2-fold:

1) This will help with deprecating and eventually deleting `jax.xla_computation` which allows for cross backend lowering.

2) Allow for cross-backend and multi-backend lowering via jax AOT APIs which will help cleanup some hacks implemented for `jax.export`.

Note that this is only available by `.trace.lower(lowering_platforms=('tpu',))`. You cannot use `.lower` to do cross-lowering. We can introduce top-level APIs in the future to allow for composable aot apis to make this easier if `.trace(*args).lower(lowering_platforms)` is cumbersome to write.

Designed with @froystig!

PiperOrigin-RevId: 644087787
2024-06-17 11:59:10 -07:00
Junwhan Ahn
cec796f5dc Batch pxla.shard_args calls triggered by jax.device_put
With this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096) will batch device-to-device transfers across arrays in a pytree.

The api_benchmark indicates that this CL makes `device_put` with 10 to 1000 arrays ~30% faster, likely because it reduces the number of `device_put_p.bind()` calls.

PiperOrigin-RevId: 644051624
2024-06-17 10:17:25 -07:00
Junwhan Ahn
5046cedbfc Make pxla.shard_arg batch calls to xc.copy_array_to_devices_with_sharding
This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so.

Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays.

PiperOrigin-RevId: 643097852
2024-06-13 13:10:10 -07:00
Yash Katariya
6c34a56b87 Add util.cache to jax.clear_caches and move pjit, sharding, array, etc uses of functools.lru_cache to util.cache so that those caches will be cleared if jax.clear_caches is called.
PiperOrigin-RevId: 642359226
2024-06-11 12:46:47 -07:00
Yash Katariya
956226c929 Raise an error if device_put sees an invalid value.
PiperOrigin-RevId: 642053543
2024-06-10 16:07:44 -07:00
Yash Katariya
44a13c9d4b Merge code between make_jaxpr and jit(f).trace.
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't.

Since we can keep the existing behavior and still merge the implementation is a good cleanup!

Fixes https://github.com/google/jax/issues/21116

PiperOrigin-RevId: 641347140
2024-06-07 13:48:31 -07:00
Roy Frostig
ea6dfd1947 rename Specialized to Traced (and specialize to trace)
PiperOrigin-RevId: 641076488
2024-06-06 17:43:08 -07:00
Yash Katariya
aee62e4874 Implement lower in terms of specialize
PiperOrigin-RevId: 641005643
2024-06-06 13:39:07 -07:00
Yash Katariya
fbf2a62aa1 Remove jaxpr and name from Lowered because specialize already has those. This keeps the abstraction boundary clear. Adapt export to use specialize.
PiperOrigin-RevId: 640968129
2024-06-06 11:38:56 -07:00
Yash Katariya
1edd649de4 Deprecate XLACompatibleSharding in favor of jax.sharding.Sharding.
PiperOrigin-RevId: 640544939
2024-06-05 09:07:27 -07:00
Matthew Johnson
10d285dea7 fix error message for vjp arguments 2024-05-30 21:22:35 +00:00
Yash Katariya
73a8e7f0a8 Rename is_compatible_aval to check_compatible_aval since it returns None and not a bool.
PiperOrigin-RevId: 638431968
2024-05-29 15:29:32 -07:00
George Necula
bb4c073574 Add the function name, the Jaxpr, and lowering platforms to Lowered.
These changes are necessary to ensure that `Lowered` carries all the
information that is needed for export and serialization.
These are in preparation of a cleanup of the exporting and serialization APIs
to integrate them with the AOT APIs. In particular, exporting will start
with a `Lowered` object and will not include anymore its own lowering code.

We add the lowered function name and the Jaxpr (as the attributes `_fun_name` and `_jaxpr`)
to `Lowered`,
and we add the tuple of lowering platforms (as `Lowered._lowering._platforms`).

The function name is useful for better error messages when exporting and
serializating. The Jaxpr is useful for exporting also the VJP of the function
and obtaining an `Exported` that can be differentiated.
2024-05-29 05:04:17 +03:00
jax authors
b5583742b5 Merge pull request #21273 from superbobry:mypy-ruff
PiperOrigin-RevId: 636146344
2024-05-22 06:35:38 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Yash Katariya
96f888bcfe Reverts 1956ff7d7b73794012fece2d8452e097196587fc
PiperOrigin-RevId: 631974751
2024-05-08 17:23:13 -07:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
Lianmin Zheng
0eed28a010
Fix a typo in jax.jit docstring 2024-05-06 04:59:23 -07:00
Yash Katariya
1956ff7d7b Add specialize on jax.jit so that we can delete the duplicate code in jax.make_jaxpr.
You can now do (in addition to make_jaxpr): `jax.jit(f).specialize(*args, **kwargs) -> stages.Specialized`

PiperOrigin-RevId: 628748620
2024-04-27 18:58:16 -07:00
Yash Katariya
8239674dab Replace donation_vector's logic with donation_vector_with_in_tree which is now deleted
PiperOrigin-RevId: 627556267
2024-04-23 17:38:30 -07:00
jax authors
1f4c31d0af Merge pull request #20849 from mattjj:jit-docstring-tweaks
PiperOrigin-RevId: 627167535
2024-04-22 15:05:38 -07:00
Yash Katariya
1837b436d7 Merge some loops in device_put since it's trivial to do so
PiperOrigin-RevId: 626546322
2024-04-19 20:59:55 -07:00
Matthew Johnson
b8df23c25b tweak jit docstring 2024-04-19 17:37:52 -07:00
Yash Katariya
837f0bbf6f Cache the _check_sharding check in device_put. If aval and sharding are the same, no need to check multiple times
PiperOrigin-RevId: 626244240
2024-04-18 21:26:35 -07:00
Yue Sheng
c2d4373535 Make core.Token a non-trivial class which wraps a jax.Array. Currently, we use a singleton and empty core.token object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).

PiperOrigin-RevId: 626091210
2024-04-18 11:09:55 -07:00
jax authors
bb8cf34a31 Document the fact that jax.clear_caches() doesn't affect the persistent cache.
PiperOrigin-RevId: 626019057
2024-04-18 06:52:40 -07:00
Yash Katariya
90401d51e9 Accept layout on ShapeDtypeStruct on the sharding argument. DeviceLocalLayout.AUTO is not allowed on SDS.
PiperOrigin-RevId: 624982814
2024-04-15 09:19:40 -07:00
Matthew Johnson
83a200a42f simple fix to make_jaxpr docstring
maybe it was accidentally copied from xla_computation before?
2024-04-11 21:50:51 -07:00
Sai-Suraj-27
5564521308 Prefer raising of TypeError for invalid types instead of ValueError. 2024-04-08 13:08:24 +05:30
Yash Katariya
92326dbc71 Expose Layout(device_local_layout, sharding) class allowing users to specify layouts of Arrays.
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.

Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
2024-04-03 16:13:31 -07:00
Peter Hawkins
d3e03fff5d Refactorings to the jit implementation.
Notably:
* We can share more code between jit/pjit. There's no significant difference between the two, other than the handling of the resource environment, so we can share more of the code.
* Rather than having an infer_params callback, we can just teach common_infer_params (now named _infer_params) to handle the resource environment, which is the only meaningful difference. common_infer_params already had to understand the two cases, so there's no reason we need to hoist part of that logic into a callback.
* If we slightly alter the role of PjitInfo so it contains only the things we know about a jit() or can deduce from its arguments, we can construct it ahead of time. This does require that we split out a couple of things that we cannot deduce at that time, namely the resource environment and the two layout parameters into separate arguments, but the result reads more cleanly to me.

No functional changes intended, this is just to improve readability.

PiperOrigin-RevId: 617812557
2024-03-21 05:37:32 -07:00
Junwhan Ahn
f569031456 Reverts 55394a0914dc0583427a4ceb73dac56348911d15
PiperOrigin-RevId: 616201321
2024-03-15 11:56:45 -07:00