978 Commits

Author SHA1 Message Date
Yash Katariya
67b0eb3af4 Improve pytree mismatch error in AOT
PiperOrigin-RevId: 612560820
2024-03-04 13:15:32 -08:00
Peter Hawkins
fdbee314d3 Make JAX tests that check for errors from dict key comparators in pytrees more relaxed, in preparation for https://github.com/openxla/xla/pull/9529.
PiperOrigin-RevId: 610819296
2024-02-27 11:30:10 -08:00
Jake VanderPlas
cddee4654c tests: access tree utilities via jax.tree.* 2024-02-26 14:17:18 -08:00
Matthew Johnson
3736b322b7 [xmap-removal] remove reduce_axes from grad / vjp / backward_pass
The reduce_axes machinery was planned to be used for xmap. It's not needed for
e.g. shard_map, see https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html.
2024-02-25 15:50:54 -08:00
Jake VanderPlas
e59a0506fe Deprecate jax.tree_map in favor of jax.tree.map 2024-02-22 11:35:39 -08:00
Matthew Johnson
bc1e5f0220 [custom_vjp] handle Nones in subtrees returned by bwd rule
fixes #8356
2024-02-21 00:37:04 -08:00
Jake VanderPlas
1fe46aa8be Error for deprecated scalar conversions of non-scalar arrays 2024-02-16 11:26:30 -08:00
George Necula
ddc248797e Fix JVP rule for convert_element_type to integer types.
Without this fix, the newly added test fails:

```
File "/Users/necula/Source/jax/jax/_src/lax/lax.py", line 2382, in _convert_element_type_jvp_rule
    return ad_util.Zero(tangent.aval.update(dtype=dtypes.float0, weak_type=False))
                         ^^^^^^^^^^^^
AttributeError: 'float' object has no attribute 'aval'
```

because `tangent` is a float constant.
2024-02-15 14:09:04 +01:00
jax authors
9b27d43e70 Import submodules from jax._src explicitly, instead of relying on import side-effects. It will lead to the missing x-refs in code search according to go/pywald-sawmill-analysis.
PiperOrigin-RevId: 604788105
2024-02-06 15:47:16 -08:00
Yash Katariya
d9122b8bac Add sharding to ShapeDtypeStruct retured by eval_shape if jit has out_shardings specified
PiperOrigin-RevId: 602556016
2024-01-29 18:02:51 -08:00
jax authors
1f380e0231 Merge pull request #19413 from jakevdp:dep-tie-in
PiperOrigin-RevId: 599688284
2024-01-18 18:52:52 -08:00
Jake VanderPlas
91a33362de Deprecate jax.lax.tie_in 2024-01-18 13:13:47 -08:00
Yash Katariya
51ef738c86 Use jit's jaxpr creation function for eval_shape to maximize tracing cache hits.
This comes up in LLM models, where we trace twice (one for eval_shape (usually the init function) and another during jit) when the output jaxpr is the same. This shouldn't happen and we should cache as much as possible.

The only caveat here is that in eval_shape the `traced_for` on `DebugInfo` is set to `jit`. But maybe it's ok to do that if we want to deprecate eval_shape for a AOT style method on `jax.jit` or have it be a thin wrapper around something like `jax.jit(f).eval_shape`

PiperOrigin-RevId: 599602407
2024-01-18 13:11:44 -08:00
Peter Hawkins
e558feaa5e Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
2024-01-15 02:13:40 -08:00
Matthew Johnson
9112dcebc9 add jax.explain_cache_misses tracing cache miss explanations
As part of making JAX's behavior more transparent, it must be clear not only
when code is slow because it's spending all its time missing caches (and hence
retracing/recompiling), but also _why_ it missed those caches. That is, just
knowing (from e.g. setting jax_log_compiles) that code is retracing a lot
doesn't tell the user what to do to fix things. But once the user knows that
the cache misses are due to changing dtypes, or due to jit being passed a new
callable object on every iteration of a loop, it's often clear what to do. And
JAX can provide that information

