485 Commits

Author SHA1 Message Date
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Dougal Maclaurin
018189491b Clean up and fix primal type to tangent type mapping
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00
Parker Schuh
86fe463ad7 [Take 2] Generalize global jit cpp cache keys so we can add more keys than the current donate_argnums.
This allows us to get more cache hits globally. For example:

Before:

jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache miss
After:

jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache hit

Reverts b615266175effe4aefeb903620a19f3719a604da

PiperOrigin-RevId: 675746175
2024-09-17 16:11:28 -07:00
Sergei Lebedev
83bccdd289 sharding and weak_type parameters of ShapeDtypeStruct are now keyword-only
We decided not to go through a deprecation cycle for this change, because
in the vast majority of cases internally these parameters are bound via a
keyword argument anyway.

PiperOrigin-RevId: 674324964
2024-09-13 09:24:38 -07:00
Sergei Lebedev
b886bd7300 Removed the named_shape argument from jex.core.ShapedArray and jax.ShapeDtypeStruct
It is unused and was only kept around to avoid breaking internal users.

PiperOrigin-RevId: 674310795
2024-09-13 08:38:15 -07:00
Yash Katariya
de9b98e0a8 Delete jax.xla_computation since it's been 3 months since it was deprecated.
PiperOrigin-RevId: 673938336
2024-09-12 11:47:38 -07:00
jax authors
4957ab9a5e Clean up JAX backend for all backends to avoid dangling PyClient references.
PiperOrigin-RevId: 673102539
2024-09-10 14:19:00 -07:00
Yash Katariya
b615266175 Reverts 82c9da020a78997862a8f7ccd494bed363f7ed01
PiperOrigin-RevId: 668969133
2024-08-29 09:43:19 -07:00
Yash Katariya
dd6f0e2e2e Add weak_type to ShapeDtypeStruct because jax.Array also has it and SDS is a duck of jax.Array
This fixes a tracing cache miss issue when you eval shape with a weak_type input and get a strong type output back and pass that back in leading to a cache miss.

Fixes: https://github.com/google/jax/issues/23302
PiperOrigin-RevId: 668949430
2024-08-29 08:35:42 -07:00
Yash Katariya
82c9da020a Generalize global jit cpp cache keys so we can add more keys than the current donate_argnums.
This allows us to get more cache hits globally. For example:

Before:

```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache miss
```

After:
```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache hit
```

