15977 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
f28b20175f Update WORKSPACE and setup.py in preparation for jax/jaxlib 0.4.9 release 2023-05-04 14:38:46 -07:00
Matthew Johnson
2845df03fc In jax.remat/jax.checkpoint, don't cache on Tracers in static args
Why do we have caching in jax.remat at all? I added it in
https://github.com/google/jax/pull/11743 without much justification other than
it made some tests faster. I think I was worried that the switch to the new
remat's "initial-style" (jaxpr forming up-front) approach would regress
eager-mode performance, so I added benchmarks to measure it and then made those
fast with caching.

But the caching seems a bit too aggressive when static_argnums are involved. In
particular, I allowed caching on Tracer arguments (by object id). That seems
dangerous!

So the change here is to check whether any of the arguments marked static by
static_argnums are Tracers. If so, skip the caching. This change happens not to
affect the benchmarks at all.

PiperOrigin-RevId: 529502687
2023-05-04 13:42:00 -07:00
jax authors
e6e6490ab0 Merge pull request #15247 from jakevdp:ml-dtypes-finfo
PiperOrigin-RevId: 529463737
2023-05-04 11:21:04 -07:00
Jake VanderPlas
59e6ed213e Use ml_dtypes definition for jnp.finfo 2023-05-04 10:40:44 -07:00
pizzud
40d730be49 aot_test: Stop forcing XLA to assume a certain number of devices.
Test cases are still frequently skipped due to lack of CompileOptions
support, but the skip/run behavior does not seem to meaningfully change
compared to a clean checkout. This was verified by inserting an exception
in place of unittest.SkipTest.

PiperOrigin-RevId: 529437419
2023-05-04 09:53:26 -07:00
jax authors
68614b4dcc [XLA:TPU] Fix a bug in eigh that caused a slight loss of accuracy.
PiperOrigin-RevId: 529406623
2023-05-04 07:49:04 -07:00
Peter Hawkins
09fce87f54 Increase sharding of or disable some flaky CI tests.
PiperOrigin-RevId: 529405705
2023-05-04 07:41:56 -07:00
George Necula
40aa4e1781 [shape_poly] Disable tests for eigh shape polymorphism.
We are seeing some failures when comparing the results
for eigh with shape polymorphism and without.
Normally, shape polymorphism should not change the HLO
so a golden comparison is not necessarily bad, even though
for eigh we should check for correctness of the results
rather than identity.

We need to investigate this further but meanwhile turn
off these tests. The changes introduced recently for
shape polymorphism for eigh are not affecting the
code paths in absence of shape polymorphism. So it
is appropriate to just turn off the tests, and add
an error that shape polymorphism for eigh on
GPU is not ready.

PiperOrigin-RevId: 529388749
2023-05-04 06:14:18 -07:00
Adam Paszke
9c5e3f7ecc Verify that slices are trivial before discarding them in state primitives
At the moment, if `r` is a JAX ref then `r[0:1] = a` works, but it silently ignores the slices
and performs `r[:] = a` instead...

PiperOrigin-RevId: 529385973
2023-05-04 05:59:47 -07:00
jax authors
ebcad11862 Merge pull request #15842 from gnecula:tf_eval_shape
PiperOrigin-RevId: 529292794
2023-05-03 22:09:00 -07:00
George Necula
f66d15c831 [shape_poly] Add a version of jax.eval_shape that works with shape polymorphism
The use would be to find the output shapes for a function in
presence of shape polymorphism, and to compute the
`polymorphic_shapes` value that can be used in a subsequent
call to `jax2tf.convert`.
2023-05-04 06:54:13 +02:00
John QiangZhang
8acbe1557c Update stablehlo.custom_call call_target name based on design doc discussion.
PiperOrigin-RevId: 529281826
2023-05-03 21:23:13 -07:00
Yash Katariya
bffddf76cb Improve the error raised when wsc is passed a PartitionSpec without a mesh context manager
PiperOrigin-RevId: 529260748
2023-05-03 19:35:51 -07:00
Yash Katariya
47fc23d7ba Make rnn_bwd_abstract_eval backwards compatible by guarding it agains the jaxlib version
PiperOrigin-RevId: 529260653
2023-05-03 19:28:42 -07:00
jax authors
c15f30f22e Instrument a new metric to measure the savings of async checkpoint in JAX.
Create the new metric '/jax/checkpoint/write/async/thread_duration_sec' to measure the savings from the async thread creation time.