The main idea here is that pointing out which parts of the cache key differs
from previously-seen keys can constitute a pretty good explanation.

This PR adds an explanation mechanism. It can be enabled in a few different ways:
  * setting the `JAX_EXPLAIN_CACHE_MISSES` shell environment variable to something truthy;
  * setting the config option `jax.config.update('jax_explain_cache_misses', True)`;
  * using the context manager `jax._src.config.explain_cache_misses` context
    manager (not in public namespace yet);
  * when parsing command line flags with absl, using the
    `--jax_explain_cache_misses` flag.

Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-12-26 21:54:27 -08:00
Matthew Johnson
be3ca507db del add_any_p and zeros_like_p, replace aval-dispatched traceable 2023-12-21 17:04:21 -08:00
Jan Hrček
4da56dcdd7 Fix duplicate word occurrences 2023-12-19 06:15:30 +01:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Matthew Johnson
70106502be fix gc bug in PjitFunction (not clearing/visiting all children)
Co-authored-by: Peter Hawkins <phawkins@google.com>
Co-authored-by: Parker Schuh <parkers@google.com>
PiperOrigin-RevId: 589271238
2023-12-08 16:23:41 -08:00
Matthew Johnson
64cb53f624 improve an error message during Mesh creation 2023-12-06 16:43:36 -08:00
Jake VanderPlas
c2a0530274 jaxpr: improve printed repr when eqn has no return values 2023-12-06 10:45:24 -08:00
Matthew Johnson
43ed74f817 rewrite test not to include float0 broadcast 2023-11-30 13:53:13 -08:00
Jake VanderPlas
97beb01c43 Deprecate the device() method of JAX arrays 2023-11-30 11:43:02 -08:00
Jake VanderPlas
d2b4800723 tests: improve warnings-related tests 2023-11-30 10:35:24 -08:00
Yash Katariya
e624610e72 Replace apply_primitive internals with jax.jit.
This allows deletion of a lot of code and leads to ~40% eager performance speedup.

Benchmarks:

```
name                                                      old time/op          new time/op          delta
eager_unary_dispatch                                      31.3µs ± 1%          19.4µs ± 6%  -37.91%    (p=0.016 n=4+5)
eager_unary                                               32.1µs ± 0%          19.8µs ± 4%  -38.26%    (p=0.016 n=4+5)
eager_binary_dispatch                                     35.9µs ± 1%          20.5µs ± 4%  -42.93%    (p=0.016 n=4+5)
eager_binary                                              36.6µs ± 1%          21.1µs ± 4%  -42.29%    (p=0.016 n=4+5)
jit_trivial_dispatch                                      3.87µs ± 2%          4.12µs ±25%     ~       (p=1.000 n=5+5)
jit_trivial                                               4.75µs ± 2%          4.82µs ±11%     ~       (p=0.690 n=5+5)
jit_simple_dispatch                                       2.95µs ± 2%          2.97µs ± 7%     ~       (p=1.000 n=5+5)
jit_simple                                                3.52µs ± 6%          3.51µs ± 5%     ~       (p=0.841 n=5+5)
jit_simple_dispatch_array                                 2.95µs ± 2%          2.96µs ± 6%     ~       (p=1.000 n=5+5)
jit_simple_array                                          3.46µs ± 2%          3.51µs ± 5%     ~       (p=0.690 n=5+5)
jit_small_matmul                                          3.01µs ± 1%          3.00µs ± 4%     ~       (p=0.548 n=5+5)
jit_big_matmul                                            34.0µs ±18%          35.5µs ±17%     ~       (p=0.310 n=5+5)
jit_simple_many_args_dispatch/num_args:10                 6.93µs ± 6%          6.80µs ± 6%     ~     (p=0.481 n=10+10)
jit_simple_many_args_dispatch/num_args:100                47.7µs ± 7%          45.4µs ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args_dispatch/num_args:1000                545µs ± 8%           516µs ± 2%     ~      (p=0.101 n=10+8)
jit_simple_many_args_dispatch/num_args:2000               1.12ms ± 7%          1.07ms ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args/num_args:10                          7.42µs ± 5%          7.23µs ± 2%     ~      (p=0.173 n=10+8)
jit_simple_many_args/num_args:100                         48.4µs ± 7%          45.6µs ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args/num_args:1000                         542µs ± 6%           524µs ± 8%     ~     (p=0.089 n=10+10)
jit_simple_many_args/num_args:2000                        1.12ms ± 7%          1.08ms ± 1%     ~      (p=0.068 n=10+8)
jit_simple_pruned_args_dispatch_10                        4.79µs ± 8%          4.98µs ±10%     ~       (p=0.421 n=5+5)
jit_simple_pruned_args_10                                 5.32µs ± 6%          5.30µs ± 4%     ~       (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_100                       24.7µs ± 6%          23.8µs ± 8%     ~       (p=0.548 n=5+5)
jit_simple_pruned_args_100                                25.2µs ± 6%          24.4µs ± 8%     ~       (p=0.690 n=5+5)
jit_simple_pruned_args_dispatch_1000                       238µs ± 7%           232µs ± 8%     ~       (p=0.841 n=5+5)
jit_simple_pruned_args_1000                                240µs ± 7%           234µs ± 8%     ~       (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_2000                       516µs ± 6%           497µs ± 1%     ~       (p=0.413 n=5+4)
jit_simple_pruned_args_2000                                517µs ± 6%           505µs ± 7%     ~       (p=0.690 n=5+5)
jit_dispatch_without_transfer                              719µs ± 9%           751µs ± 8%     ~       (p=0.222 n=5+5)
jit_dispatch_with_transfer                                 799µs ±14%           793µs ± 9%     ~       (p=1.000 n=5+5)
pmap_trivial_2_devices                                    49.9µs ±40%          48.2µs ±42%     ~       (p=0.841 n=5+5)
pmap_trivial_dispatch_8_devices                           74.5µs ±24%          78.9µs ±29%     ~       (p=0.421 n=5+5)
pmap_trivial_8_devices                                    79.3µs ± 6%          82.7µs ±20%     ~       (p=0.841 n=5+5)
pmap_simple_2_devices                                     47.1µs ±17%          49.1µs ±20%     ~       (p=0.548 n=5+5)
pmap_simple_dispatch_8_devices                            73.4µs ±16%          76.8µs ±21%     ~       (p=0.690 n=5+5)
pmap_simple_8_devices                                     76.0µs ±10%          80.6µs ±29%     ~       (p=1.000 n=5+5)
pmap_simple_dispatch_8_devices_100_args                   1.12ms ±22%          1.08ms ±42%     ~       (p=0.841 n=5+5)
pmap_simple_8_devices_100_args                            12.5ms ± 8%          12.8ms ±10%     ~       (p=1.000 n=5+5)
sda_index_1                                                413µs ± 1%           686µs ± 4%  +66.08%    (p=0.008 n=5+5)
sda_index_2                                                850µs ± 1%          1378µs ± 4%  +62.02%    (p=0.008 n=5+5)
sda_index_8                                               3.60ms ± 1%          5.69ms ± 4%  +58.00%    (p=0.008 n=5+5)
bench_shaped_abstractify                                   300µs ± 1%           305µs ± 3%     ~       (p=0.056 n=5+5)
bench_xla_abstractify_scalar_int                          6.45µs ± 1%          6.50µs ± 3%     ~       (p=0.548 n=5+5)
bench_xla_abstractify_scalar_float                        3.73µs ± 1%          3.73µs ± 3%     ~       (p=0.690 n=5+5)
bench_xla_abstractify_scalar_numpy_int32                  4.97µs ± 1%          4.83µs ± 3%     ~       (p=0.095 n=5+5)
bench_xla_abstractify_scalar_numpy_uint32                 4.91µs ± 1%          4.75µs ± 0%   -3.30%    (p=0.016 n=5+4)
bench_xla_abstractify_numpy_random                        4.34µs ± 2%          4.31µs ± 3%     ~       (p=0.310 n=5+5)
bench_xla_abstractify_numpy_arange_100_float32            3.94µs ± 1%          3.93µs ± 3%     ~       (p=0.548 n=5+5)
bench_xla_abstractify_enum                                6.85µs ± 1%          7.06µs ± 7%   +3.07%    (p=0.032 n=5+5)
bench_are_op_shardings_equal                              26.9µs ± 2%          27.0µs ± 3%     ~       (p=0.841 n=5+5)
bench_pjit_check_aval_sharding                             691µs ± 2%           711µs ±13%     ~       (p=0.841 n=5+5)
bench_addressable_shards_index                             656ns ± 4%           688ns ± 9%     ~       (p=0.095 n=5+5)
bench_remat_eager_retracing_overheads                     12.7ms ± 4%          10.7ms ± 1%  -15.48%    (p=0.016 n=5+4)
bench_remat_eager_retracing_overheads_static_argnums      13.0ms ± 2%          11.3ms ± 6%  -13.71%    (p=0.008 n=5+5)
bench_slicing_compilation                                 12.1ms ± 1%          12.3ms ± 4%     ~       (p=0.690 n=5+5)
bench_slicing_compilation2                                11.3ms ± 0%          11.5ms ± 6%     ~       (p=0.690 n=5+5)
bench_repeated_static_indexing                            62.5ms ± 2%          40.8ms ± 8%  -34.77%    (p=0.008 n=5+5)
bench_repeated_static_slicing                             46.7ms ± 1%          31.4ms ± 2%  -32.76%    (p=0.008 n=5+5)
pjit_simple_1_device/num_args:1                           2.72µs ± 2%          2.68µs ± 5%     ~       (p=0.151 n=5+5)
pjit_simple_1_device/num_args:10                          12.6µs ± 7%          12.3µs ± 3%     ~       (p=0.310 n=5+5)
pjit_simple_1_device/num_args:100                          109µs ± 3%           108µs ± 4%     ~       (p=0.548 n=5+5)
pjit_simple_4_device/num_args:1                           38.0µs ±26%          36.8µs ±19%     ~       (p=0.690 n=5+5)
pjit_simple_4_device/num_args:10                          93.3µs ±19%          96.6µs ±23%     ~       (p=0.841 n=5+5)
pjit_simple_4_device/num_args:100                          730µs ±16%           698µs ±48%     ~       (p=0.841 n=5+5)
pjit_aot_1_device/num_args:1                              3.29µs ± 2%          3.12µs ± 4%   -5.24%    (p=0.016 n=4+5)
pjit_aot_1_device/num_args:10                             13.0µs ± 1%          12.7µs ± 2%     ~       (p=0.063 n=4+5)
pjit_aot_1_device/num_args:100                             111µs ± 5%           110µs ±11%     ~       (p=0.421 n=5+5)
pjit_aot_4_device/num_args:1                              38.4µs ±19%          38.9µs ±24%     ~       (p=1.000 n=5+5)
pjit_aot_4_device/num_args:10                             91.3µs ±15%          96.9µs ±29%     ~       (p=0.548 n=5+5)
pjit_aot_4_device/num_args:100                             676µs ±20%           689µs ±41%     ~       (p=0.841 n=5+5)
host_local_array_to_global_array                           196µs ± 6%           194µs ± 4%     ~       (p=0.548 n=5+5)
device_put                                                50.8µs ± 1%          50.7µs ± 4%     ~       (p=0.413 n=4+5)
device_put_sharded                                         176µs ± 0%           177µs ± 4%     ~       (p=0.190 n=4+5)
device_get_8_devices                                      3.96ms ± 4%          4.03ms ± 7%     ~       (p=0.413 n=4+5)
np_asarray_8_devices                                      3.34ms ±18%          3.30ms ±10%     ~       (p=0.548 n=5+5)
jax_array_arrays_8_devices                                5.01ms ±10%          5.09ms ±21%     ~       (p=0.421 n=5+5)
batch_inplace_while_scatter                                440µs ± 1%           439µs ± 1%     ~       (p=0.421 n=5+5)
batch_inplace_while_dynamic_update_slice                   454µs ± 0%           457µs ± 1%     ~       (p=0.905 n=4+5)
serial_dot_products                                       4.51µs ± 3%          4.41µs ± 2%     ~       (p=0.151 n=5+5)
bench_make_array_from_callback_fully_replicated_sharding  26.6µs ± 1%          27.0µs ± 2%     ~       (p=0.056 n=5+5)
```

