4942 Commits

Author SHA1 Message Date
jax authors
8ad774fb10 Automate arguments for jax.distributed.initialize for cloud TPU environments.
PiperOrigin-RevId: 586892544
2023-11-30 22:25:00 -08:00
Jake VanderPlas
97beb01c43 Deprecate the device() method of JAX arrays 2023-11-30 11:43:02 -08:00
Jake VanderPlas
d2b4800723 tests: improve warnings-related tests 2023-11-30 10:35:24 -08:00
Yash Katariya
e624610e72 Replace apply_primitive internals with jax.jit.
This allows deletion of a lot of code and leads to ~40% eager performance speedup.

Benchmarks:

```
name                                                      old time/op          new time/op          delta
eager_unary_dispatch                                      31.3µs ± 1%          19.4µs ± 6%  -37.91%    (p=0.016 n=4+5)
eager_unary                                               32.1µs ± 0%          19.8µs ± 4%  -38.26%    (p=0.016 n=4+5)
eager_binary_dispatch                                     35.9µs ± 1%          20.5µs ± 4%  -42.93%    (p=0.016 n=4+5)
eager_binary                                              36.6µs ± 1%          21.1µs ± 4%  -42.29%    (p=0.016 n=4+5)
jit_trivial_dispatch                                      3.87µs ± 2%          4.12µs ±25%     ~       (p=1.000 n=5+5)
jit_trivial                                               4.75µs ± 2%          4.82µs ±11%     ~       (p=0.690 n=5+5)
jit_simple_dispatch                                       2.95µs ± 2%          2.97µs ± 7%     ~       (p=1.000 n=5+5)
jit_simple                                                3.52µs ± 6%          3.51µs ± 5%     ~       (p=0.841 n=5+5)
jit_simple_dispatch_array                                 2.95µs ± 2%          2.96µs ± 6%     ~       (p=1.000 n=5+5)
jit_simple_array                                          3.46µs ± 2%          3.51µs ± 5%     ~       (p=0.690 n=5+5)
jit_small_matmul                                          3.01µs ± 1%          3.00µs ± 4%     ~       (p=0.548 n=5+5)
jit_big_matmul                                            34.0µs ±18%          35.5µs ±17%     ~       (p=0.310 n=5+5)
jit_simple_many_args_dispatch/num_args:10                 6.93µs ± 6%          6.80µs ± 6%     ~     (p=0.481 n=10+10)
jit_simple_many_args_dispatch/num_args:100                47.7µs ± 7%          45.4µs ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args_dispatch/num_args:1000                545µs ± 8%           516µs ± 2%     ~      (p=0.101 n=10+8)
jit_simple_many_args_dispatch/num_args:2000               1.12ms ± 7%          1.07ms ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args/num_args:10                          7.42µs ± 5%          7.23µs ± 2%     ~      (p=0.173 n=10+8)
jit_simple_many_args/num_args:100                         48.4µs ± 7%          45.6µs ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args/num_args:1000                         542µs ± 6%           524µs ± 8%     ~     (p=0.089 n=10+10)
jit_simple_many_args/num_args:2000                        1.12ms ± 7%          1.08ms ± 1%     ~      (p=0.068 n=10+8)
jit_simple_pruned_args_dispatch_10                        4.79µs ± 8%          4.98µs ±10%     ~       (p=0.421 n=5+5)
jit_simple_pruned_args_10                                 5.32µs ± 6%          5.30µs ± 4%     ~       (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_100                       24.7µs ± 6%          23.8µs ± 8%     ~       (p=0.548 n=5+5)
jit_simple_pruned_args_100                                25.2µs ± 6%          24.4µs ± 8%     ~       (p=0.690 n=5+5)
jit_simple_pruned_args_dispatch_1000                       238µs ± 7%           232µs ± 8%     ~       (p=0.841 n=5+5)
jit_simple_pruned_args_1000                                240µs ± 7%           234µs ± 8%     ~       (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_2000                       516µs ± 6%           497µs ± 1%     ~       (p=0.413 n=5+4)
jit_simple_pruned_args_2000                                517µs ± 6%           505µs ± 7%     ~       (p=0.690 n=5+5)
jit_dispatch_without_transfer                              719µs ± 9%           751µs ± 8%     ~       (p=0.222 n=5+5)
jit_dispatch_with_transfer                                 799µs ±14%           793µs ± 9%     ~       (p=1.000 n=5+5)
pmap_trivial_2_devices                                    49.9µs ±40%          48.2µs ±42%     ~       (p=0.841 n=5+5)
pmap_trivial_dispatch_8_devices                           74.5µs ±24%          78.9µs ±29%     ~       (p=0.421 n=5+5)
pmap_trivial_8_devices                                    79.3µs ± 6%          82.7µs ±20%     ~       (p=0.841 n=5+5)
pmap_simple_2_devices                                     47.1µs ±17%          49.1µs ±20%     ~       (p=0.548 n=5+5)
pmap_simple_dispatch_8_devices                            73.4µs ±16%          76.8µs ±21%     ~       (p=0.690 n=5+5)
pmap_simple_8_devices                                     76.0µs ±10%          80.6µs ±29%     ~       (p=1.000 n=5+5)
pmap_simple_dispatch_8_devices_100_args                   1.12ms ±22%          1.08ms ±42%     ~       (p=0.841 n=5+5)
pmap_simple_8_devices_100_args                            12.5ms ± 8%          12.8ms ±10%     ~       (p=1.000 n=5+5)
sda_index_1                                                413µs ± 1%           686µs ± 4%  +66.08%    (p=0.008 n=5+5)
sda_index_2                                                850µs ± 1%          1378µs ± 4%  +62.02%    (p=0.008 n=5+5)
sda_index_8                                               3.60ms ± 1%          5.69ms ± 4%  +58.00%    (p=0.008 n=5+5)
bench_shaped_abstractify                                   300µs ± 1%           305µs ± 3%     ~       (p=0.056 n=5+5)
bench_xla_abstractify_scalar_int                          6.45µs ± 1%          6.50µs ± 3%     ~       (p=0.548 n=5+5)
bench_xla_abstractify_scalar_float                        3.73µs ± 1%          3.73µs ± 3%     ~       (p=0.690 n=5+5)
bench_xla_abstractify_scalar_numpy_int32                  4.97µs ± 1%          4.83µs ± 3%     ~       (p=0.095 n=5+5)
bench_xla_abstractify_scalar_numpy_uint32                 4.91µs ± 1%          4.75µs ± 0%   -3.30%    (p=0.016 n=5+4)
bench_xla_abstractify_numpy_random                        4.34µs ± 2%          4.31µs ± 3%     ~       (p=0.310 n=5+5)
bench_xla_abstractify_numpy_arange_100_float32            3.94µs ± 1%          3.93µs ± 3%     ~       (p=0.548 n=5+5)
bench_xla_abstractify_enum                                6.85µs ± 1%          7.06µs ± 7%   +3.07%    (p=0.032 n=5+5)
bench_are_op_shardings_equal                              26.9µs ± 2%          27.0µs ± 3%     ~       (p=0.841 n=5+5)
bench_pjit_check_aval_sharding                             691µs ± 2%           711µs ±13%     ~       (p=0.841 n=5+5)
bench_addressable_shards_index                             656ns ± 4%           688ns ± 9%     ~       (p=0.095 n=5+5)
bench_remat_eager_retracing_overheads                     12.7ms ± 4%          10.7ms ± 1%  -15.48%    (p=0.016 n=5+4)
bench_remat_eager_retracing_overheads_static_argnums      13.0ms ± 2%          11.3ms ± 6%  -13.71%    (p=0.008 n=5+5)
bench_slicing_compilation                                 12.1ms ± 1%          12.3ms ± 4%     ~       (p=0.690 n=5+5)
bench_slicing_compilation2                                11.3ms ± 0%          11.5ms ± 6%     ~       (p=0.690 n=5+5)
bench_repeated_static_indexing                            62.5ms ± 2%          40.8ms ± 8%  -34.77%    (p=0.008 n=5+5)
bench_repeated_static_slicing                             46.7ms ± 1%          31.4ms ± 2%  -32.76%    (p=0.008 n=5+5)
pjit_simple_1_device/num_args:1                           2.72µs ± 2%          2.68µs ± 5%     ~       (p=0.151 n=5+5)
pjit_simple_1_device/num_args:10                          12.6µs ± 7%          12.3µs ± 3%     ~       (p=0.310 n=5+5)
pjit_simple_1_device/num_args:100                          109µs ± 3%           108µs ± 4%     ~       (p=0.548 n=5+5)
pjit_simple_4_device/num_args:1                           38.0µs ±26%          36.8µs ±19%     ~       (p=0.690 n=5+5)
pjit_simple_4_device/num_args:10                          93.3µs ±19%          96.6µs ±23%     ~       (p=0.841 n=5+5)
pjit_simple_4_device/num_args:100                          730µs ±16%           698µs ±48%     ~       (p=0.841 n=5+5)
pjit_aot_1_device/num_args:1                              3.29µs ± 2%          3.12µs ± 4%   -5.24%    (p=0.016 n=4+5)
pjit_aot_1_device/num_args:10                             13.0µs ± 1%          12.7µs ± 2%     ~       (p=0.063 n=4+5)
pjit_aot_1_device/num_args:100                             111µs ± 5%           110µs ±11%     ~       (p=0.421 n=5+5)
pjit_aot_4_device/num_args:1                              38.4µs ±19%          38.9µs ±24%     ~       (p=1.000 n=5+5)
pjit_aot_4_device/num_args:10                             91.3µs ±15%          96.9µs ±29%     ~       (p=0.548 n=5+5)
pjit_aot_4_device/num_args:100                             676µs ±20%           689µs ±41%     ~       (p=0.841 n=5+5)
host_local_array_to_global_array                           196µs ± 6%           194µs ± 4%     ~       (p=0.548 n=5+5)
device_put                                                50.8µs ± 1%          50.7µs ± 4%     ~       (p=0.413 n=4+5)
device_put_sharded                                         176µs ± 0%           177µs ± 4%     ~       (p=0.190 n=4+5)
device_get_8_devices                                      3.96ms ± 4%          4.03ms ± 7%     ~       (p=0.413 n=4+5)
np_asarray_8_devices                                      3.34ms ±18%          3.30ms ±10%     ~       (p=0.548 n=5+5)
jax_array_arrays_8_devices                                5.01ms ±10%          5.09ms ±21%     ~       (p=0.421 n=5+5)
batch_inplace_while_scatter                                440µs ± 1%           439µs ± 1%     ~       (p=0.421 n=5+5)
batch_inplace_while_dynamic_update_slice                   454µs ± 0%           457µs ± 1%     ~       (p=0.905 n=4+5)
serial_dot_products                                       4.51µs ± 3%          4.41µs ± 2%     ~       (p=0.151 n=5+5)
bench_make_array_from_callback_fully_replicated_sharding  26.6µs ± 1%          27.0µs ± 2%     ~       (p=0.056 n=5+5)
```