PiperOrigin-RevId: 529227213
2023-05-03 16:36:54 -07:00
jax authors
a7aaa5a6be Merge pull request #15852 from jakevdp:warning-suppression
PiperOrigin-RevId: 529192921
2023-05-03 15:36:31 -07:00
Yash Katariya
260a4305ac Guard the rnn changes on the jaxlib version to be backwards compatible
PiperOrigin-RevId: 529184172
2023-05-03 13:42:49 -07:00
Jake VanderPlas
13f7291ff6 Remove obsolete warning suppression from pyproject.toml 2023-05-03 13:16:41 -07:00
jax authors
d84d19b7d1 Merge pull request #15846 from jakevdp:deprecate-make-sharded
PiperOrigin-RevId: 529172585
2023-05-03 13:02:33 -07:00
jax authors
95e1e6d3ef Merge pull request #15849 from hawkinsp:xla
PiperOrigin-RevId: 529156176
2023-05-03 12:02:40 -07:00
Yash Katariya
9515ccf376 Fix pjit + vmap when device is passed as an argument to pjit/jit
PiperOrigin-RevId: 529155035
2023-05-03 11:55:23 -07:00
Peter Hawkins
51c40f04b9 Bump XLA version. 2023-05-03 18:52:04 +00:00
jax authors
b512477820 Merge pull request #15825 from jakevdp:jax2tf-prng
PiperOrigin-RevId: 529124128
2023-05-03 10:11:58 -07:00
Jake VanderPlas
9cfe77d5e1 Remove use of deprecated make_sharded_device_array 2023-05-03 10:11:29 -07:00
jax authors
47f5c225a9 Merge pull request #15789 from nouiz:ci
PiperOrigin-RevId: 529121113
2023-05-03 10:02:01 -07:00
John QiangZhang
7fe62b5406 Bump XLACallModule to version 5 and add the function_list.
PiperOrigin-RevId: 529106145
2023-05-03 09:05:08 -07:00
jax authors
5d143e6eea Merge pull request #15818 from froystig:random-bits-direct
PiperOrigin-RevId: 529090390
2023-05-03 07:56:17 -07:00
Rahul Joshi
9d750ae97d Fix pjit outfeed test avoid potential deadlocks.
PiperOrigin-RevId: 529076350
2023-05-03 06:51:26 -07:00
Benjamin Kramer
545c483e50 Re-enable testTruncNormPdf on CPU
Breaking change was reverted in LLVM 3b8bc83527

PiperOrigin-RevId: 529072697
2023-05-03 06:31:59 -07:00
Roy Frostig
ea3389205f add jax.random.bits 2023-05-03 06:10:05 -07:00
George Necula
a2ac510dc3 [shape_poly] Add support for dynamic shapes for eigh
We can only handle dynamic sizes for the batch dimensions for now.

PiperOrigin-RevId: 529001830
2023-05-02 23:27:59 -07:00
George Necula
6dfd248e74 [shape_poly] Add support for shape polymorphism for prng GPU custom call
We are using the new support for dynamic shapes for hlo.CustomCallOp, where
we need to pass the output shapes as additional operands.

This allows us to enable multiple "random" tests that were previously disabled.

PiperOrigin-RevId: 528990469
2023-05-02 22:26:58 -07:00
James Bradbury
8afec934fe [shard_map] Avoid nondeterminism in shmap transpose psum axes
PiperOrigin-RevId: 528969592
2023-05-02 21:08:14 -07:00
jax authors
36975245d4 Modify the QDWH algorithm to run subspace iteration on the projector of smallest rank.
The QDWH splitting step involves two orthogonal projectors
P_plus = -0.5*(U-I) and P_minus = 0.5*(U+I), one of which will have rank k and the other rank n-k. Ideally, if we are able to pick the median eigenvalue for the split point optimally, k will be near n/2, and the rank of the two projectors will be similar. However, if our guess of the median eigenvalue is poor or the matrix is rank-deficient, k can be far from n/2, and the cost of the subspace iteration will be higher for the projector of higher rank, since it involves computing the QR decomposition of a matrix of size n x rank.

