jax authors
a07ed22b02
Merge pull request #18756 from jakevdp:skips
...
PiperOrigin-RevId: 586795988
2023-11-30 14:56:11 -08:00
jax authors
efb4924699
Merge pull request #18754 from mattjj:fix-float0
...
PiperOrigin-RevId: 586781308
2023-11-30 14:05:11 -08:00
Matthew Johnson
43ed74f817
rewrite test not to include float0 broadcast
2023-11-30 13:53:13 -08:00
Jake VanderPlas
d6154e5d89
[array-api] remove some test skips
2023-11-30 13:28:08 -08:00
jax authors
5b3fc1bd5d
Merge pull request #18730 from jakevdp:dep-device
...
PiperOrigin-RevId: 586761439
2023-11-30 12:58:16 -08:00
Mark Sandler
569f06cda7
In python 3.11 async.run() always tries to convert repr of the result of a coroutine as integer while fetching sigint handler. This makes the test materialize the whole tensor in memory. This changes the test co-routine to return nothing to avoid triggering this bug.
...
https://github.com/python/cpython/issues/112559
PiperOrigin-RevId: 586756112
2023-11-30 12:37:12 -08:00
Jake VanderPlas
97beb01c43
Deprecate the device() method of JAX arrays
2023-11-30 11:43:02 -08:00
jax authors
4de07b3f62
Merge pull request #18753 from jakevdp:warnings-tests
...
PiperOrigin-RevId: 586739682
2023-11-30 11:37:24 -08:00
jax authors
53e66c1214
Merge pull request #18752 from mattjj:shmap-remat-rule
...
PiperOrigin-RevId: 586729063
2023-11-30 11:02:20 -08:00
jax authors
dab6379995
Merge pull request #18746 from olupton:fix-repeated-builds
...
PiperOrigin-RevId: 586723814
2023-11-30 10:47:45 -08:00
Shashank Viswanadha
bd46e5c960
Add nb::arg
to nanobind definitions to generate better python annotations.
...
PiperOrigin-RevId: 586721759
2023-11-30 10:39:28 -08:00
Jake VanderPlas
d2b4800723
tests: improve warnings-related tests
2023-11-30 10:35:24 -08:00
Matthew Johnson
5862852f85
[shard-map] add rewrite and replication checking rules for remat
...
these rules enable shmap-of-remat with check_rep=True
2023-11-30 10:15:48 -08:00
jax authors
11d7a2b860
Merge pull request #18741 from mattjj:shmap-test-fix
...
PiperOrigin-RevId: 586710378
2023-11-30 10:09:32 -08:00
Matthew Johnson
5c2635c205
[shard-map] fix test running broken by 0aec40a16fad02f084ef0cabd350db78b86b335e
2023-11-30 09:56:34 -08:00
Olli Lupton
e50c35d1e7
Fix repeatedly building JAX.
...
Reproducer was essentially running `pip install .` twice in a row in the
same source directory. Closes google/jax#18252 .
2023-11-30 13:31:17 +01:00
jax authors
fe237cd776
Update XLA dependency to use revision
...
58e6b428e2
.
PiperOrigin-RevId: 586556206
2023-11-29 22:44:13 -08:00
jax authors
d9a0cc0a07
Merge pull request #18731 from mattjj:shmap-custom-jvp-fix
...
PiperOrigin-RevId: 586527134
2023-11-29 20:24:54 -08:00
Matthew Johnson
b8f758e4a0
[shard-map] replace jaxpr interpreters with final-style-xform-of-eval-jaxpr
2023-11-29 20:06:12 -08:00
Yash Katariya
e624610e72
Replace apply_primitive internals with jax.jit
.
...
This allows deletion of a lot of code and leads to ~40% eager performance speedup.
Benchmarks:
```
name old time/op new time/op delta
eager_unary_dispatch 31.3µs ± 1% 19.4µs ± 6% -37.91% (p=0.016 n=4+5)
eager_unary 32.1µs ± 0% 19.8µs ± 4% -38.26% (p=0.016 n=4+5)
eager_binary_dispatch 35.9µs ± 1% 20.5µs ± 4% -42.93% (p=0.016 n=4+5)
eager_binary 36.6µs ± 1% 21.1µs ± 4% -42.29% (p=0.016 n=4+5)
jit_trivial_dispatch 3.87µs ± 2% 4.12µs ±25% ~ (p=1.000 n=5+5)
jit_trivial 4.75µs ± 2% 4.82µs ±11% ~ (p=0.690 n=5+5)
jit_simple_dispatch 2.95µs ± 2% 2.97µs ± 7% ~ (p=1.000 n=5+5)
jit_simple 3.52µs ± 6% 3.51µs ± 5% ~ (p=0.841 n=5+5)
jit_simple_dispatch_array 2.95µs ± 2% 2.96µs ± 6% ~ (p=1.000 n=5+5)
jit_simple_array 3.46µs ± 2% 3.51µs ± 5% ~ (p=0.690 n=5+5)
jit_small_matmul 3.01µs ± 1% 3.00µs ± 4% ~ (p=0.548 n=5+5)
jit_big_matmul 34.0µs ±18% 35.5µs ±17% ~ (p=0.310 n=5+5)
jit_simple_many_args_dispatch/num_args:10 6.93µs ± 6% 6.80µs ± 6% ~ (p=0.481 n=10+10)
jit_simple_many_args_dispatch/num_args:100 47.7µs ± 7% 45.4µs ± 2% ~ (p=0.237 n=10+8)
jit_simple_many_args_dispatch/num_args:1000 545µs ± 8% 516µs ± 2% ~ (p=0.101 n=10+8)
jit_simple_many_args_dispatch/num_args:2000 1.12ms ± 7% 1.07ms ± 2% ~ (p=0.237 n=10+8)
jit_simple_many_args/num_args:10 7.42µs ± 5% 7.23µs ± 2% ~ (p=0.173 n=10+8)
jit_simple_many_args/num_args:100 48.4µs ± 7% 45.6µs ± 2% ~ (p=0.237 n=10+8)
jit_simple_many_args/num_args:1000 542µs ± 6% 524µs ± 8% ~ (p=0.089 n=10+10)
jit_simple_many_args/num_args:2000 1.12ms ± 7% 1.08ms ± 1% ~ (p=0.068 n=10+8)
jit_simple_pruned_args_dispatch_10 4.79µs ± 8% 4.98µs ±10% ~ (p=0.421 n=5+5)
jit_simple_pruned_args_10 5.32µs ± 6% 5.30µs ± 4% ~ (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_100 24.7µs ± 6% 23.8µs ± 8% ~ (p=0.548 n=5+5)
jit_simple_pruned_args_100 25.2µs ± 6% 24.4µs ± 8% ~ (p=0.690 n=5+5)
jit_simple_pruned_args_dispatch_1000 238µs ± 7% 232µs ± 8% ~ (p=0.841 n=5+5)
jit_simple_pruned_args_1000 240µs ± 7% 234µs ± 8% ~ (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_2000 516µs ± 6% 497µs ± 1% ~ (p=0.413 n=5+4)
jit_simple_pruned_args_2000 517µs ± 6% 505µs ± 7% ~ (p=0.690 n=5+5)
jit_dispatch_without_transfer 719µs ± 9% 751µs ± 8% ~ (p=0.222 n=5+5)
jit_dispatch_with_transfer 799µs ±14% 793µs ± 9% ~ (p=1.000 n=5+5)
pmap_trivial_2_devices 49.9µs ±40% 48.2µs ±42% ~ (p=0.841 n=5+5)
pmap_trivial_dispatch_8_devices 74.5µs ±24% 78.9µs ±29% ~ (p=0.421 n=5+5)
pmap_trivial_8_devices 79.3µs ± 6% 82.7µs ±20% ~ (p=0.841 n=5+5)
pmap_simple_2_devices 47.1µs ±17% 49.1µs ±20% ~ (p=0.548 n=5+5)
pmap_simple_dispatch_8_devices 73.4µs ±16% 76.8µs ±21% ~ (p=0.690 n=5+5)
pmap_simple_8_devices 76.0µs ±10% 80.6µs ±29% ~ (p=1.000 n=5+5)
pmap_simple_dispatch_8_devices_100_args 1.12ms ±22% 1.08ms ±42% ~ (p=0.841 n=5+5)
pmap_simple_8_devices_100_args 12.5ms ± 8% 12.8ms ±10% ~ (p=1.000 n=5+5)
sda_index_1 413µs ± 1% 686µs ± 4% +66.08% (p=0.008 n=5+5)
sda_index_2 850µs ± 1% 1378µs ± 4% +62.02% (p=0.008 n=5+5)
sda_index_8 3.60ms ± 1% 5.69ms ± 4% +58.00% (p=0.008 n=5+5)
bench_shaped_abstractify 300µs ± 1% 305µs ± 3% ~ (p=0.056 n=5+5)
bench_xla_abstractify_scalar_int 6.45µs ± 1% 6.50µs ± 3% ~ (p=0.548 n=5+5)
bench_xla_abstractify_scalar_float 3.73µs ± 1% 3.73µs ± 3% ~ (p=0.690 n=5+5)
bench_xla_abstractify_scalar_numpy_int32 4.97µs ± 1% 4.83µs ± 3% ~ (p=0.095 n=5+5)
bench_xla_abstractify_scalar_numpy_uint32 4.91µs ± 1% 4.75µs ± 0% -3.30% (p=0.016 n=5+4)
bench_xla_abstractify_numpy_random 4.34µs ± 2% 4.31µs ± 3% ~ (p=0.310 n=5+5)
bench_xla_abstractify_numpy_arange_100_float32 3.94µs ± 1% 3.93µs ± 3% ~ (p=0.548 n=5+5)
bench_xla_abstractify_enum 6.85µs ± 1% 7.06µs ± 7% +3.07% (p=0.032 n=5+5)
bench_are_op_shardings_equal 26.9µs ± 2% 27.0µs ± 3% ~ (p=0.841 n=5+5)
bench_pjit_check_aval_sharding 691µs ± 2% 711µs ±13% ~ (p=0.841 n=5+5)
bench_addressable_shards_index 656ns ± 4% 688ns ± 9% ~ (p=0.095 n=5+5)
bench_remat_eager_retracing_overheads 12.7ms ± 4% 10.7ms ± 1% -15.48% (p=0.016 n=5+4)
bench_remat_eager_retracing_overheads_static_argnums 13.0ms ± 2% 11.3ms ± 6% -13.71% (p=0.008 n=5+5)
bench_slicing_compilation 12.1ms ± 1% 12.3ms ± 4% ~ (p=0.690 n=5+5)
bench_slicing_compilation2 11.3ms ± 0% 11.5ms ± 6% ~ (p=0.690 n=5+5)
bench_repeated_static_indexing 62.5ms ± 2% 40.8ms ± 8% -34.77% (p=0.008 n=5+5)
bench_repeated_static_slicing 46.7ms ± 1% 31.4ms ± 2% -32.76% (p=0.008 n=5+5)
pjit_simple_1_device/num_args:1 2.72µs ± 2% 2.68µs ± 5% ~ (p=0.151 n=5+5)
pjit_simple_1_device/num_args:10 12.6µs ± 7% 12.3µs ± 3% ~ (p=0.310 n=5+5)
pjit_simple_1_device/num_args:100 109µs ± 3% 108µs ± 4% ~ (p=0.548 n=5+5)
pjit_simple_4_device/num_args:1 38.0µs ±26% 36.8µs ±19% ~ (p=0.690 n=5+5)
pjit_simple_4_device/num_args:10 93.3µs ±19% 96.6µs ±23% ~ (p=0.841 n=5+5)
pjit_simple_4_device/num_args:100 730µs ±16% 698µs ±48% ~ (p=0.841 n=5+5)
pjit_aot_1_device/num_args:1 3.29µs ± 2% 3.12µs ± 4% -5.24% (p=0.016 n=4+5)
pjit_aot_1_device/num_args:10 13.0µs ± 1% 12.7µs ± 2% ~ (p=0.063 n=4+5)
pjit_aot_1_device/num_args:100 111µs ± 5% 110µs ±11% ~ (p=0.421 n=5+5)
pjit_aot_4_device/num_args:1 38.4µs ±19% 38.9µs ±24% ~ (p=1.000 n=5+5)
pjit_aot_4_device/num_args:10 91.3µs ±15% 96.9µs ±29% ~ (p=0.548 n=5+5)
pjit_aot_4_device/num_args:100 676µs ±20% 689µs ±41% ~ (p=0.841 n=5+5)
host_local_array_to_global_array 196µs ± 6% 194µs ± 4% ~ (p=0.548 n=5+5)
device_put 50.8µs ± 1% 50.7µs ± 4% ~ (p=0.413 n=4+5)
device_put_sharded 176µs ± 0% 177µs ± 4% ~ (p=0.190 n=4+5)
device_get_8_devices 3.96ms ± 4% 4.03ms ± 7% ~ (p=0.413 n=4+5)
np_asarray_8_devices 3.34ms ±18% 3.30ms ±10% ~ (p=0.548 n=5+5)
jax_array_arrays_8_devices 5.01ms ±10% 5.09ms ±21% ~ (p=0.421 n=5+5)
batch_inplace_while_scatter 440µs ± 1% 439µs ± 1% ~ (p=0.421 n=5+5)
batch_inplace_while_dynamic_update_slice 454µs ± 0% 457µs ± 1% ~ (p=0.905 n=4+5)
serial_dot_products 4.51µs ± 3% 4.41µs ± 2% ~ (p=0.151 n=5+5)
bench_make_array_from_callback_fully_replicated_sharding 26.6µs ± 1% 27.0µs ± 2% ~ (p=0.056 n=5+5)
```
PiperOrigin-RevId: 586505950
2023-11-29 18:07:13 -08:00
jax authors
b6c73f8992
Merge pull request #18740 from mattjj:shmap-conv-rules
...
PiperOrigin-RevId: 586499305
2023-11-29 17:33:24 -08:00
Matthew Johnson
6f20c0af38
[shard-map] add conv replication rules
...
fixes #18737
2023-11-29 16:58:54 -08:00
jax authors
57e19db104
Merge pull request #18736 from mattjj:device-put-fixes
...
PiperOrigin-RevId: 586490689
2023-11-29 16:51:15 -08:00
Matthew Johnson
c9ab0bfd3c
fix grad device_put src inference, and a small device_put bug
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-11-29 16:24:24 -08:00
jax authors
f0382a5838
Merge pull request #18728 from jakevdp:dep-device-buffer
...
PiperOrigin-RevId: 586481166
2023-11-29 16:10:56 -08:00
Jake VanderPlas
0aec40a16f
Deprecate arr.device_buffer and arr.device_buffers
2023-11-29 15:31:01 -08:00
Peter Hawkins
842ca2ccc5
Use process_count(backend) in local_devices().
...
Due to what is arguably a bug, multiple TPU devices in the same job can have the same process index. When determining a process count for, say, CPU, make sure we use the same backend to compute the process_count. Otherwise we might see an apparently out-of-range process index from another backend.
We should perhaps fix the TPU backend not to do this, but that's going to be a bigger change.
PiperOrigin-RevId: 586453157
2023-11-29 14:24:06 -08:00
Yash Katariya
d6637da431
Disable test_memory_cosumption test
...
PiperOrigin-RevId: 586426753
2023-11-29 12:50:46 -08:00
jax authors
0fce77a70e
Merge pull request #18708 from jakevdp:array-equal-dep
...
PiperOrigin-RevId: 586357829
2023-11-29 08:58:29 -08:00
Adam Paszke
ef65ba8f32
Internal change
...
PiperOrigin-RevId: 586312803
2023-11-29 05:47:50 -08:00
Peter Hawkins
458a8962be
Always lower reduce_scatter_p as an HLO ReduceScatter.
...
We don't need the fallback path for CPU: XLA:CPU already does its own lowering of ReduceScatter as AllReduce + DynamicSlice, and I plan to teach it a direct lowering in an upcoming change.
PiperOrigin-RevId: 586311031
2023-11-29 05:37:58 -08:00
jax authors
86d9398078
Update XLA dependency to use revision
...
422b8e7fb7
.
PiperOrigin-RevId: 586305722
2023-11-29 05:14:13 -08:00
jax authors
528a90e97d
Merge pull request #18724 from apaszke:downgrade-logging
...
PiperOrigin-RevId: 586304793
2023-11-29 05:06:18 -08:00
Peter Hawkins
1e961b80da
Remove fallback path that lowers all_gather via psum.
...
As far as I can tell this is no longer necessary on GPU, which handles arbitrary allgather dimensions (by making the dimension the major-most dimension in layout assignment), and on CPU, where at present XLA would do the same lowering JAX would.
I'm planning to improve the XLA:CPU lowering in a subsequent change.
PiperOrigin-RevId: 586291911
2023-11-29 04:14:11 -08:00
Adam Paszke
d80c15aaee
Downgrade a bunch of logging to DEBUG
...
The logs related to compilation cache ended up being quite chatty,
which is quite unlike the other logs in JAX. This downgrades a bunch
of them to debug, as they can always be enabled independently
using JAX config. This should also fix the recent failures in
logging_test.py.
2023-11-29 12:10:53 +00:00
Sharad Vikram
8dfbf90602
[Pallas/Mosaic] Add support for barrier semaphores
...
PiperOrigin-RevId: 586289340
2023-11-29 04:04:11 -08:00
jax authors
5bcf2311f9
Update XLA dependency to use revision
...
1317a629f3
.
PiperOrigin-RevId: 586206098
2023-11-28 22:23:33 -08:00
jax authors
896d4cfbaf
Disable task_using_cache_metric unit test while debugging.
...
This test is failing in the OSS environment. Temporarily
disabling the test while debugging.
PiperOrigin-RevId: 586144501
2023-11-28 17:04:23 -08:00
jax authors
7d4d912065
Merge pull request #18632 from mattjj:shmap-eager-custom-jvp-vjp
...
PiperOrigin-RevId: 586143644
2023-11-28 16:56:43 -08:00
jax authors
28063b886d
Merge pull request #18718 from jakevdp:sparse-conj
...
PiperOrigin-RevId: 586131627
2023-11-28 16:14:20 -08:00
Matthew Johnson
7589c2bdb8
[shard_map] implement eager custom_jvp / custom_vjp
2023-11-28 16:08:56 -08:00
jax authors
2ccdfa6da2
Merge pull request #18711 from mattjj:shmap-transpose-fix-3
...
PiperOrigin-RevId: 586131306
2023-11-28 16:06:08 -08:00
Jake VanderPlas
05d18ac998
[sparse] add sparsify support for conj_p
2023-11-28 15:46:47 -08:00
Matthew Johnson
5fbda6d060
[shard-map] fix transpose replication checking bug with integer_pow
2023-11-28 15:39:32 -08:00
jax authors
5178afc81e
Merge pull request #18714 from jakevdp:sparse-shape
...
PiperOrigin-RevId: 586123122
2023-11-28 15:34:16 -08:00
Jake VanderPlas
d6d7061d71
[sparse] canonicalize sparse shapes
2023-11-28 14:38:48 -08:00
Yash Katariya
cb7c2ed848
Use trace_to_jaxpr_dynamic
for the apply_primitive path. trace_to_jaxpr_final
is only for final style primitives. Also do some cleanup.
...
PiperOrigin-RevId: 586106427
2023-11-28 14:35:44 -08:00
jax authors
a66fea78b0
Merge pull request #18712 from jakevdp:warning-typo
...
PiperOrigin-RevId: 586103610
2023-11-28 14:25:51 -08:00
Jake VanderPlas
ae662be5ef
Fix typo in deprecation warning
2023-11-28 13:56:49 -08:00
Jake VanderPlas
13dd5e42cc
Deprecate non-array inputs to jnp.array_equal & jnp.array_equiv
2023-11-28 13:55:18 -08:00