Also, we can remove the hack (which I didn't like) in multihost_utils.py.

PiperOrigin-RevId: 665574475
2024-08-20 16:18:58 -07:00
Yash Katariya
6e1c23610d If input layouts are specified via in_shardings to jit and the array that the jitted function is called with is uncommitted, reshard the input array to the layout specified by the user.
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.

Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
2024-08-19 15:10:32 -07:00
Yue Sheng
09beb33226 Don't call api.clean_up when there is no default backend.
PiperOrigin-RevId: 658936536
2024-08-02 16:14:29 -07:00
Yue Sheng
88c8bacdca Add util.clear_all_caches to api.clear_backends and let api.clear_backends be called before process terminates on JAX CPU. This could make the PjRt CPU client object to be successfully destroyed during Python garbage collection.
PiperOrigin-RevId: 658843789
2024-08-02 11:08:48 -07:00
Sergei Lebedev
fb1dbf15df Bumped mypy to 1.11.0 and jaxlib to 0.4.31 on the CI 2024-08-01 22:30:24 +01:00
Matthew Johnson
c8ea86c9c9 remove inlined jax.nn.initializers definitions, resolving TODO of levskaya et al
fixes breakage from cl/655766534 aka https://github.com/google/jax/pull/21069

PiperOrigin-RevId: 655806010
2024-07-24 20:55:36 -07:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Jake VanderPlas
f887b66d5d Remove the unaccelerate_deprecation utility 2024-07-23 05:07:49 -07:00
Yash Katariya
ff3dc0f5fb Add check_compatible_aval checks to Layout. It checks if len(major_to_minor) == len(aval.shape).
PiperOrigin-RevId: 651777179
2024-07-12 08:10:43 -07:00
Cristian Garcia
e45a95d96c Update jax/_src/api.py
Co-authored-by: Matthew Johnson <mattjj@google.com>
2024-07-02 16:10:36 +01:00
Cristian Garcia
756de6952f Update jax/_src/api.py
Co-authored-by: Matthew Johnson <mattjj@google.com>
2024-07-01 21:55:54 +01:00
Matthew Johnson
987194d4e9 prototyping improving vmap out_axes error
e.g.:

  jax.vmap(lambda x: (x, x), in_axes=0, out_axes=(0, None))(jnp.arange(3))

Co-authored-by: Cristian Garcia <cgarciae@google.com>
2024-07-01 21:55:53 +01:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Peter Hawkins
9e30079dba [JAX] Add caching to pjit._infer_params.
When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.

In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.

```
name           old cpu/op   new cpu/op   delta
jit_add_chain  59.1ms ±14%  49.4ms ±10%  -16.32%  (p=0.008 n=5+5)

name           old time/op          new time/op          delta
jit_add_chain  60.3ms ±14%          50.7ms ±11%  -15.99%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 645491650
2024-06-21 13:53:04 -07:00
Yash Katariya
6ba16e0348 Add lowering_platforms to traced.lower() to allow lowering to different backends and multi-backend lowering too. In other words, enable cross-lowering!
The motivation for doing this is 2-fold:

1) This will help with deprecating and eventually deleting `jax.xla_computation` which allows for cross backend lowering.

2) Allow for cross-backend and multi-backend lowering via jax AOT APIs which will help cleanup some hacks implemented for `jax.export`.

Note that this is only available by `.trace.lower(lowering_platforms=('tpu',))`. You cannot use `.lower` to do cross-lowering. We can introduce top-level APIs in the future to allow for composable aot apis to make this easier if `.trace(*args).lower(lowering_platforms)` is cumbersome to write.

Designed with @froystig!

PiperOrigin-RevId: 644087787
2024-06-17 11:59:10 -07:00
Junwhan Ahn
cec796f5dc Batch pxla.shard_args calls triggered by jax.device_put
With this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096) will batch device-to-device transfers across arrays in a pytree.

The api_benchmark indicates that this CL makes `device_put` with 10 to 1000 arrays ~30% faster, likely because it reduces the number of `device_put_p.bind()` calls.

PiperOrigin-RevId: 644051624
2024-06-17 10:17:25 -07:00
Junwhan Ahn
5046cedbfc Make pxla.shard_arg batch calls to xc.copy_array_to_devices_with_sharding
This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so.

Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays.

PiperOrigin-RevId: 643097852
2024-06-13 13:10:10 -07:00
Yash Katariya
6c34a56b87 Add util.cache to jax.clear_caches and move pjit, sharding, array, etc uses of functools.lru_cache to util.cache so that those caches will be cleared if jax.clear_caches is called.
PiperOrigin-RevId: 642359226
2024-06-11 12:46:47 -07:00
Yash Katariya
956226c929 Raise an error if device_put sees an invalid value.
PiperOrigin-RevId: 642053543
2024-06-10 16:07:44 -07:00
Yash Katariya
44a13c9d4b Merge code between make_jaxpr and jit(f).trace.
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't.

Since we can keep the existing behavior and still merge the implementation is a good cleanup!

Fixes https://github.com/google/jax/issues/21116

PiperOrigin-RevId: 641347140
2024-06-07 13:48:31 -07:00
Roy Frostig
ea6dfd1947 rename Specialized to Traced (and specialize to trace)
PiperOrigin-RevId: 641076488
2024-06-06 17:43:08 -07:00
Yash Katariya
aee62e4874 Implement lower in terms of specialize
PiperOrigin-RevId: 641005643
2024-06-06 13:39:07 -07:00
Yash Katariya
fbf2a62aa1 Remove jaxpr and name from Lowered because specialize already has those. This keeps the abstraction boundary clear. Adapt export to use specialize.
PiperOrigin-RevId: 640968129
2024-06-06 11:38:56 -07:00
Yash Katariya
1edd649de4 Deprecate XLACompatibleSharding in favor of jax.sharding.Sharding.
PiperOrigin-RevId: 640544939
2024-06-05 09:07:27 -07:00
Matthew Johnson
10d285dea7 fix error message for vjp arguments 2024-05-30 21:22:35 +00:00
Yash Katariya
73a8e7f0a8 Rename is_compatible_aval to check_compatible_aval since it returns None and not a bool.
PiperOrigin-RevId: 638431968
2024-05-29 15:29:32 -07:00
George Necula
bb4c073574 Add the function name, the Jaxpr, and lowering platforms to Lowered.
These changes are necessary to ensure that `Lowered` carries all the
information that is needed for export and serialization.
These are in preparation of a cleanup of the exporting and serialization APIs
to integrate them with the AOT APIs. In particular, exporting will start
with a `Lowered` object and will not include anymore its own lowering code.

We add the lowered function name and the Jaxpr (as the attributes `_fun_name` and `_jaxpr`)
to `Lowered`,
and we add the tuple of lowering platforms (as `Lowered._lowering._platforms`).

The function name is useful for better error messages when exporting and
serializating. The Jaxpr is useful for exporting also the VJP of the function
and obtaining an `Exported` that can be differentiated.
2024-05-29 05:04:17 +03:00
jax authors
b5583742b5 Merge pull request #21273 from superbobry:mypy-ruff
PiperOrigin-RevId: 636146344
2024-05-22 06:35:38 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Yash Katariya
96f888bcfe Reverts 1956ff7d7b73794012fece2d8452e097196587fc
PiperOrigin-RevId: 631974751
2024-05-08 17:23:13 -07:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
Lianmin Zheng
0eed28a010
Fix a typo in jax.jit docstring 2024-05-06 04:59:23 -07:00
Yash Katariya
1956ff7d7b Add specialize on jax.jit so that we can delete the duplicate code in jax.make_jaxpr.
You can now do (in addition to make_jaxpr): `jax.jit(f).specialize(*args, **kwargs) -> stages.Specialized`

PiperOrigin-RevId: 628748620
2024-04-27 18:58:16 -07:00
Yash Katariya
8239674dab Replace donation_vector's logic with donation_vector_with_in_tree which is now deleted
PiperOrigin-RevId: 627556267
2024-04-23 17:38:30 -07:00
jax authors
1f4c31d0af Merge pull request #20849 from mattjj:jit-docstring-tweaks
PiperOrigin-RevId: 627167535
2024-04-22 15:05:38 -07:00
Yash Katariya
1837b436d7 Merge some loops in device_put since it's trivial to do so
PiperOrigin-RevId: 626546322
2024-04-19 20:59:55 -07:00
Matthew Johnson
b8df23c25b tweak jit docstring 2024-04-19 17:37:52 -07:00
Yash Katariya
837f0bbf6f Cache the _check_sharding check in device_put. If aval and sharding are the same, no need to check multiple times
PiperOrigin-RevId: 626244240
2024-04-18 21:26:35 -07:00
Yue Sheng
c2d4373535 Make core.Token a non-trivial class which wraps a jax.Array. Currently, we use a singleton and empty core.token object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).

PiperOrigin-RevId: 626091210
2024-04-18 11:09:55 -07:00
jax authors
bb8cf34a31 Document the fact that jax.clear_caches() doesn't affect the persistent cache.
PiperOrigin-RevId: 626019057
2024-04-18 06:52:40 -07:00
Yash Katariya
90401d51e9 Accept layout on ShapeDtypeStruct on the sharding argument. DeviceLocalLayout.AUTO is not allowed on SDS.
PiperOrigin-RevId: 624982814
2024-04-15 09:19:40 -07:00