PiperOrigin-RevId: 586505950
2023-11-29 18:07:13 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Jake VanderPlas
0bcd64ade3 jax.vmap: improve docs & error for structured in_axes 2023-11-15 11:56:53 -08:00
jax authors
871b79925e Fix test failures when we update the abseil hashtable implementation.
PiperOrigin-RevId: 581988519
2023-11-13 09:24:06 -08:00
Peter Hawkins
1611e1bc41 Remove PythonJitTest from api_test.py.
Ever since the jit-pjit merge, the "Python" jit test has actually just called the same code as the "C++" jit test. We don't have a C++-free jit path any more. Remove the "Python" tests since they don't test anything.

PiperOrigin-RevId: 581965049
2023-11-13 08:03:23 -08:00
Junwhan Ahn
55394a0914 Roll back the optimized version of jax.block_until_ready due to test breakage
Reverts 6cc6d093643c0265c7de4027f79879f6945e0342

PiperOrigin-RevId: 581577789
2023-11-11 12:15:45 -08:00
Junwhan Ahn
6cc6d09364 Implement more efficient jax.block_until_ready(x) in C++
The current implementation synchronously calls `ArrayImpl.block_until_ready()` one by one. This is suboptimal when it's not cheap to query the readiness of an array. Also, calling `x.block_until_ready()` causes GIL to be acquired/released repeatedly.

