11698 Commits

Author SHA1 Message Date
Peter Hawkins
458a8962be Always lower reduce_scatter_p as an HLO ReduceScatter.
We don't need the fallback path for CPU: XLA:CPU already does its own lowering of ReduceScatter as AllReduce + DynamicSlice, and I plan to teach it a direct lowering in an upcoming change.

PiperOrigin-RevId: 586311031
2023-11-29 05:37:58 -08:00
jax authors
528a90e97d Merge pull request #18724 from apaszke:downgrade-logging
PiperOrigin-RevId: 586304793
2023-11-29 05:06:18 -08:00
Peter Hawkins
1e961b80da Remove fallback path that lowers all_gather via psum.
As far as I can tell this is no longer necessary on GPU, which handles arbitrary allgather dimensions (by making the dimension the major-most dimension in layout assignment), and on CPU, where at present XLA would do the same lowering JAX would.

I'm planning to improve the XLA:CPU lowering in a subsequent change.

PiperOrigin-RevId: 586291911
2023-11-29 04:14:11 -08:00
Adam Paszke
d80c15aaee Downgrade a bunch of logging to DEBUG
The logs related to compilation cache ended up being quite chatty,
which is quite unlike the other logs in JAX. This downgrades a bunch
of them to debug, as they can always be enabled independently
using JAX config. This should also fix the recent failures in
logging_test.py.
2023-11-29 12:10:53 +00:00
Sharad Vikram
8dfbf90602 [Pallas/Mosaic] Add support for barrier semaphores
PiperOrigin-RevId: 586289340
2023-11-29 04:04:11 -08:00
jax authors
7d4d912065 Merge pull request #18632 from mattjj:shmap-eager-custom-jvp-vjp
PiperOrigin-RevId: 586143644
2023-11-28 16:56:43 -08:00
jax authors
28063b886d Merge pull request #18718 from jakevdp:sparse-conj
PiperOrigin-RevId: 586131627
2023-11-28 16:14:20 -08:00
Matthew Johnson
7589c2bdb8 [shard_map] implement eager custom_jvp / custom_vjp 2023-11-28 16:08:56 -08:00
jax authors
2ccdfa6da2 Merge pull request #18711 from mattjj:shmap-transpose-fix-3
PiperOrigin-RevId: 586131306
2023-11-28 16:06:08 -08:00
Jake VanderPlas
05d18ac998 [sparse] add sparsify support for conj_p 2023-11-28 15:46:47 -08:00
Matthew Johnson
5fbda6d060 [shard-map] fix transpose replication checking bug with integer_pow 2023-11-28 15:39:32 -08:00
jax authors
5178afc81e Merge pull request #18714 from jakevdp:sparse-shape
PiperOrigin-RevId: 586123122
2023-11-28 15:34:16 -08:00
Jake VanderPlas
d6d7061d71 [sparse] canonicalize sparse shapes 2023-11-28 14:38:48 -08:00
Yash Katariya
cb7c2ed848 Use trace_to_jaxpr_dynamic for the apply_primitive path. trace_to_jaxpr_final is only for final style primitives. Also do some cleanup.
PiperOrigin-RevId: 586106427
2023-11-28 14:35:44 -08:00
Jake VanderPlas
ae662be5ef Fix typo in deprecation warning 2023-11-28 13:56:49 -08:00
Jake VanderPlas
a8723ecb9c Fix grad of jnp.i0 at zero 2023-11-28 12:34:56 -08:00
Yash Katariya
f8e22ae512 Remove jaxpr_debug_info from MeshComputation since that information is available via AllArgsInfo
PiperOrigin-RevId: 586018345
2023-11-28 10:05:15 -08:00
George Necula
66f9078d46 Remove redundant jax2tf/tests/shape_poly_test tests.
The removed tests run identically in tests/shape_poly_test.py

