15932 Commits

Author SHA1 Message Date
Peter Hawkins
57e62ca03c Reenable scipy_stats_test in CI.
Disable testTruncNormPdf on CPU, which is failing after an LLVM update.

PiperOrigin-RevId: 528884880
2023-05-02 14:11:08 -07:00
Yash Katariya
40349a8612 Normalize 1 length tuples to a string while getting PartitionSpec from array mapping.
Fixes https://github.com/google/jax/issues/15782

PiperOrigin-RevId: 528796985
2023-05-02 08:55:40 -07:00
jax authors
d5289e627f Merge pull request #15804 from froystig:issue13949
PiperOrigin-RevId: 528790988
2023-05-02 08:30:46 -07:00
Yash Katariya
c52e48b6c0 Only return the same input Sharding object is the original aval's ndim and out_aval's ndim are the same.
This is because if both the OpShardings are replicated then the ndim is not encoded in the OpSharding and it will return True even if the Sharding is incompatible with the output's ndim. Concretely `NamedSharding({'x': 1, y: '2'}, P('x'))` is not compatible with a input with `ndim == 0`.

PiperOrigin-RevId: 528621971
2023-05-01 17:39:51 -07:00
jax authors
12e3db5fbc Merge pull request #15813 from jakevdp:keyarray-device-put-sharded
PiperOrigin-RevId: 528578837
2023-05-01 14:41:55 -07:00
Jake VanderPlas
979aa3235b KeyArray: implement sharded & replicated device_put 2023-05-01 14:17:01 -07:00
Skye Wanderman-Milne
70cac773f7 Exclude scipy_fft_test from msan as well as t/asan.
PiperOrigin-RevId: 528562775
2023-05-01 13:42:24 -07:00
Skye Wanderman-Milne
fa68c1f882 Bump up lax_test TPU sharding to avoid asan timeouts
PiperOrigin-RevId: 528559870
2023-05-01 13:31:22 -07:00
Yash Katariya
4a3fb238f6 Return the same sharding object if the output OpSharding matches the input OpSharding.
Fixes https://github.com/google/jax/issues/15782

PiperOrigin-RevId: 528531594
2023-05-01 11:46:57 -07:00
jax authors
3386e77cfa Merge pull request #15809 from pizzud:docfix
PiperOrigin-RevId: 528510503
2023-05-01 10:36:24 -07:00
jax authors
2b01186b84 Merge pull request #15808 from jakevdp:min-python-version
PiperOrigin-RevId: 528503926
2023-05-01 10:15:46 -07:00
David Pizzuto
6948d32d15 contributing: Switch repo URL to HTTPS for consistency with other github URLs. 2023-05-01 10:03:39 -07:00
Jake VanderPlas
57af5360a1 Update required Python version to 3.9 2023-05-01 10:00:57 -07:00
Jake VanderPlas
8a4eadfec5 Change Python 3.8 CI runs to Python 3.9
Followup for #15805

PiperOrigin-RevId: 528495690
2023-05-01 09:47:51 -07:00
Yash Katariya
e51d12cdef Remove py3.8 support from jax as per https://jax.readthedocs.io/en/latest/deprecation.html
PiperOrigin-RevId: 528488319
2023-05-01 09:15:44 -07:00
Peter Hawkins
ba11b9dcba Remove tupling of custom call results.
MHLO-to-HLO conversion now knows how to introduce tuples to custom calls if needed, so we can remove explicit tupling from JAX.

PiperOrigin-RevId: 528485268
2023-05-01 09:02:14 -07:00
Roy Frostig
8d4d520933 resolve opaque dtypes in MLIR callback lowering and in XLA shape translation 2023-05-01 08:21:54 -07:00
George Necula
b981b4f68f [shape_poly] Re-enable tests that had been fixed
These tests had been fixed through changes in StableHLO
and upgrading the JAX lowering paths to not use
fallback HLO lowering.

PiperOrigin-RevId: 528465080
2023-05-01 07:15:17 -07:00
jax authors
40d061a38b Merge pull request #15800 from gnecula:poly_opaque
PiperOrigin-RevId: 528451726
2023-05-01 05:43:02 -07:00
George Necula
1876d9691f [shape_poly] Fix vmap(while) case for opaque types 2023-05-01 12:45:19 +02:00
Christina Sorokin
63d87c6c3d Add new attribute function_list to XLACallModule and bump the version.
PiperOrigin-RevId: 528076798
2023-04-28 22:34:41 -07:00
George Necula
5b7e8d0765 [jax2tf] Simplify back_compat_test.py to use jax_export mechanisms to run
the serialized module, instead of relying on tf.XlaCallModule.

PiperOrigin-RevId: 528061968
2023-04-28 21:22:35 -07:00
Skye Wanderman-Milne
c662fd216d Disable tsan CI for random_test_with_custom_prng to avoid timeouts.
asan is already disabled, and the comment and "cpu" case indicates
that tsan should already have been disabled as well.