To address this issue, this CL introduces a C++ implementation of `jax.block_until_ready(x)` that uses IFRT's `Array::GetReadyFuture()` to asynchronously query the readiness of all arrays and wait for them once. To preserve the previous behavior, the C++ implementation also has a slow path for any non-PyArray objects that implement `block_until_ready`.

PiperOrigin-RevId: 581302290
2023-11-10 10:34:34 -08:00
jax authors
e227536fd6 In api_test.py, wait for the result in test_double_donation.
PiperOrigin-RevId: 579267104
2023-11-03 12:23:55 -07:00
Peter Hawkins
011d49c518 Add a test for double donation.
The underlying issue was fixed some time ago.

Fixes https://github.com/google/jax/issues/9635

PiperOrigin-RevId: 579170638
2023-11-03 07:03:13 -07:00
Jake VanderPlas
53c4de477e [random] deprecate jax.random.default_prng_impl() 2023-10-19 13:59:01 -07:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Jake VanderPlas
709b05f12f jax.make_jaxpr: fix __name__ & related attributes 2023-10-09 15:12:28 -07:00
Jake VanderPlas
eb5981a9dc api_test: fix platform-dependent deprecation warning 2023-10-06 13:40:08 -07:00
Matthew Willson
17d89ad166 Fix jax.device_put so it doesn't use tree_map for _check_sharding.
This causes it to unnecessarily attempt to unflatten the None return values from _check_sharding into the original tree structure, which is a problem for custom datatypes registered with jax.tree_util that don't accept None values in place of jax arrays.

