18442 Commits

Author SHA1 Message Date
jax authors
a07ed22b02 Merge pull request #18756 from jakevdp:skips
PiperOrigin-RevId: 586795988
2023-11-30 14:56:11 -08:00
jax authors
efb4924699 Merge pull request #18754 from mattjj:fix-float0
PiperOrigin-RevId: 586781308
2023-11-30 14:05:11 -08:00
Matthew Johnson
43ed74f817 rewrite test not to include float0 broadcast 2023-11-30 13:53:13 -08:00
Jake VanderPlas
d6154e5d89 [array-api] remove some test skips 2023-11-30 13:28:08 -08:00
jax authors
5b3fc1bd5d Merge pull request #18730 from jakevdp:dep-device
PiperOrigin-RevId: 586761439
2023-11-30 12:58:16 -08:00
Mark Sandler
569f06cda7 In python 3.11 async.run() always tries to convert repr of the result of a coroutine as integer while fetching sigint handler. This makes the test materialize the whole tensor in memory. This changes the test co-routine to return nothing to avoid triggering this bug.
https://github.com/python/cpython/issues/112559

PiperOrigin-RevId: 586756112
2023-11-30 12:37:12 -08:00
Jake VanderPlas
97beb01c43 Deprecate the device() method of JAX arrays 2023-11-30 11:43:02 -08:00
jax authors
4de07b3f62 Merge pull request #18753 from jakevdp:warnings-tests
PiperOrigin-RevId: 586739682
2023-11-30 11:37:24 -08:00
jax authors
53e66c1214 Merge pull request #18752 from mattjj:shmap-remat-rule
PiperOrigin-RevId: 586729063
2023-11-30 11:02:20 -08:00
jax authors
dab6379995 Merge pull request #18746 from olupton:fix-repeated-builds
PiperOrigin-RevId: 586723814
2023-11-30 10:47:45 -08:00
Shashank Viswanadha
bd46e5c960 Add nb::arg to nanobind definitions to generate better python annotations.
PiperOrigin-RevId: 586721759
2023-11-30 10:39:28 -08:00
Jake VanderPlas
d2b4800723 tests: improve warnings-related tests 2023-11-30 10:35:24 -08:00
Matthew Johnson
5862852f85 [shard-map] add rewrite and replication checking rules for remat
these rules enable shmap-of-remat with check_rep=True
2023-11-30 10:15:48 -08:00
jax authors
11d7a2b860 Merge pull request #18741 from mattjj:shmap-test-fix
PiperOrigin-RevId: 586710378
2023-11-30 10:09:32 -08:00
Matthew Johnson
5c2635c205 [shard-map] fix test running broken by 0aec40a16fad02f084ef0cabd350db78b86b335e 2023-11-30 09:56:34 -08:00
Olli Lupton
e50c35d1e7 Fix repeatedly building JAX.
Reproducer was essentially running `pip install .` twice in a row in the
same source directory. Closes google/jax#18252.
2023-11-30 13:31:17 +01:00
jax authors
fe237cd776 Update XLA dependency to use revision
58e6b428e2.

PiperOrigin-RevId: 586556206
2023-11-29 22:44:13 -08:00
jax authors
d9a0cc0a07 Merge pull request #18731 from mattjj:shmap-custom-jvp-fix
PiperOrigin-RevId: 586527134
2023-11-29 20:24:54 -08:00
Matthew Johnson
b8f758e4a0 [shard-map] replace jaxpr interpreters with final-style-xform-of-eval-jaxpr 2023-11-29 20:06:12 -08:00
Yash Katariya
e624610e72 Replace apply_primitive internals with jax.jit.
This allows deletion of a lot of code and leads to ~40% eager performance speedup.

Benchmarks:

```
name                                                      old time/op          new time/op          delta
eager_unary_dispatch                                      31.3µs ± 1%          19.4µs ± 6%  -37.91%    (p=0.016 n=4+5)
eager_unary                                               32.1µs ± 0%          19.8µs ± 4%  -38.26%    (p=0.016 n=4+5)
eager_binary_dispatch                                     35.9µs ± 1%          20.5µs ± 4%  -42.93%    (p=0.016 n=4+5)
eager_binary                                              36.6µs ± 1%          21.1µs ± 4%  -42.29%    (p=0.016 n=4+5)
jit_trivial_dispatch                                      3.87µs ± 2%          4.12µs ±25%     ~       (p=1.000 n=5+5)
jit_trivial                                               4.75µs ± 2%          4.82µs ±11%     ~       (p=0.690 n=5+5)
jit_simple_dispatch                                       2.95µs ± 2%          2.97µs ± 7%     ~       (p=1.000 n=5+5)
jit_simple                                                3.52µs ± 6%          3.51µs ± 5%     ~       (p=0.841 n=5+5)
jit_simple_dispatch_array                                 2.95µs ± 2%          2.96µs ± 6%     ~       (p=1.000 n=5+5)
jit_simple_array                                          3.46µs ± 2%          3.51µs ± 5%     ~       (p=0.690 n=5+5)
jit_small_matmul                                          3.01µs ± 1%          3.00µs ± 4%     ~       (p=0.548 n=5+5)
jit_big_matmul                                            34.0µs ±18%          35.5µs ±17%     ~       (p=0.310 n=5+5)
jit_simple_many_args_dispatch/num_args:10                 6.93µs ± 6%          6.80µs ± 6%     ~     (p=0.481 n=10+10)
jit_simple_many_args_dispatch/num_args:100                47.7µs ± 7%          45.4µs ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args_dispatch/num_args:1000                545µs ± 8%           516µs ± 2%     ~      (p=0.101 n=10+8)
jit_simple_many_args_dispatch/num_args:2000               1.12ms ± 7%          1.07ms ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args/num_args:10                          7.42µs ± 5%          7.23µs ± 2%     ~      (p=0.173 n=10+8)
jit_simple_many_args/num_args:100                         48.4µs ± 7%          45.6µs ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args/num_args:1000                         542µs ± 6%           524µs ± 8%     ~     (p=0.089 n=10+10)
jit_simple_many_args/num_args:2000                        1.12ms ± 7%          1.08ms ± 1%     ~      (p=0.068 n=10+8)
jit_simple_pruned_args_dispatch_10                        4.79µs ± 8%          4.98µs ±10%     ~       (p=0.421 n=5+5)
jit_simple_pruned_args_10                                 5.32µs ± 6%          5.30µs ± 4%     ~       (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_100                       24.7µs ± 6%          23.8µs ± 8%     ~       (p=0.548 n=5+5)
jit_simple_pruned_args_100                                25.2µs ± 6%          24.4µs ± 8%     ~       (p=0.690 n=5+5)
jit_simple_pruned_args_dispatch_1000                       238µs ± 7%           232µs ± 8%     ~       (p=0.841 n=5+5)
jit_simple_pruned_args_1000                                240µs ± 7%           234µs ± 8%     ~       (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_2000                       516µs ± 6%           497µs ± 1%     ~       (p=0.413 n=5+4)
jit_simple_pruned_args_2000                                517µs ± 6%           505µs ± 7%     ~       (p=0.690 n=5+5)
jit_dispatch_without_transfer                              719µs ± 9%           751µs ± 8%     ~       (p=0.222 n=5+5)
jit_dispatch_with_transfer                                 799µs ±14%           793µs ± 9%     ~       (p=1.000 n=5+5)
pmap_trivial_2_devices                                    49.9µs ±40%          48.2µs ±42%     ~       (p=0.841 n=5+5)
pmap_trivial_dispatch_8_devices                           74.5µs ±24%          78.9µs ±29%     ~       (p=0.421 n=5+5)
pmap_trivial_8_devices                                    79.3µs ± 6%          82.7µs ±20%     ~       (p=0.841 n=5+5)
pmap_simple_2_devices                                     47.1µs ±17%          49.1µs ±20%     ~       (p=0.548 n=5+5)
pmap_simple_dispatch_8_devices                            73.4µs ±16%          76.8µs ±21%     ~       (p=0.690 n=5+5)
pmap_simple_8_devices                                     76.0µs ±10%          80.6µs ±29%     ~       (p=1.000 n=5+5)
pmap_simple_dispatch_8_devices_100_args                   1.12ms ±22%          1.08ms ±42%     ~       (p=0.841 n=5+5)
pmap_simple_8_devices_100_args                            12.5ms ± 8%          12.8ms ±10%     ~       (p=1.000 n=5+5)
sda_index_1                                                413µs ± 1%           686µs ± 4%  +66.08%    (p=0.008 n=5+5)
sda_index_2                                                850µs ± 1%          1378µs ± 4%  +62.02%    (p=0.008 n=5+5)
sda_index_8                                               3.60ms ± 1%          5.69ms ± 4%  +58.00%    (p=0.008 n=5+5)
bench_shaped_abstractify                                   300µs ± 1%           305µs ± 3%     ~       (p=0.056 n=5+5)
bench_xla_abstractify_scalar_int                          6.45µs ± 1%          6.50µs ± 3%     ~       (p=0.548 n=5+5)
bench_xla_abstractify_scalar_float                        3.73µs ± 1%          3.73µs ± 3%     ~       (p=0.690 n=5+5)
bench_xla_abstractify_scalar_numpy_int32                  4.97µs ± 1%          4.83µs ± 3%     ~       (p=0.095 n=5+5)
bench_xla_abstractify_scalar_numpy_uint32                 4.91µs ± 1%          4.75µs ± 0%   -3.30%    (p=0.016 n=5+4)
bench_xla_abstractify_numpy_random                        4.34µs ± 2%          4.31µs ± 3%     ~       (p=0.310 n=5+5)
bench_xla_abstractify_numpy_arange_100_float32            3.94µs ± 1%          3.93µs ± 3%     ~       (p=0.548 n=5+5)
bench_xla_abstractify_enum                                6.85µs ± 1%          7.06µs ± 7%   +3.07%    (p=0.032 n=5+5)
bench_are_op_shardings_equal                              26.9µs ± 2%          27.0µs ± 3%     ~       (p=0.841 n=5+5)
bench_pjit_check_aval_sharding                             691µs ± 2%           711µs ±13%     ~       (p=0.841 n=5+5)
bench_addressable_shards_index                             656ns ± 4%           688ns ± 9%     ~       (p=0.095 n=5+5)
bench_remat_eager_retracing_overheads                     12.7ms ± 4%          10.7ms ± 1%  -15.48%    (p=0.016 n=5+4)
bench_remat_eager_retracing_overheads_static_argnums      13.0ms ± 2%          11.3ms ± 6%  -13.71%    (p=0.008 n=5+5)
bench_slicing_compilation                                 12.1ms ± 1%          12.3ms ± 4%     ~       (p=0.690 n=5+5)
bench_slicing_compilation2                                11.3ms ± 0%          11.5ms ± 6%     ~       (p=0.690 n=5+5)
bench_repeated_static_indexing                            62.5ms ± 2%          40.8ms ± 8%  -34.77%    (p=0.008 n=5+5)
bench_repeated_static_slicing                             46.7ms ± 1%          31.4ms ± 2%  -32.76%    (p=0.008 n=5+5)
pjit_simple_1_device/num_args:1                           2.72µs ± 2%          2.68µs ± 5%     ~       (p=0.151 n=5+5)
pjit_simple_1_device/num_args:10                          12.6µs ± 7%          12.3µs ± 3%     ~       (p=0.310 n=5+5)
pjit_simple_1_device/num_args:100                          109µs ± 3%           108µs ± 4%     ~       (p=0.548 n=5+5)
pjit_simple_4_device/num_args:1                           38.0µs ±26%          36.8µs ±19%     ~       (p=0.690 n=5+5)
pjit_simple_4_device/num_args:10                          93.3µs ±19%          96.6µs ±23%     ~       (p=0.841 n=5+5)
pjit_simple_4_device/num_args:100                          730µs ±16%           698µs ±48%     ~       (p=0.841 n=5+5)
pjit_aot_1_device/num_args:1                              3.29µs ± 2%          3.12µs ± 4%   -5.24%    (p=0.016 n=4+5)
pjit_aot_1_device/num_args:10                             13.0µs ± 1%          12.7µs ± 2%     ~       (p=0.063 n=4+5)
pjit_aot_1_device/num_args:100                             111µs ± 5%           110µs ±11%     ~       (p=0.421 n=5+5)
pjit_aot_4_device/num_args:1                              38.4µs ±19%          38.9µs ±24%     ~       (p=1.000 n=5+5)
pjit_aot_4_device/num_args:10                             91.3µs ±15%          96.9µs ±29%     ~       (p=0.548 n=5+5)
pjit_aot_4_device/num_args:100                             676µs ±20%           689µs ±41%     ~       (p=0.841 n=5+5)
host_local_array_to_global_array                           196µs ± 6%           194µs ± 4%     ~       (p=0.548 n=5+5)
device_put                                                50.8µs ± 1%          50.7µs ± 4%     ~       (p=0.413 n=4+5)
device_put_sharded                                         176µs ± 0%           177µs ± 4%     ~       (p=0.190 n=4+5)
device_get_8_devices                                      3.96ms ± 4%          4.03ms ± 7%     ~       (p=0.413 n=4+5)
np_asarray_8_devices                                      3.34ms ±18%          3.30ms ±10%     ~       (p=0.548 n=5+5)
jax_array_arrays_8_devices                                5.01ms ±10%          5.09ms ±21%     ~       (p=0.421 n=5+5)
batch_inplace_while_scatter                                440µs ± 1%           439µs ± 1%     ~       (p=0.421 n=5+5)
batch_inplace_while_dynamic_update_slice                   454µs ± 0%           457µs ± 1%     ~       (p=0.905 n=4+5)
serial_dot_products                                       4.51µs ± 3%          4.41µs ± 2%     ~       (p=0.151 n=5+5)
bench_make_array_from_callback_fully_replicated_sharding  26.6µs ± 1%          27.0µs ± 2%     ~       (p=0.056 n=5+5)
```

