jax authors
8ad774fb10
Automate arguments for jax.distributed.initialize for cloud TPU environments.
...
PiperOrigin-RevId: 586892544
2023-11-30 22:25:00 -08:00
Jake VanderPlas
97beb01c43
Deprecate the device() method of JAX arrays
2023-11-30 11:43:02 -08:00
Jake VanderPlas
d2b4800723
tests: improve warnings-related tests
2023-11-30 10:35:24 -08:00
Yash Katariya
e624610e72
Replace apply_primitive internals with jax.jit
.
...
This allows deletion of a lot of code and leads to ~40% eager performance speedup.
Benchmarks:
```
name old time/op new time/op delta
eager_unary_dispatch 31.3µs ± 1% 19.4µs ± 6% -37.91% (p=0.016 n=4+5)
eager_unary 32.1µs ± 0% 19.8µs ± 4% -38.26% (p=0.016 n=4+5)
eager_binary_dispatch 35.9µs ± 1% 20.5µs ± 4% -42.93% (p=0.016 n=4+5)
eager_binary 36.6µs ± 1% 21.1µs ± 4% -42.29% (p=0.016 n=4+5)
jit_trivial_dispatch 3.87µs ± 2% 4.12µs ±25% ~ (p=1.000 n=5+5)
jit_trivial 4.75µs ± 2% 4.82µs ±11% ~ (p=0.690 n=5+5)
jit_simple_dispatch 2.95µs ± 2% 2.97µs ± 7% ~ (p=1.000 n=5+5)
jit_simple 3.52µs ± 6% 3.51µs ± 5% ~ (p=0.841 n=5+5)
jit_simple_dispatch_array 2.95µs ± 2% 2.96µs ± 6% ~ (p=1.000 n=5+5)
jit_simple_array 3.46µs ± 2% 3.51µs ± 5% ~ (p=0.690 n=5+5)
jit_small_matmul 3.01µs ± 1% 3.00µs ± 4% ~ (p=0.548 n=5+5)
jit_big_matmul 34.0µs ±18% 35.5µs ±17% ~ (p=0.310 n=5+5)
jit_simple_many_args_dispatch/num_args:10 6.93µs ± 6% 6.80µs ± 6% ~ (p=0.481 n=10+10)
jit_simple_many_args_dispatch/num_args:100 47.7µs ± 7% 45.4µs ± 2% ~ (p=0.237 n=10+8)
jit_simple_many_args_dispatch/num_args:1000 545µs ± 8% 516µs ± 2% ~ (p=0.101 n=10+8)
jit_simple_many_args_dispatch/num_args:2000 1.12ms ± 7% 1.07ms ± 2% ~ (p=0.237 n=10+8)
jit_simple_many_args/num_args:10 7.42µs ± 5% 7.23µs ± 2% ~ (p=0.173 n=10+8)
jit_simple_many_args/num_args:100 48.4µs ± 7% 45.6µs ± 2% ~ (p=0.237 n=10+8)
jit_simple_many_args/num_args:1000 542µs ± 6% 524µs ± 8% ~ (p=0.089 n=10+10)
jit_simple_many_args/num_args:2000 1.12ms ± 7% 1.08ms ± 1% ~ (p=0.068 n=10+8)
jit_simple_pruned_args_dispatch_10 4.79µs ± 8% 4.98µs ±10% ~ (p=0.421 n=5+5)
jit_simple_pruned_args_10 5.32µs ± 6% 5.30µs ± 4% ~ (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_100 24.7µs ± 6% 23.8µs ± 8% ~ (p=0.548 n=5+5)
jit_simple_pruned_args_100 25.2µs ± 6% 24.4µs ± 8% ~ (p=0.690 n=5+5)
jit_simple_pruned_args_dispatch_1000 238µs ± 7% 232µs ± 8% ~ (p=0.841 n=5+5)
jit_simple_pruned_args_1000 240µs ± 7% 234µs ± 8% ~ (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_2000 516µs ± 6% 497µs ± 1% ~ (p=0.413 n=5+4)
jit_simple_pruned_args_2000 517µs ± 6% 505µs ± 7% ~ (p=0.690 n=5+5)
jit_dispatch_without_transfer 719µs ± 9% 751µs ± 8% ~ (p=0.222 n=5+5)
jit_dispatch_with_transfer 799µs ±14% 793µs ± 9% ~ (p=1.000 n=5+5)
pmap_trivial_2_devices 49.9µs ±40% 48.2µs ±42% ~ (p=0.841 n=5+5)
pmap_trivial_dispatch_8_devices 74.5µs ±24% 78.9µs ±29% ~ (p=0.421 n=5+5)
pmap_trivial_8_devices 79.3µs ± 6% 82.7µs ±20% ~ (p=0.841 n=5+5)
pmap_simple_2_devices 47.1µs ±17% 49.1µs ±20% ~ (p=0.548 n=5+5)
pmap_simple_dispatch_8_devices 73.4µs ±16% 76.8µs ±21% ~ (p=0.690 n=5+5)
pmap_simple_8_devices 76.0µs ±10% 80.6µs ±29% ~ (p=1.000 n=5+5)
pmap_simple_dispatch_8_devices_100_args 1.12ms ±22% 1.08ms ±42% ~ (p=0.841 n=5+5)
pmap_simple_8_devices_100_args 12.5ms ± 8% 12.8ms ±10% ~ (p=1.000 n=5+5)
sda_index_1 413µs ± 1% 686µs ± 4% +66.08% (p=0.008 n=5+5)
sda_index_2 850µs ± 1% 1378µs ± 4% +62.02% (p=0.008 n=5+5)
sda_index_8 3.60ms ± 1% 5.69ms ± 4% +58.00% (p=0.008 n=5+5)
bench_shaped_abstractify 300µs ± 1% 305µs ± 3% ~ (p=0.056 n=5+5)
bench_xla_abstractify_scalar_int 6.45µs ± 1% 6.50µs ± 3% ~ (p=0.548 n=5+5)
bench_xla_abstractify_scalar_float 3.73µs ± 1% 3.73µs ± 3% ~ (p=0.690 n=5+5)
bench_xla_abstractify_scalar_numpy_int32 4.97µs ± 1% 4.83µs ± 3% ~ (p=0.095 n=5+5)
bench_xla_abstractify_scalar_numpy_uint32 4.91µs ± 1% 4.75µs ± 0% -3.30% (p=0.016 n=5+4)
bench_xla_abstractify_numpy_random 4.34µs ± 2% 4.31µs ± 3% ~ (p=0.310 n=5+5)
bench_xla_abstractify_numpy_arange_100_float32 3.94µs ± 1% 3.93µs ± 3% ~ (p=0.548 n=5+5)
bench_xla_abstractify_enum 6.85µs ± 1% 7.06µs ± 7% +3.07% (p=0.032 n=5+5)
bench_are_op_shardings_equal 26.9µs ± 2% 27.0µs ± 3% ~ (p=0.841 n=5+5)
bench_pjit_check_aval_sharding 691µs ± 2% 711µs ±13% ~ (p=0.841 n=5+5)
bench_addressable_shards_index 656ns ± 4% 688ns ± 9% ~ (p=0.095 n=5+5)
bench_remat_eager_retracing_overheads 12.7ms ± 4% 10.7ms ± 1% -15.48% (p=0.016 n=5+4)
bench_remat_eager_retracing_overheads_static_argnums 13.0ms ± 2% 11.3ms ± 6% -13.71% (p=0.008 n=5+5)
bench_slicing_compilation 12.1ms ± 1% 12.3ms ± 4% ~ (p=0.690 n=5+5)
bench_slicing_compilation2 11.3ms ± 0% 11.5ms ± 6% ~ (p=0.690 n=5+5)
bench_repeated_static_indexing 62.5ms ± 2% 40.8ms ± 8% -34.77% (p=0.008 n=5+5)
bench_repeated_static_slicing 46.7ms ± 1% 31.4ms ± 2% -32.76% (p=0.008 n=5+5)
pjit_simple_1_device/num_args:1 2.72µs ± 2% 2.68µs ± 5% ~ (p=0.151 n=5+5)
pjit_simple_1_device/num_args:10 12.6µs ± 7% 12.3µs ± 3% ~ (p=0.310 n=5+5)
pjit_simple_1_device/num_args:100 109µs ± 3% 108µs ± 4% ~ (p=0.548 n=5+5)
pjit_simple_4_device/num_args:1 38.0µs ±26% 36.8µs ±19% ~ (p=0.690 n=5+5)
pjit_simple_4_device/num_args:10 93.3µs ±19% 96.6µs ±23% ~ (p=0.841 n=5+5)
pjit_simple_4_device/num_args:100 730µs ±16% 698µs ±48% ~ (p=0.841 n=5+5)
pjit_aot_1_device/num_args:1 3.29µs ± 2% 3.12µs ± 4% -5.24% (p=0.016 n=4+5)
pjit_aot_1_device/num_args:10 13.0µs ± 1% 12.7µs ± 2% ~ (p=0.063 n=4+5)
pjit_aot_1_device/num_args:100 111µs ± 5% 110µs ±11% ~ (p=0.421 n=5+5)
pjit_aot_4_device/num_args:1 38.4µs ±19% 38.9µs ±24% ~ (p=1.000 n=5+5)
pjit_aot_4_device/num_args:10 91.3µs ±15% 96.9µs ±29% ~ (p=0.548 n=5+5)
pjit_aot_4_device/num_args:100 676µs ±20% 689µs ±41% ~ (p=0.841 n=5+5)
host_local_array_to_global_array 196µs ± 6% 194µs ± 4% ~ (p=0.548 n=5+5)
device_put 50.8µs ± 1% 50.7µs ± 4% ~ (p=0.413 n=4+5)
device_put_sharded 176µs ± 0% 177µs ± 4% ~ (p=0.190 n=4+5)
device_get_8_devices 3.96ms ± 4% 4.03ms ± 7% ~ (p=0.413 n=4+5)
np_asarray_8_devices 3.34ms ±18% 3.30ms ±10% ~ (p=0.548 n=5+5)
jax_array_arrays_8_devices 5.01ms ±10% 5.09ms ±21% ~ (p=0.421 n=5+5)
batch_inplace_while_scatter 440µs ± 1% 439µs ± 1% ~ (p=0.421 n=5+5)
batch_inplace_while_dynamic_update_slice 454µs ± 0% 457µs ± 1% ~ (p=0.905 n=4+5)
serial_dot_products 4.51µs ± 3% 4.41µs ± 2% ~ (p=0.151 n=5+5)
bench_make_array_from_callback_fully_replicated_sharding 26.6µs ± 1% 27.0µs ± 2% ~ (p=0.056 n=5+5)
```
PiperOrigin-RevId: 586505950
2023-11-29 18:07:13 -08:00
jax authors
57e19db104
Merge pull request #18736 from mattjj:device-put-fixes
...
PiperOrigin-RevId: 586490689
2023-11-29 16:51:15 -08:00
Matthew Johnson
c9ab0bfd3c
fix grad device_put src inference, and a small device_put bug
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-11-29 16:24:24 -08:00
Jake VanderPlas
0aec40a16f
Deprecate arr.device_buffer and arr.device_buffers
2023-11-29 15:31:01 -08:00
Peter Hawkins
842ca2ccc5
Use process_count(backend) in local_devices().
...
Due to what is arguably a bug, multiple TPU devices in the same job can have the same process index. When determining a process count for, say, CPU, make sure we use the same backend to compute the process_count. Otherwise we might see an apparently out-of-range process index from another backend.
We should perhaps fix the TPU backend not to do this, but that's going to be a bigger change.
PiperOrigin-RevId: 586453157
2023-11-29 14:24:06 -08:00
jax authors
0fce77a70e
Merge pull request #18708 from jakevdp:array-equal-dep
...
PiperOrigin-RevId: 586357829
2023-11-29 08:58:29 -08:00
Peter Hawkins
458a8962be
Always lower reduce_scatter_p as an HLO ReduceScatter.
...
We don't need the fallback path for CPU: XLA:CPU already does its own lowering of ReduceScatter as AllReduce + DynamicSlice, and I plan to teach it a direct lowering in an upcoming change.
PiperOrigin-RevId: 586311031
2023-11-29 05:37:58 -08:00
jax authors
528a90e97d
Merge pull request #18724 from apaszke:downgrade-logging
...
PiperOrigin-RevId: 586304793
2023-11-29 05:06:18 -08:00
Peter Hawkins
1e961b80da
Remove fallback path that lowers all_gather via psum.
...
As far as I can tell this is no longer necessary on GPU, which handles arbitrary allgather dimensions (by making the dimension the major-most dimension in layout assignment), and on CPU, where at present XLA would do the same lowering JAX would.
I'm planning to improve the XLA:CPU lowering in a subsequent change.
PiperOrigin-RevId: 586291911
2023-11-29 04:14:11 -08:00
Adam Paszke
d80c15aaee
Downgrade a bunch of logging to DEBUG
...
The logs related to compilation cache ended up being quite chatty,
which is quite unlike the other logs in JAX. This downgrades a bunch
of them to debug, as they can always be enabled independently
using JAX config. This should also fix the recent failures in
logging_test.py.
2023-11-29 12:10:53 +00:00
Sharad Vikram
8dfbf90602
[Pallas/Mosaic] Add support for barrier semaphores
...
PiperOrigin-RevId: 586289340
2023-11-29 04:04:11 -08:00
Yash Katariya
cb7c2ed848
Use trace_to_jaxpr_dynamic
for the apply_primitive path. trace_to_jaxpr_final
is only for final style primitives. Also do some cleanup.
...
PiperOrigin-RevId: 586106427
2023-11-28 14:35:44 -08:00
Jake VanderPlas
ae662be5ef
Fix typo in deprecation warning
2023-11-28 13:56:49 -08:00
Jake VanderPlas
13dd5e42cc
Deprecate non-array inputs to jnp.array_equal & jnp.array_equiv
2023-11-28 13:55:18 -08:00
Jake VanderPlas
a8723ecb9c
Fix grad of jnp.i0 at zero
2023-11-28 12:34:56 -08:00
Yash Katariya
f8e22ae512
Remove jaxpr_debug_info from MeshComputation since that information is available via AllArgsInfo
...
PiperOrigin-RevId: 586018345
2023-11-28 10:05:15 -08:00
George Necula
c6afdfd8d6
[shape_poly] Simplify the API for processing polymorphic_shape specifications
...
Before, we had `export.poly_spec` to create a jax.ShapedDtypeStruct`
given a polymorphic shape specification. This function was
invoked `poly_spec(arg_shape, arg_dtype, polymorphic_shape)`.
The `arg_shape` was only needed when the polymorphic shape spec
contained placeholders.
We break out an `export.symbolic_shape` that is just a parser
of polymorphic shape specs and we ask the user to invoke
`jax.ShapeDtypeStruct` directly:
`jax.ShapeDtypeStruct(export.symbolic_shape(polymorphic_shape, like=arg_shape), arg_dtype)`.
We also rename the `export.poly_specs` to `export.arg_specs`.
2023-11-28 12:45:59 +02:00
Yash Katariya
88d980f164
Typecheck avals and sharding for arguments that were DCE'd.
...
This keeps the promise of AOT that recompilation is guaranteed.
Fixes https://github.com/google/jax/issues/18686
PiperOrigin-RevId: 585855658
2023-11-27 22:39:32 -08:00
Yash Katariya
37f11428a3
Fix indexing bug when querying _input_layouts
...
PiperOrigin-RevId: 585842302
2023-11-27 21:12:07 -08:00
Yash Katariya
81aee237d8
Simply lower_sharding_computation signature by always taking a closed jaxpr as input. For apply_primitive do the tracing to jaxpr in dispatch.py
...
PiperOrigin-RevId: 585810475
2023-11-27 18:00:57 -08:00
Yash Katariya
2ed0fc4d1c
compiled.input_layouts()
should preserve the structure of the original in_tree.
...
JAX by default DCE's arguments that are unused which changes the in_layouts available on the `executable`. This breaks when we try to unflatten the said in_layouts with the original in_tree (because in_tree has all the args DCE'd + non-DCE'd).
The in_layouts that we return to the user should contain layouts for DCE'd + non-DCE'd args. So fill the DCE'd layouts with None which means the default layout. This does not affect the actual HLO computation because JAX will discard the DCE'd layouts anyways, consequently discarding the jax.Arrays created with those layouts.
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 585790912
2023-11-27 16:23:58 -08:00
jax authors
b9b5410ddd
Default-enable the Jax persistent compilation cache.
...
To increase the adoption of the compilation cache, we should
enable it by default. A prerequisite is to configure a default
cache directory.
Switch spherical_cnn molecules training and universal_diffusion
model wrapper to use the default cache.
Testing: manual testing with test workloads.
PiperOrigin-RevId: 585767363
2023-11-27 14:53:20 -08:00
Jake VanderPlas
01fde43fce
Fix sign of jax.scipy.special.gamma for negative inputs
2023-11-27 14:08:02 -08:00
Jake VanderPlas
8bc486a48f
Fix debug_nans issue in sort
2023-11-27 08:41:19 -08:00
Tomás Longeri
523f36153f
[Mosaic] Use C++ apply-vector-layout as default setting
...
PiperOrigin-RevId: 585663081
2023-11-27 08:39:42 -08:00
Adam Paszke
30a76756c3
[Mosaic] Accept that the dump flag might be missing.
...
The flag ends up in libtpu in the OSS build, so we will need to find a different
mechanism for this. This is a quick fix to get things working again.
PiperOrigin-RevId: 584835560
2023-11-23 02:21:29 -08:00
jax authors
bf2464ec82
Merge pull request #18652 from mattjj:debug-callback-docstring-2
...
PiperOrigin-RevId: 584731668
2023-11-22 15:22:30 -08:00
Matthew Johnson
997db225e2
small tweaks to jax.debug.print docstring
2023-11-22 15:03:41 -08:00
jax authors
4770690b4a
Merge pull request #18647 from jakevdp:pref-eltype-doc
...
PiperOrigin-RevId: 584715788
2023-11-22 14:07:04 -08:00
Jake VanderPlas
530cf30bfc
xla_bridge: add logic to avoid version skew
2023-11-22 12:17:21 -08:00
Jake VanderPlas
2acdb120a0
DOC: document preferred_element_type argument to dot functions
2023-11-22 09:49:34 -08:00
Tom Hennigan
1b504bb68e
Allow threads to race setting attributes on Mesh.
...
PiperOrigin-RevId: 584602313
2023-11-22 05:47:56 -08:00
Matthew Johnson
67677eb10e
improve error message for e.g. jnp.zeros(5)[:, 0]
2023-11-21 15:59:21 -08:00
jax authors
2efa5862ad
Merge pull request #18563 from jakevdp:cleanup-fold-in
...
PiperOrigin-RevId: 584414134
2023-11-21 13:37:51 -08:00
Peter Hawkins
84c1e825c0
Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.
...
PiperOrigin-RevId: 584377134
2023-11-21 11:10:12 -08:00
George Necula
a6284a03e5
Disable a integer_pow harness due to overflow.
...
PiperOrigin-RevId: 584261696
2023-11-21 02:23:22 -08:00
Peter Hawkins
49c80e68d1
Fix error/hang when non-finite values are passed to non-symmetric Eigendecomposition.
...
Improve the documentation of lax.eig().
Fixes https://github.com/google/jax/issues/18226
PiperOrigin-RevId: 584170564
2023-11-20 17:32:16 -08:00
jax authors
fc8058a17d
Restrict retrieving XLA-AutoFDO profile version to TPU workloads.
...
XLA-AutoFDO is supported only for TPUs, so requesting the latest
profile version for non-TPU workloads is unnecessary and can delay
the completion of initialization.
Testing: test workload.
PiperOrigin-RevId: 584148686
2023-11-20 15:52:03 -08:00
Krasimir Georgiev
9287a6369d
Integrate LLVM at llvm/llvm-project@9bdbb8226e
...
Updates LLVM usage to match
[9bdbb8226e70](https://github.com/llvm/llvm-project/commit/9bdbb8226e70 )
PiperOrigin-RevId: 584091615
2023-11-20 12:01:57 -08:00
Yash Katariya
d6a9352270
Only call executale.get_output_memory_kinds() if jax_enable_memories is True
...
PiperOrigin-RevId: 584087022
2023-11-20 11:44:30 -08:00
jax authors
ab9c973031
Merge pull request #18600 from nouiz:doc_compilation_cache
...
PiperOrigin-RevId: 584068904
2023-11-20 10:43:32 -08:00
jax authors
e388d48d1f
Merge pull request #18228 from JiaYaobo:random.binomial
...
PiperOrigin-RevId: 584045591
2023-11-20 09:16:41 -08:00
jax authors
4fd93c3226
Merge pull request #18593 from lgeiger:xeinsum-error-msg
...
PiperOrigin-RevId: 584024405
2023-11-20 07:47:16 -08:00
Frederic Bastien
72b6c9cf0b
Document the compilation cache
2023-11-20 07:03:30 -08:00
George Necula
2d9da6c8fb
Cleanup the code to picking lowering rules based on platform.
...
Previously, we had special-cased the code to pick the lowering
rule for a primitive based on the lowering platform, and separately
we had the code to handle multi-platform lowering. The latter,
called `mlir.lower_multi_platform` had its own special case for
when a single lowering rule applied.
We rename `mlir.lower_multi_platform` to `mlir.lower_per_platform`
to not imply that it is only for multi-platform. We simplify
its API (takes a dictionary instead of a list of tuples).
2023-11-19 18:39:59 +02:00
jiayaobo
ae2387dc27
add random.binomial
...
update
update
modify
2023-11-19 14:51:10 +08:00
Yash Katariya
c8ef37507b
Make the SpecifiedLayout class opaque.
...
Also need to enabling pickling to xc.Layout so that AOT serialization continues to work.
PiperOrigin-RevId: 583684299
2023-11-18 15:17:16 -08:00