PiperOrigin-RevId: 585916166
2023-11-28 03:39:15 -08:00
George Necula
c6afdfd8d6 [shape_poly] Simplify the API for processing polymorphic_shape specifications
Before, we had `export.poly_spec` to create a jax.ShapedDtypeStruct`
given a polymorphic shape specification. This function was
invoked `poly_spec(arg_shape, arg_dtype, polymorphic_shape)`.
The `arg_shape` was only needed when the polymorphic shape spec
contained placeholders.

We break out an `export.symbolic_shape` that is just a parser
of polymorphic shape specs and we ask the user to invoke
`jax.ShapeDtypeStruct` directly:

`jax.ShapeDtypeStruct(export.symbolic_shape(polymorphic_shape, like=arg_shape), arg_dtype)`.

We also rename the `export.poly_specs` to `export.arg_specs`.
2023-11-28 12:45:59 +02:00
Yash Katariya
88d980f164 Typecheck avals and sharding for arguments that were DCE'd.
This keeps the promise of AOT that recompilation is guaranteed.

Fixes https://github.com/google/jax/issues/18686

PiperOrigin-RevId: 585855658
2023-11-27 22:39:32 -08:00
Yash Katariya
37f11428a3 Fix indexing bug when querying _input_layouts
PiperOrigin-RevId: 585842302
2023-11-27 21:12:07 -08:00
Yash Katariya
81aee237d8 Simply lower_sharding_computation signature by always taking a closed jaxpr as input. For apply_primitive do the tracing to jaxpr in dispatch.py
PiperOrigin-RevId: 585810475
2023-11-27 18:00:57 -08:00
Yash Katariya
2ed0fc4d1c compiled.input_layouts() should preserve the structure of the original in_tree.
JAX by default DCE's arguments that are unused which changes the in_layouts available on the `executable`. This breaks when we try to unflatten the said in_layouts with the original in_tree (because in_tree has all the args DCE'd + non-DCE'd).

The in_layouts that we return to the user should contain layouts for DCE'd + non-DCE'd args. So fill the DCE'd layouts with None which means the default layout. This does not affect the actual HLO computation because JAX will discard the DCE'd layouts anyways, consequently discarding the jax.Arrays created with those layouts.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 585790912
2023-11-27 16:23:58 -08:00
jax authors
b9b5410ddd Default-enable the Jax persistent compilation cache.
To increase the adoption of the compilation cache, we should
enable it by default. A prerequisite is to configure a default
cache directory.

Switch spherical_cnn molecules training and universal_diffusion
model wrapper to use the default cache.

Testing: manual testing with test workloads.
PiperOrigin-RevId: 585767363
2023-11-27 14:53:20 -08:00
Jake VanderPlas
01fde43fce Fix sign of jax.scipy.special.gamma for negative inputs 2023-11-27 14:08:02 -08:00
Jake VanderPlas
8bc486a48f Fix debug_nans issue in sort 2023-11-27 08:41:19 -08:00
Tomás Longeri
523f36153f [Mosaic] Use C++ apply-vector-layout as default setting
PiperOrigin-RevId: 585663081
2023-11-27 08:39:42 -08:00
Adam Paszke
30a76756c3 [Mosaic] Accept that the dump flag might be missing.
The flag ends up in libtpu in the OSS build, so we will need to find a different
mechanism for this. This is a quick fix to get things working again.

PiperOrigin-RevId: 584835560
2023-11-23 02:21:29 -08:00
jax authors
bf2464ec82 Merge pull request #18652 from mattjj:debug-callback-docstring-2
PiperOrigin-RevId: 584731668
2023-11-22 15:22:30 -08:00
Matthew Johnson
997db225e2 small tweaks to jax.debug.print docstring 2023-11-22 15:03:41 -08:00
jax authors
4770690b4a Merge pull request #18647 from jakevdp:pref-eltype-doc
PiperOrigin-RevId: 584715788
2023-11-22 14:07:04 -08:00
jax authors
d63f5190f5 Merge pull request #18646 from jakevdp:version-commit
PiperOrigin-RevId: 584701095
2023-11-22 13:00:42 -08:00
Jake VanderPlas
530cf30bfc xla_bridge: add logic to avoid version skew 2023-11-22 12:17:21 -08:00
Jake VanderPlas
2acdb120a0 DOC: document preferred_element_type argument to dot functions 2023-11-22 09:49:34 -08:00
Jake VanderPlas
cc9656d654 version: fix commit hash extraction 2023-11-22 09:35:03 -08:00
Tom Hennigan
1b504bb68e Allow threads to race setting attributes on Mesh.
PiperOrigin-RevId: 584602313
2023-11-22 05:47:56 -08:00
Matthew Johnson
67677eb10e improve error message for e.g. jnp.zeros(5)[:, 0] 2023-11-21 15:59:21 -08:00
jax authors
b48254e084 Merge pull request #18599 from mattjj:shmap-eager-axis-index
PiperOrigin-RevId: 584423708
2023-11-21 14:05:23 -08:00
jax authors
2efa5862ad Merge pull request #18563 from jakevdp:cleanup-fold-in
PiperOrigin-RevId: 584414134
2023-11-21 13:37:51 -08:00
Utku Evci
83b6c3f450 Removing unused description
PiperOrigin-RevId: 584395995
2023-11-21 12:17:20 -08:00
Peter Hawkins
84c1e825c0 Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.
PiperOrigin-RevId: 584377134
2023-11-21 11:10:12 -08:00
George Necula
a6284a03e5 Disable a integer_pow harness due to overflow.
PiperOrigin-RevId: 584261696
2023-11-21 02:23:22 -08:00
George Necula
139c0504cc [jax2tf] Disable test broken by TensorFlow
Also check all the jax2tf tests to ensure that each one
has at least one TPU configuration marked for TAP continuous.
Without this we will only notice failures on TPU post submit, as it was
the case here.

PiperOrigin-RevId: 584253387
2023-11-21 01:42:37 -08:00
Peter Hawkins
49c80e68d1 Fix error/hang when non-finite values are passed to non-symmetric Eigendecomposition.
Improve the documentation of lax.eig().

Fixes https://github.com/google/jax/issues/18226

PiperOrigin-RevId: 584170564
2023-11-20 17:32:16 -08:00
jax authors
fc8058a17d Restrict retrieving XLA-AutoFDO profile version to TPU workloads.
XLA-AutoFDO is supported only for TPUs, so requesting the latest
profile version for non-TPU workloads is unnecessary and can delay
the completion of initialization.

Testing: test workload.
PiperOrigin-RevId: 584148686
2023-11-20 15:52:03 -08:00
Krasimir Georgiev
9287a6369d Integrate LLVM at llvm/llvm-project@9bdbb8226e
Updates LLVM usage to match
[9bdbb8226e70](https://github.com/llvm/llvm-project/commit/9bdbb8226e70)

PiperOrigin-RevId: 584091615
2023-11-20 12:01:57 -08:00
Yash Katariya
d6a9352270 Only call executale.get_output_memory_kinds() if jax_enable_memories is True
PiperOrigin-RevId: 584087022
2023-11-20 11:44:30 -08:00
jax authors
ab9c973031 Merge pull request #18600 from nouiz:doc_compilation_cache
PiperOrigin-RevId: 584068904
2023-11-20 10:43:32 -08:00
jax authors
e388d48d1f Merge pull request #18228 from JiaYaobo:random.binomial
PiperOrigin-RevId: 584045591
2023-11-20 09:16:41 -08:00
jax authors
4fd93c3226 Merge pull request #18593 from lgeiger:xeinsum-error-msg
PiperOrigin-RevId: 584024405
2023-11-20 07:47:16 -08:00