PiperOrigin-RevId: 570189648
2023-10-02 15:01:03 -07:00
George Necula
552fef6fcd Introduce a LoweringParameters dataclass for easier plumbing
There are currently two parameters that are used to configure
lowering: lowering_platform (for cross-platform lowering), and
override_lowering_rules. Each of them are passed as separate arguments
through several layers of lowering internal functions. This is tedious,
and error prone. In fact, override_lowering_rules was not plumbed
in all places, and due to using default arguments in all places,
this leads to silent errors.

We foresee introducing other parameters for lowering: for multi-platform
lowering, for controlling the lowering of effects.

Here is pack all such parameters into a `mlir.LoweringParameters`
dataclass and we plumb that through.
2023-09-29 08:23:05 +03:00
Junwhan Ahn
8bfe3b92bc Roll back f92a70a41e
Reverts bb4382f0bce074ab081e1e02871e32ba331d1d46

PiperOrigin-RevId: 569292433
2023-09-28 14:32:23 -07:00
Junwhan Ahn
bb4382f0bc Destruct objects owned by WeakRefLRUCache::CacheEntry out of band using GlobalPyRefManager()
This assumes less about whether the thread that destructs `CacheEntry` has GIL or not, which is difficult to reason about due to the `xla::LRUCache`'s use of `std::shared_ptr<CacheEntry>`.

The following changes have been made in JAX to accommodate the behavior differences from direct destruction to GC:

* Since `PyLoadedExecutable`s cached in `WeakRefLRUCache` are now destructed out of band, `PyClient::LiveExecutables()` calls `GlobalPyRefManager()->CollectGarbage()` to make the returned information accurate and up to date.
* `test_jit_reference_dropping` has been updated to call `gc.collect()` before verifying the live executable counts since the destruction of executables owned by weak ref maps is now done out of band as part of `GlobalPyRefManager`'s GC.

PiperOrigin-RevId: 569062402
2023-09-27 22:15:22 -07:00
Peter Hawkins
210fab1aae Remove the "No GPU/TPU found" warning.
Instead, add a lightweight test for NVIDIA GPUs and Google TPUs. Warn
only if we suspect either is present but JAX is not using them.
2023-09-26 19:04:34 +00:00
Peter Hawkins
5aaa15df84 Remove the skip_on_xla_cpu_mlir decorator.
We no longer test this variant in CI, so we don't need code to skip it.

PiperOrigin-RevId: 568219651
2023-09-25 08:04:56 -07:00
Yash Katariya
8276038f63 Relax the memory alignment check between numpy array and jax array on CPU
PiperOrigin-RevId: 567722405
2023-09-22 14:49:00 -07:00
Jake VanderPlas
bfed3d862e Improve behavior of core.valid_jaxtype 2023-09-22 13:46:09 -07:00
Yash Katariya
426970591b If an input to jnp.asarray is a numpy array, then convert it to a jax.Array via device_put to avoid a copy.
Do a similar thing for jax.Array too if dtypes match.

Fixes https://github.com/google/jax/issues/17702

PiperOrigin-RevId: 567644997
2023-09-22 09:40:25 -07:00
Jake VanderPlas
0dc2252f71 Better errors for array scalar/boolean conversion 2023-09-19 09:00:19 -07:00
Parker Schuh
21389415cc Add support for float flags to compiler_options.
PiperOrigin-RevId: 565475731
2023-09-14 14:19:39 -07:00
jax authors
6b5af15eea Merge pull request #17593 from jakeh-gc:test_changes
PiperOrigin-RevId: 565428268
2023-09-14 11:30:55 -07:00
Yash Katariya
a2720ee2c3 Deprecate jax.experimental.pjit.with_sharding_constraint. Replacement is jax.lax.with_sharding_constraint which has been available since 1 year.
PiperOrigin-RevId: 565389746
2023-09-14 09:23:03 -07:00