PiperOrigin-RevId: 586505950
2023-11-29 18:07:13 -08:00
jax authors
57e19db104 Merge pull request #18736 from mattjj:device-put-fixes
PiperOrigin-RevId: 586490689
2023-11-29 16:51:15 -08:00
Matthew Johnson
c9ab0bfd3c fix grad device_put src inference, and a small device_put bug
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-11-29 16:24:24 -08:00
Jake VanderPlas
0aec40a16f Deprecate arr.device_buffer and arr.device_buffers 2023-11-29 15:31:01 -08:00
Peter Hawkins
842ca2ccc5 Use process_count(backend) in local_devices().
Due to what is arguably a bug, multiple TPU devices in the same job can have the same process index. When determining a process count for, say, CPU, make sure we use the same backend to compute the process_count. Otherwise we might see an apparently out-of-range process index from another backend.

We should perhaps fix the TPU backend not to do this, but that's going to be a bigger change.

PiperOrigin-RevId: 586453157
2023-11-29 14:24:06 -08:00
jax authors
0fce77a70e Merge pull request #18708 from jakevdp:array-equal-dep
PiperOrigin-RevId: 586357829
2023-11-29 08:58:29 -08:00
Peter Hawkins
458a8962be Always lower reduce_scatter_p as an HLO ReduceScatter.
We don't need the fallback path for CPU: XLA:CPU already does its own lowering of ReduceScatter as AllReduce + DynamicSlice, and I plan to teach it a direct lowering in an upcoming change.