This change makes the algorithm adaptively pick the projector of lower rank.

PiperOrigin-RevId: 528941151
2023-05-02 18:14:33 -07:00
Matthew Johnson
56feaca7f9 update cuDNN RNN code not to save 'workspace' scratch between fwd and bwd
PiperOrigin-RevId: 528928263
2023-05-02 17:05:42 -07:00
Yash Katariya
b698390171 Handle multihost pmap in pmap shard_map merge. This involves lifting the host local inputs to global inputs and vice-versa on the outputs.
To handle Tracers, ShapedArray, concrete Arrays, etc `global_array_to_host_local_array` and `host_local_array_to_global_array` are now primitives.

PiperOrigin-RevId: 528925663
2023-05-02 16:53:22 -07:00
Yash Katariya
7530ac1e09 Improve the error message for incompatible avals when the aval is a scalar
PiperOrigin-RevId: 528918215
2023-05-02 16:22:30 -07:00
Yash Katariya
356cac014c [Rollback] Remove py3.8 support from jax as per https://jax.readthedocs.io/en/latest/deprecation.html
PiperOrigin-RevId: 528907173
2023-05-02 15:40:27 -07:00
Yash Katariya
6506ee2a40 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Rollback] Update required Python version to 3.9

PiperOrigin-RevId: 528905991
2023-05-02 15:33:29 -07:00
jax authors
519a96305b Merge pull request #15827 from jakevdp:update-lint-typecheck
PiperOrigin-RevId: 528898652
2023-05-02 15:04:02 -07:00
jax authors
162f09fc8d Stop recursion in spectral bisection eigensolver when the remaining sub-matrix has norm less than epsilon times the input matrix norm, which means that it is pure numerical noise.
PiperOrigin-RevId: 528891206
2023-05-02 14:35:07 -07:00
Peter Hawkins
57e62ca03c Reenable scipy_stats_test in CI.
Disable testTruncNormPdf on CPU, which is failing after an LLVM update.

PiperOrigin-RevId: 528884880
2023-05-02 14:11:08 -07:00
Jake VanderPlas
2fa2f82274 CI: run lint_and_typecheck under newest Python version 2023-05-02 12:09:53 -07:00
Jake VanderPlas
2956ca6e38 custom prng: fix jax2tf random_split test 2023-05-02 10:57:58 -07:00
Yash Katariya
40349a8612 Normalize 1 length tuples to a string while getting PartitionSpec from array mapping.
Fixes https://github.com/google/jax/issues/15782

PiperOrigin-RevId: 528796985
2023-05-02 08:55:40 -07:00
Frederic Bastien
8f1711a431 Install as all the other. 2023-05-02 08:36:08 -07:00
jax authors
d5289e627f Merge pull request #15804 from froystig:issue13949
PiperOrigin-RevId: 528790988
2023-05-02 08:30:46 -07:00
Yash Katariya
c52e48b6c0 Only return the same input Sharding object is the original aval's ndim and out_aval's ndim are the same.
This is because if both the OpShardings are replicated then the ndim is not encoded in the OpSharding and it will return True even if the Sharding is incompatible with the output's ndim. Concretely `NamedSharding({'x': 1, y: '2'}, P('x'))` is not compatible with a input with `ndim == 0`.

PiperOrigin-RevId: 528621971
2023-05-01 17:39:51 -07:00
jax authors
12e3db5fbc Merge pull request #15813 from jakevdp:keyarray-device-put-sharded
PiperOrigin-RevId: 528578837
2023-05-01 14:41:55 -07:00
Jake VanderPlas
979aa3235b KeyArray: implement sharded & replicated device_put 2023-05-01 14:17:01 -07:00