Peter Hawkins
57e62ca03c
Reenable scipy_stats_test in CI.
...
Disable testTruncNormPdf on CPU, which is failing after an LLVM update.
PiperOrigin-RevId: 528884880
2023-05-02 14:11:08 -07:00
Yash Katariya
40349a8612
Normalize 1 length tuples to a string while getting PartitionSpec from array mapping.
...
Fixes https://github.com/google/jax/issues/15782
PiperOrigin-RevId: 528796985
2023-05-02 08:55:40 -07:00
jax authors
d5289e627f
Merge pull request #15804 from froystig:issue13949
...
PiperOrigin-RevId: 528790988
2023-05-02 08:30:46 -07:00
Yash Katariya
c52e48b6c0
Only return the same input Sharding object is the original aval's ndim and out_aval's ndim are the same.
...
This is because if both the OpShardings are replicated then the ndim is not encoded in the OpSharding and it will return True even if the Sharding is incompatible with the output's ndim. Concretely `NamedSharding({'x': 1, y: '2'}, P('x'))` is not compatible with a input with `ndim == 0`.
PiperOrigin-RevId: 528621971
2023-05-01 17:39:51 -07:00
jax authors
12e3db5fbc
Merge pull request #15813 from jakevdp:keyarray-device-put-sharded
...
PiperOrigin-RevId: 528578837
2023-05-01 14:41:55 -07:00
Jake VanderPlas
979aa3235b
KeyArray: implement sharded & replicated device_put
2023-05-01 14:17:01 -07:00
Skye Wanderman-Milne
70cac773f7
Exclude scipy_fft_test from msan as well as t/asan.
...
PiperOrigin-RevId: 528562775
2023-05-01 13:42:24 -07:00
Skye Wanderman-Milne
fa68c1f882
Bump up lax_test TPU sharding to avoid asan timeouts
...
PiperOrigin-RevId: 528559870
2023-05-01 13:31:22 -07:00
Yash Katariya
4a3fb238f6
Return the same sharding object if the output OpSharding matches the input OpSharding.
...
Fixes https://github.com/google/jax/issues/15782
PiperOrigin-RevId: 528531594
2023-05-01 11:46:57 -07:00
jax authors
3386e77cfa
Merge pull request #15809 from pizzud:docfix
...
PiperOrigin-RevId: 528510503
2023-05-01 10:36:24 -07:00
jax authors
2b01186b84
Merge pull request #15808 from jakevdp:min-python-version
...
PiperOrigin-RevId: 528503926
2023-05-01 10:15:46 -07:00
David Pizzuto
6948d32d15
contributing: Switch repo URL to HTTPS for consistency with other github URLs.
2023-05-01 10:03:39 -07:00
Jake VanderPlas
57af5360a1
Update required Python version to 3.9
2023-05-01 10:00:57 -07:00
Jake VanderPlas
8a4eadfec5
Change Python 3.8 CI runs to Python 3.9
...
Followup for #15805
PiperOrigin-RevId: 528495690
2023-05-01 09:47:51 -07:00
Yash Katariya
e51d12cdef
Remove py3.8 support from jax as per https://jax.readthedocs.io/en/latest/deprecation.html
...
PiperOrigin-RevId: 528488319
2023-05-01 09:15:44 -07:00
Peter Hawkins
ba11b9dcba
Remove tupling of custom call results.
...
MHLO-to-HLO conversion now knows how to introduce tuples to custom calls if needed, so we can remove explicit tupling from JAX.
PiperOrigin-RevId: 528485268
2023-05-01 09:02:14 -07:00
Roy Frostig
8d4d520933
resolve opaque dtypes in MLIR callback lowering and in XLA shape translation
2023-05-01 08:21:54 -07:00
George Necula
b981b4f68f
[shape_poly] Re-enable tests that had been fixed
...
These tests had been fixed through changes in StableHLO
and upgrading the JAX lowering paths to not use
fallback HLO lowering.
PiperOrigin-RevId: 528465080
2023-05-01 07:15:17 -07:00
jax authors
40d061a38b
Merge pull request #15800 from gnecula:poly_opaque
...
PiperOrigin-RevId: 528451726
2023-05-01 05:43:02 -07:00
George Necula
1876d9691f
[shape_poly] Fix vmap(while) case for opaque types
2023-05-01 12:45:19 +02:00
Christina Sorokin
63d87c6c3d
Add new attribute function_list
to XLACallModule and bump the version.
...
PiperOrigin-RevId: 528076798
2023-04-28 22:34:41 -07:00
George Necula
5b7e8d0765
[jax2tf] Simplify back_compat_test.py to use jax_export mechanisms to run
...
the serialized module, instead of relying on tf.XlaCallModule.
PiperOrigin-RevId: 528061968
2023-04-28 21:22:35 -07:00
Skye Wanderman-Milne
c662fd216d
Disable tsan CI for random_test_with_custom_prng to avoid timeouts.
...
asan is already disabled, and the comment and "cpu" case indicates
that tsan should already have been disabled as well.
PiperOrigin-RevId: 528000458
2023-04-28 15:26:46 -07:00
jax authors
b8f3caf4b7
Merge pull request #15790 from jakevdp:sparse-doc
...
PiperOrigin-RevId: 527998499
2023-04-28 15:18:08 -07:00
Jake VanderPlas
e059e3b52f
DOC: document jax.experimental.sparse.linalg
2023-04-28 14:18:50 -07:00
jax authors
0814b874d5
Merge pull request #15779 from jakevdp:keyty-itemsize
...
PiperOrigin-RevId: 527914223
2023-04-28 10:00:14 -07:00
jax authors
566f17513b
Merge pull request #15770 from gnecula:clean_call_tf
...
PiperOrigin-RevId: 527841341
2023-04-28 03:54:21 -07:00
George Necula
161664e858
[call_tf] Some cleanup of call_tf
...
The main cleanup is around _code_generator_and_avals, which in
an earlier version of the code was used for both abstract values
and for code generation. That is why it was cached, and why it
returned a code generator and abstract values. A while
ago we did a first round of cleaning to not use it for abstract
values. Now we can actually eliminate the function and inline
it directly.
A second improvement is to add the explicit error message from
TF commpilation, instead of just the generic message that
call_tf cannot be used with non-compileable functions.
2023-04-28 12:38:27 +02:00
jax authors
818805f6f9
Improve error message in _create_device_mesh_for_nd_torus
...
PiperOrigin-RevId: 527834640
2023-04-28 03:15:06 -07:00
John QiangZhang
5b4388ad03
Add new attribute function_list
to XLACallModule and bump the version.
...
PiperOrigin-RevId: 527741961
2023-04-27 18:28:12 -07:00
Jake VanderPlas
054fca5cd4
KeyArray: define itemsize on opaque dtype
2023-04-27 15:59:57 -07:00
Peter Hawkins
84c516974a
Revert: Switch to using Clang as the default compiler.
...
It appears this is causing deadlocks in multi-gpu tests.
PiperOrigin-RevId: 527706573
2023-04-27 15:52:28 -07:00
Jake VanderPlas
50405b1081
KeyArray: add size attribute
2023-04-27 14:06:55 -07:00
Yash Katariya
86c1f5bcee
Preserve the sharding type of physical sharding on logical sharding when .sharding
is accessed on a PRNGKeyArray
...
PiperOrigin-RevId: 527639257
2023-04-27 11:41:00 -07:00
Dinghua Li
7d6fb535a9
[shape_poly] Add support for shape polymorphism for _unsafe_rbg_split.
...
PiperOrigin-RevId: 527619524
2023-04-27 10:36:36 -07:00
jax authors
5027543c11
Merge pull request #15753 from gnecula:poly_unif
...
PiperOrigin-RevId: 527590089
2023-04-27 08:52:38 -07:00
George Necula
876c53abb7
[shape_poly] Refactor the unification of the argument abstract values with the actual arguments
...
This was called shape_poly.compute_dim_values. We rename it to
shape_poly.unify_avals_with_args and we add better error reporting to it.
Now it will identify the arg/kwarg where there is a shape discrepancy.
This is intended to be a pure refactoring, in preparation for adding
support for shape polymorphism to jax_export.call_exported.
2023-04-27 08:59:59 +02:00
jax authors
c6457f3153
Fix bug in the JAX spectral bisection eigensolver implementing the QDWH-eigh algorithm.
...
Use the largest columns, as intended, of the projector P for the initial guess of the subspace iteration, instead of the smallest.
PiperOrigin-RevId: 527418006
2023-04-26 17:19:05 -07:00
Yash Katariya
34d5a6259f
Default jax_spmd_mode to allow_jit which will allow explicit jax.jit to not raise the multihost error (since jit and pjit have been merged).
...
Implicit jit and apply_primitive will still raise an error though (which is recognized via inline parameter). Majority of jnp operations in JAX should be inlined.
PiperOrigin-RevId: 527398394
2023-04-26 15:56:46 -07:00
Skye Wanderman-Milne
67d80c21cb
Increase sharding count on nn_test and svd_test to avoid ASAN timeouts.
...
PiperOrigin-RevId: 527387005
2023-04-26 15:11:29 -07:00
jax authors
139e1c2f92
Merge pull request #15762 from jakevdp:keyarray-eq
...
PiperOrigin-RevId: 527366188
2023-04-26 13:54:32 -07:00
Jake VanderPlas
e46d7f673b
KeyArray: use assertArraysEqual in place of assertKeysEqual
2023-04-26 13:15:03 -07:00
Jake VanderPlas
fcffbac346
KeyArray: implement __eq__ and __ne__
2023-04-26 13:12:24 -07:00
jax authors
be74b07800
Merge pull request #15757 from jakevdp:keyarray-operators
...
PiperOrigin-RevId: 527353108
2023-04-26 13:09:59 -07:00
Jake VanderPlas
a47a71ff80
KeyArray: better errors for operators
2023-04-26 11:34:07 -07:00
jax authors
a6b052ae78
Merge pull request #15748 from jakevdp:keyarray-at
...
PiperOrigin-RevId: 527322597
2023-04-26 11:29:25 -07:00
Parker Schuh
5f4408ded7
Convert inspect_sharding to register the handler directly in c++ so that it can
...
work across the c-api boundary.
PiperOrigin-RevId: 527322386
2023-04-26 11:22:28 -07:00
jax authors
1551d9e816
Merge pull request #15755 from jakevdp:sharding-constraint-doc
...
PiperOrigin-RevId: 527317106
2023-04-26 11:04:53 -07:00
John QiangZhang
af051ba49c
[2/n] Embed the tf.Graph into the stablehlo.custom_call.
...
PiperOrigin-RevId: 527302563
2023-04-26 10:20:46 -07:00
Jake VanderPlas
8dc06ed2ce
Document jax.lax.with_sharding_constraint
2023-04-26 10:19:04 -07:00