PiperOrigin-RevId: 586505950
2023-11-29 18:07:13 -08:00
jax authors
b6c73f8992 Merge pull request #18740 from mattjj:shmap-conv-rules
PiperOrigin-RevId: 586499305
2023-11-29 17:33:24 -08:00
Matthew Johnson
6f20c0af38 [shard-map] add conv replication rules
fixes #18737
2023-11-29 16:58:54 -08:00
jax authors
57e19db104 Merge pull request #18736 from mattjj:device-put-fixes
PiperOrigin-RevId: 586490689
2023-11-29 16:51:15 -08:00
Matthew Johnson
c9ab0bfd3c fix grad device_put src inference, and a small device_put bug
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-11-29 16:24:24 -08:00
jax authors
f0382a5838 Merge pull request #18728 from jakevdp:dep-device-buffer
PiperOrigin-RevId: 586481166
2023-11-29 16:10:56 -08:00
Jake VanderPlas
0aec40a16f Deprecate arr.device_buffer and arr.device_buffers 2023-11-29 15:31:01 -08:00
Peter Hawkins
842ca2ccc5 Use process_count(backend) in local_devices().
Due to what is arguably a bug, multiple TPU devices in the same job can have the same process index. When determining a process count for, say, CPU, make sure we use the same backend to compute the process_count. Otherwise we might see an apparently out-of-range process index from another backend.