PiperOrigin-RevId: 586311031
2023-11-29 05:37:58 -08:00
jax authors
528a90e97d Merge pull request #18724 from apaszke:downgrade-logging
PiperOrigin-RevId: 586304793
2023-11-29 05:06:18 -08:00
Peter Hawkins
1e961b80da Remove fallback path that lowers all_gather via psum.
As far as I can tell this is no longer necessary on GPU, which handles arbitrary allgather dimensions (by making the dimension the major-most dimension in layout assignment), and on CPU, where at present XLA would do the same lowering JAX would.

I'm planning to improve the XLA:CPU lowering in a subsequent change.

PiperOrigin-RevId: 586291911
2023-11-29 04:14:11 -08:00
Adam Paszke
d80c15aaee Downgrade a bunch of logging to DEBUG
The logs related to compilation cache ended up being quite chatty,
which is quite unlike the other logs in JAX. This downgrades a bunch
of them to debug, as they can always be enabled independently
using JAX config. This should also fix the recent failures in
logging_test.py.
2023-11-29 12:10:53 +00:00
Sharad Vikram
8dfbf90602 [Pallas/Mosaic] Add support for barrier semaphores
PiperOrigin-RevId: 586289340
2023-11-29 04:04:11 -08:00
Yash Katariya
cb7c2ed848 Use trace_to_jaxpr_dynamic for the apply_primitive path. trace_to_jaxpr_final is only for final style primitives. Also do some cleanup.
PiperOrigin-RevId: 586106427
2023-11-28 14:35:44 -08:00
Jake VanderPlas
ae662be5ef Fix typo in deprecation warning 2023-11-28 13:56:49 -08:00
Jake VanderPlas
13dd5e42cc Deprecate non-array inputs to jnp.array_equal & jnp.array_equiv 2023-11-28 13:55:18 -08:00
Jake VanderPlas
a8723ecb9c Fix grad of jnp.i0 at zero 2023-11-28 12:34:56 -08:00
Yash Katariya
f8e22ae512 Remove jaxpr_debug_info from MeshComputation since that information is available via AllArgsInfo
PiperOrigin-RevId: 586018345
2023-11-28 10:05:15 -08:00
George Necula
c6afdfd8d6 [shape_poly] Simplify the API for processing polymorphic_shape specifications
Before, we had `export.poly_spec` to create a jax.ShapedDtypeStruct`
given a polymorphic shape specification. This function was
invoked `poly_spec(arg_shape, arg_dtype, polymorphic_shape)`.
The `arg_shape` was only needed when the polymorphic shape spec
contained placeholders.

We break out an `export.symbolic_shape` that is just a parser
of polymorphic shape specs and we ask the user to invoke
`jax.ShapeDtypeStruct` directly:

`jax.ShapeDtypeStruct(export.symbolic_shape(polymorphic_shape, like=arg_shape), arg_dtype)`.

We also rename the `export.poly_specs` to `export.arg_specs`.
2023-11-28 12:45:59 +02:00
Yash Katariya
88d980f164 Typecheck avals and sharding for arguments that were DCE'd.
This keeps the promise of AOT that recompilation is guaranteed.

Fixes https://github.com/google/jax/issues/18686

PiperOrigin-RevId: 585855658
2023-11-27 22:39:32 -08:00
Yash Katariya
37f11428a3 Fix indexing bug when querying _input_layouts
PiperOrigin-RevId: 585842302
2023-11-27 21:12:07 -08:00
Yash Katariya
81aee237d8 Simply lower_sharding_computation signature by always taking a closed jaxpr as input. For apply_primitive do the tracing to jaxpr in dispatch.py
PiperOrigin-RevId: 585810475
2023-11-27 18:00:57 -08:00
Yash Katariya
2ed0fc4d1c compiled.input_layouts() should preserve the structure of the original in_tree.
JAX by default DCE's arguments that are unused which changes the in_layouts available on the `executable`. This breaks when we try to unflatten the said in_layouts with the original in_tree (because in_tree has all the args DCE'd + non-DCE'd).

The in_layouts that we return to the user should contain layouts for DCE'd + non-DCE'd args. So fill the DCE'd layouts with None which means the default layout. This does not affect the actual HLO computation because JAX will discard the DCE'd layouts anyways, consequently discarding the jax.Arrays created with those layouts.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 585790912
2023-11-27 16:23:58 -08:00
jax authors
b9b5410ddd Default-enable the Jax persistent compilation cache.
To increase the adoption of the compilation cache, we should
enable it by default. A prerequisite is to configure a default
cache directory.

Switch spherical_cnn molecules training and universal_diffusion
model wrapper to use the default cache.

Testing: manual testing with test workloads.
PiperOrigin-RevId: 585767363
2023-11-27 14:53:20 -08:00
Jake VanderPlas
01fde43fce Fix sign of jax.scipy.special.gamma for negative inputs 2023-11-27 14:08:02 -08:00
Jake VanderPlas
8bc486a48f Fix debug_nans issue in sort 2023-11-27 08:41:19 -08:00
Tomás Longeri
523f36153f [Mosaic] Use C++ apply-vector-layout as default setting
PiperOrigin-RevId: 585663081
2023-11-27 08:39:42 -08:00
Adam Paszke
30a76756c3 [Mosaic] Accept that the dump flag might be missing.
The flag ends up in libtpu in the OSS build, so we will need to find a different
mechanism for this. This is a quick fix to get things working again.

PiperOrigin-RevId: 584835560
2023-11-23 02:21:29 -08:00
jax authors
bf2464ec82 Merge pull request #18652 from mattjj:debug-callback-docstring-2
PiperOrigin-RevId: 584731668
2023-11-22 15:22:30 -08:00
Matthew Johnson
997db225e2 small tweaks to jax.debug.print docstring 2023-11-22 15:03:41 -08:00
jax authors
4770690b4a Merge pull request #18647 from jakevdp:pref-eltype-doc
PiperOrigin-RevId: 584715788
2023-11-22 14:07:04 -08:00
Jake VanderPlas
530cf30bfc xla_bridge: add logic to avoid version skew 2023-11-22 12:17:21 -08:00
Jake VanderPlas
2acdb120a0 DOC: document preferred_element_type argument to dot functions 2023-11-22 09:49:34 -08:00
Tom Hennigan
1b504bb68e Allow threads to race setting attributes on Mesh.
PiperOrigin-RevId: 584602313
2023-11-22 05:47:56 -08:00
Matthew Johnson
67677eb10e improve error message for e.g. jnp.zeros(5)[:, 0] 2023-11-21 15:59:21 -08:00
jax authors
2efa5862ad Merge pull request #18563 from jakevdp:cleanup-fold-in
PiperOrigin-RevId: 584414134
2023-11-21 13:37:51 -08:00
Peter Hawkins
84c1e825c0 Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.
PiperOrigin-RevId: 584377134
2023-11-21 11:10:12 -08:00
George Necula
a6284a03e5 Disable a integer_pow harness due to overflow.
PiperOrigin-RevId: 584261696
2023-11-21 02:23:22 -08:00
Peter Hawkins
49c80e68d1 Fix error/hang when non-finite values are passed to non-symmetric Eigendecomposition.
Improve the documentation of lax.eig().

Fixes https://github.com/google/jax/issues/18226

PiperOrigin-RevId: 584170564
2023-11-20 17:32:16 -08:00
jax authors
fc8058a17d Restrict retrieving XLA-AutoFDO profile version to TPU workloads.
XLA-AutoFDO is supported only for TPUs, so requesting the latest
profile version for non-TPU workloads is unnecessary and can delay
the completion of initialization.

Testing: test workload.
PiperOrigin-RevId: 584148686
2023-11-20 15:52:03 -08:00
Krasimir Georgiev
9287a6369d Integrate LLVM at llvm/llvm-project@9bdbb8226e
Updates LLVM usage to match
[9bdbb8226e70](https://github.com/llvm/llvm-project/commit/9bdbb8226e70)

PiperOrigin-RevId: 584091615
2023-11-20 12:01:57 -08:00
Yash Katariya
d6a9352270 Only call executale.get_output_memory_kinds() if jax_enable_memories is True
PiperOrigin-RevId: 584087022
2023-11-20 11:44:30 -08:00
jax authors
ab9c973031 Merge pull request #18600 from nouiz:doc_compilation_cache
PiperOrigin-RevId: 584068904
2023-11-20 10:43:32 -08:00
jax authors
e388d48d1f Merge pull request #18228 from JiaYaobo:random.binomial
PiperOrigin-RevId: 584045591
2023-11-20 09:16:41 -08:00
jax authors
4fd93c3226 Merge pull request #18593 from lgeiger:xeinsum-error-msg
PiperOrigin-RevId: 584024405
2023-11-20 07:47:16 -08:00
Frederic Bastien
72b6c9cf0b Document the compilation cache 2023-11-20 07:03:30 -08:00
George Necula
2d9da6c8fb Cleanup the code to picking lowering rules based on platform.
Previously, we had special-cased the code to pick the lowering
rule for a primitive based on the lowering platform, and separately
we had the code to handle multi-platform lowering. The latter,
called `mlir.lower_multi_platform` had its own special case for
when a single lowering rule applied.

We rename `mlir.lower_multi_platform` to `mlir.lower_per_platform`
to not imply that it is only for multi-platform. We simplify
its API (takes a dictionary instead of a list of tuples).
2023-11-19 18:39:59 +02:00
jiayaobo
ae2387dc27 add random.binomial
update

update

modify
2023-11-19 14:51:10 +08:00
Yash Katariya
c8ef37507b Make the SpecifiedLayout class opaque.
Also need to enabling pickling to xc.Layout so that AOT serialization continues to work.

PiperOrigin-RevId: 583684299
2023-11-18 15:17:16 -08:00