PiperOrigin-RevId: 528000458
2023-04-28 15:26:46 -07:00
jax authors
b8f3caf4b7 Merge pull request #15790 from jakevdp:sparse-doc
PiperOrigin-RevId: 527998499
2023-04-28 15:18:08 -07:00
Jake VanderPlas
e059e3b52f DOC: document jax.experimental.sparse.linalg 2023-04-28 14:18:50 -07:00
jax authors
0814b874d5 Merge pull request #15779 from jakevdp:keyty-itemsize
PiperOrigin-RevId: 527914223
2023-04-28 10:00:14 -07:00
jax authors
566f17513b Merge pull request #15770 from gnecula:clean_call_tf
PiperOrigin-RevId: 527841341
2023-04-28 03:54:21 -07:00
George Necula
161664e858 [call_tf] Some cleanup of call_tf
The main cleanup is around _code_generator_and_avals, which in
an earlier version of the code was used for both abstract values
and for code generation. That is why it was cached, and why it
returned a code generator and abstract values. A while
ago we did a first round of cleaning to not use it for abstract
values. Now we can actually eliminate the function and inline
it directly.

A second improvement is to add the explicit error message from
TF commpilation, instead of just the generic message that
call_tf cannot be used with non-compileable functions.
2023-04-28 12:38:27 +02:00
jax authors
818805f6f9 Improve error message in _create_device_mesh_for_nd_torus
PiperOrigin-RevId: 527834640
2023-04-28 03:15:06 -07:00
John QiangZhang
5b4388ad03 Add new attribute function_list to XLACallModule and bump the version.
PiperOrigin-RevId: 527741961
2023-04-27 18:28:12 -07:00
Jake VanderPlas
054fca5cd4 KeyArray: define itemsize on opaque dtype 2023-04-27 15:59:57 -07:00
Peter Hawkins
84c516974a Revert: Switch to using Clang as the default compiler.
It appears this is causing deadlocks in multi-gpu tests.

PiperOrigin-RevId: 527706573
2023-04-27 15:52:28 -07:00
Jake VanderPlas
50405b1081 KeyArray: add size attribute 2023-04-27 14:06:55 -07:00
Yash Katariya
86c1f5bcee Preserve the sharding type of physical sharding on logical sharding when .sharding is accessed on a PRNGKeyArray
PiperOrigin-RevId: 527639257
2023-04-27 11:41:00 -07:00
Dinghua Li
7d6fb535a9 [shape_poly] Add support for shape polymorphism for _unsafe_rbg_split.
PiperOrigin-RevId: 527619524
2023-04-27 10:36:36 -07:00
jax authors
5027543c11 Merge pull request #15753 from gnecula:poly_unif
PiperOrigin-RevId: 527590089
2023-04-27 08:52:38 -07:00
George Necula
876c53abb7 [shape_poly] Refactor the unification of the argument abstract values with the actual arguments
This was called shape_poly.compute_dim_values. We rename it to
shape_poly.unify_avals_with_args and we add better error reporting to it.
Now it will identify the arg/kwarg where there is a shape discrepancy.

This is intended to be a pure refactoring, in preparation for adding
support for shape polymorphism to jax_export.call_exported.
2023-04-27 08:59:59 +02:00
jax authors
c6457f3153 Fix bug in the JAX spectral bisection eigensolver implementing the QDWH-eigh algorithm.
Use the largest columns, as intended, of the projector P for the initial guess of the subspace iteration, instead of the smallest.

PiperOrigin-RevId: 527418006
2023-04-26 17:19:05 -07:00
Yash Katariya
34d5a6259f Default jax_spmd_mode to allow_jit which will allow explicit jax.jit to not raise the multihost error (since jit and pjit have been merged).
Implicit jit and apply_primitive will still raise an error though (which is recognized via inline parameter). Majority of jnp operations in JAX should be inlined.

PiperOrigin-RevId: 527398394
2023-04-26 15:56:46 -07:00
Skye Wanderman-Milne
67d80c21cb Increase sharding count on nn_test and svd_test to avoid ASAN timeouts.
PiperOrigin-RevId: 527387005
2023-04-26 15:11:29 -07:00
jax authors
139e1c2f92 Merge pull request #15762 from jakevdp:keyarray-eq
PiperOrigin-RevId: 527366188
2023-04-26 13:54:32 -07:00
Jake VanderPlas
e46d7f673b KeyArray: use assertArraysEqual in place of assertKeysEqual 2023-04-26 13:15:03 -07:00
Jake VanderPlas
fcffbac346 KeyArray: implement __eq__ and __ne__ 2023-04-26 13:12:24 -07:00
jax authors
be74b07800 Merge pull request #15757 from jakevdp:keyarray-operators
PiperOrigin-RevId: 527353108
2023-04-26 13:09:59 -07:00
Jake VanderPlas
a47a71ff80 KeyArray: better errors for operators 2023-04-26 11:34:07 -07:00
jax authors
a6b052ae78 Merge pull request #15748 from jakevdp:keyarray-at
PiperOrigin-RevId: 527322597
2023-04-26 11:29:25 -07:00
Parker Schuh
5f4408ded7 Convert inspect_sharding to register the handler directly in c++ so that it can
work across the c-api boundary.

PiperOrigin-RevId: 527322386
2023-04-26 11:22:28 -07:00
jax authors
1551d9e816 Merge pull request #15755 from jakevdp:sharding-constraint-doc
PiperOrigin-RevId: 527317106
2023-04-26 11:04:53 -07:00
John QiangZhang
af051ba49c [2/n] Embed the tf.Graph into the stablehlo.custom_call.
PiperOrigin-RevId: 527302563
2023-04-26 10:20:46 -07:00
Jake VanderPlas
8dc06ed2ce Document jax.lax.with_sharding_constraint 2023-04-26 10:19:04 -07:00