We should perhaps fix the TPU backend not to do this, but that's going to be a bigger change.

PiperOrigin-RevId: 586453157
2023-11-29 14:24:06 -08:00
Yash Katariya
d6637da431 Disable test_memory_cosumption test
PiperOrigin-RevId: 586426753
2023-11-29 12:50:46 -08:00
jax authors
0fce77a70e Merge pull request #18708 from jakevdp:array-equal-dep
PiperOrigin-RevId: 586357829
2023-11-29 08:58:29 -08:00
Adam Paszke
ef65ba8f32 Internal change
PiperOrigin-RevId: 586312803
2023-11-29 05:47:50 -08:00
Peter Hawkins
458a8962be Always lower reduce_scatter_p as an HLO ReduceScatter.
We don't need the fallback path for CPU: XLA:CPU already does its own lowering of ReduceScatter as AllReduce + DynamicSlice, and I plan to teach it a direct lowering in an upcoming change.

PiperOrigin-RevId: 586311031
2023-11-29 05:37:58 -08:00
jax authors
86d9398078 Update XLA dependency to use revision
422b8e7fb7.

PiperOrigin-RevId: 586305722
2023-11-29 05:14:13 -08:00
jax authors
528a90e97d Merge pull request #18724 from apaszke:downgrade-logging
PiperOrigin-RevId: 586304793
2023-11-29 05:06:18 -08:00
Peter Hawkins
1e961b80da Remove fallback path that lowers all_gather via psum.
As far as I can tell this is no longer necessary on GPU, which handles arbitrary allgather dimensions (by making the dimension the major-most dimension in layout assignment), and on CPU, where at present XLA would do the same lowering JAX would.

I'm planning to improve the XLA:CPU lowering in a subsequent change.

PiperOrigin-RevId: 586291911
2023-11-29 04:14:11 -08:00
Adam Paszke
d80c15aaee Downgrade a bunch of logging to DEBUG
The logs related to compilation cache ended up being quite chatty,
which is quite unlike the other logs in JAX. This downgrades a bunch
of them to debug, as they can always be enabled independently
using JAX config. This should also fix the recent failures in
logging_test.py.
2023-11-29 12:10:53 +00:00
Sharad Vikram
8dfbf90602 [Pallas/Mosaic] Add support for barrier semaphores
PiperOrigin-RevId: 586289340
2023-11-29 04:04:11 -08:00
jax authors
5bcf2311f9 Update XLA dependency to use revision
1317a629f3.

PiperOrigin-RevId: 586206098
2023-11-28 22:23:33 -08:00
jax authors
896d4cfbaf Disable task_using_cache_metric unit test while debugging.
This test is failing in the OSS environment. Temporarily
disabling the test while debugging.

PiperOrigin-RevId: 586144501
2023-11-28 17:04:23 -08:00
jax authors
7d4d912065 Merge pull request #18632 from mattjj:shmap-eager-custom-jvp-vjp
PiperOrigin-RevId: 586143644
2023-11-28 16:56:43 -08:00
jax authors
28063b886d Merge pull request #18718 from jakevdp:sparse-conj
PiperOrigin-RevId: 586131627
2023-11-28 16:14:20 -08:00
Matthew Johnson
7589c2bdb8 [shard_map] implement eager custom_jvp / custom_vjp 2023-11-28 16:08:56 -08:00
jax authors
2ccdfa6da2 Merge pull request #18711 from mattjj:shmap-transpose-fix-3
PiperOrigin-RevId: 586131306
2023-11-28 16:06:08 -08:00
Jake VanderPlas
05d18ac998 [sparse] add sparsify support for conj_p 2023-11-28 15:46:47 -08:00
Matthew Johnson
5fbda6d060 [shard-map] fix transpose replication checking bug with integer_pow 2023-11-28 15:39:32 -08:00
jax authors
5178afc81e Merge pull request #18714 from jakevdp:sparse-shape
PiperOrigin-RevId: 586123122
2023-11-28 15:34:16 -08:00
Jake VanderPlas
d6d7061d71 [sparse] canonicalize sparse shapes 2023-11-28 14:38:48 -08:00
Yash Katariya
cb7c2ed848 Use trace_to_jaxpr_dynamic for the apply_primitive path. trace_to_jaxpr_final is only for final style primitives. Also do some cleanup.
PiperOrigin-RevId: 586106427
2023-11-28 14:35:44 -08:00
jax authors
a66fea78b0 Merge pull request #18712 from jakevdp:warning-typo
PiperOrigin-RevId: 586103610
2023-11-28 14:25:51 -08:00
Jake VanderPlas
ae662be5ef Fix typo in deprecation warning 2023-11-28 13:56:49 -08:00
Jake VanderPlas
13dd5e42cc Deprecate non-array inputs to jnp.array_equal & jnp.array_equiv 2023-11-28 13:55:18 -08:00