16016 Commits

Author SHA1 Message Date
jax authors
1b9180167b Merge pull request #15945 from skye:version
PiperOrigin-RevId: 530722158
jaxlib-v0.4.9 jax-v0.4.9 jax-v0.4.9-rc
2023-05-09 14:59:20 -07:00
Skye Wanderman-Milne
5bcd9dcc46 Update WORKSPACE and setup.py in preparation for jax/jaxlib 0.4.9 release, take 2 2023-05-09 14:49:54 -07:00
Yash Katariya
befa29b566 Fix the cache on to_gspmd_sharding to depend on if device/backend is set on pjit/jit.
Before if a SingleDeviceSharding went via `to_gspmd_sharding` and then the same SingleDeviceSharding (created when device/backend is set) went via `to_gspmd_sharding`, we would hit the cache and return the first SingleDeviceSharding which didn't have the dynamic attribute on it.

This would eventually cause errors down the stack. The fix is to explicitly thread this argument through all the caches so we miss them and create the correct sharding.

PiperOrigin-RevId: 530712918
2023-05-09 14:24:21 -07:00
Yash Katariya
2694bf6207 Use set equality operators instead of intersection because I didn't know set had equality operators.
PiperOrigin-RevId: 530688786
2023-05-09 12:55:47 -07:00
jax authors
68ba54241c Merge pull request #15929 from gnecula:fix_mlir_ir
PiperOrigin-RevId: 530675418
2023-05-09 12:02:35 -07:00
Peter Hawkins
a89c377762 [GPU] Fix another instance of missing stream synchronization in RNN kernels.
PiperOrigin-RevId: 530660502
2023-05-09 11:08:24 -07:00
jax authors
a2b5bd5230 Merge pull request #15931 from geraschenko:bcoo_reshape
PiperOrigin-RevId: 530657565
2023-05-09 10:58:53 -07:00
Anton Geraschenko
27aa5fb774 Make dimensions argument of bcoo_reshape optional. 2023-05-09 10:38:18 -07:00
Yash Katariya
18d19caa1c Add McJAX resharding to device_put. Allow resharding if inputs and target sharding have the same set of devices but different order.
We can make this general enough in JAX slowly and carefully and would likely require a refactor of how device_assignment is chosen.

Fixes: https://github.com/google/jax/issues/15903
PiperOrigin-RevId: 530638856
2023-05-09 09:58:12 -07:00
jax authors
cf4c1edafa Merge pull request #15920 from froystig:issue15869
PiperOrigin-RevId: 530634021
2023-05-09 09:39:48 -07:00
George Necula
daf6a30f6e Import "ir" directly rather than as "mlir.ir" 2023-05-09 17:55:13 +02:00
jax authors
cb3a4f3dbf Merge pull request #15859 from gnecula:poly_rng
PiperOrigin-RevId: 530606467
2023-05-09 07:43:31 -07:00
George Necula
de2a811fe9 [shape_poly] Improvements and more testing for shape polymorphism for random primitives
* added support for shape polymorphism for partitionable threefry and for
    random_split.
  * removed footgun that was ignoring the partitionable flag in presence of
    shape polymorphism.
  * Replicated the PRNG tests for threefry (partitionable and non-partitionable),
    and unsafe_rbg.
  * Added general support for overriding jax.config flags for PolyHarness

This fixes the known bug with random_gamma.
The known missing feature is shape polymorphism for RngBitGenerator.
https://github.com/openxla/stablehlo/issues/1344
2023-05-09 13:55:27 +02:00
George Necula
16881e623f [shape_poly] Improve testing of vmap test harnesses
Previously, we disabled `check_result` (check that the JAX native and JAX with shape polymorphism produce the same result) for test harnesses that are created by vmap on primitive harnesses if the primitive harness has a custom assertion.

Now we enable that checking even for those harnesses, and we use the same custom assertion.

PiperOrigin-RevId: 530547784
2023-05-09 02:29:22 -07:00
George Necula
d66be780a5 [shape_poly] Add support for shape polymorphism with native serialization for lax.linalg.qr on TPU
PiperOrigin-RevId: 530522799
2023-05-09 00:25:45 -07:00
Roy Frostig
051c5dda6e delegate select lowering to opaque dtype rule
... and implement it for PRNG key arrays
2023-05-08 19:02:42 -07:00
jax authors
236c74cad7 Merge pull request #15909 from skye:version
PiperOrigin-RevId: 530450947
2023-05-08 17:22:23 -07:00
Peter Hawkins
f168a1560c [GPU] Add missing stream synchronization to tridiagonal_solve gtsv2 call.
May fix flaky failures in CI.

Make stream argument to Pool::Borrow() mandatory to minimize chance of forgetting it.

PiperOrigin-RevId: 530425766
2023-05-08 15:37:04 -07:00
Peter Hawkins
00b75aff82 Add tests for negative inputs to top-k.
Make the top-k test inputs larger.

This test would have caught the top-k bug fixed by https://github.com/openxla/xla/pull/2809

PiperOrigin-RevId: 530398528
2023-05-08 13:51:24 -07:00
Peter Hawkins
6b9a109939 Use stream-synchronized copy in rnn_kernels.cc.
May fix flaky wrong outputs sometimes seen in CI.

Also check for errors in another use of gpuStreamSynchronize().

PiperOrigin-RevId: 530391917
2023-05-08 13:28:08 -07:00
George Necula
821b38da12 [shape_poly] Remove old code disabling some tests
PiperOrigin-RevId: 530390362
2023-05-08 13:20:52 -07:00
Skye Wanderman-Milne
5e9364abc6 Revert setup.py changes.
This reverts the setup.py changes from
f28b20175f307d5a56502446a9706480126a5bd4. We actually need to fix some
more issues before releasing 0.4.9, so fix the install at HEAD in the
meantime.
2023-05-08 09:58:51 -07:00
John QiangZhang
47df8628a0 Fix the problem for tf function return StatefulPartitionedCall during jax2tf.call_tf.
PiperOrigin-RevId: 529964653
2023-05-06 08:30:26 -07:00
jax authors
d508f08121 Merge pull request #15893 from froystig:jex-jep
PiperOrigin-RevId: 529812297
2023-05-05 14:17:12 -07:00
Roy Frostig
ce840a9cd8 JEP: jax.extend, a module for extensions 2023-05-05 13:50:22 -07:00
jax authors
a4382d7600 Merge pull request #15890 from jakevdp:delete-slice
PiperOrigin-RevId: 529782623
2023-05-05 12:19:25 -07:00
Jake VanderPlas
882edd4924 jnp.delete: avoid large trace-time constant when deleting slice 2023-05-05 11:38:58 -07:00
Yash Katariya
1629c6c76b Make jax.jit work with vmap(..., spmd_axis_name) when there is no mesh context manager.
This will only work if the input Array's sharding is a NamedSharding

Fixes https://github.com/google/jax/issues/15886

PiperOrigin-RevId: 529758233
2023-05-05 10:48:33 -07:00
jax authors
d992080bfa Merge pull request #15885 from cgarciae:improve-serial-loop-docs
PiperOrigin-RevId: 529747371
2023-05-05 10:09:56 -07:00
Peter Hawkins
e8c735125c Disable more tests that are flaky in CI.
PiperOrigin-RevId: 529724306
2023-05-05 08:33:33 -07:00
Cristian Garcia
36c6fce9d5 improve serial_loop docstring 2023-05-05 15:27:33 +00:00
Yash Katariya
a6254c75e0 Improve the shape incompatible error message by adding the argument/result name path to it.
PiperOrigin-RevId: 529605855
2023-05-04 21:50:04 -07:00
Yash Katariya
36ad0d4459 Add docs on how to create a jax.Array from data parallel host local inputs
PiperOrigin-RevId: 529579626
2023-05-04 19:12:08 -07:00
jax authors
62efeb83a9 Improve the bisection strategy in eigh.
Background: Currently, we pad the sub-matrices that occur during the spectral bisection algorithm to fit in a small number of buckets, in order to keep compilation time down. Each unique bucket size gives rise to a separate JIT compilation. The current strategy uses powers of two times the termination size of 256, below which we switch to a Jacobi solver. One issue is that the bisection step rarely splits the matrix in two exact equal parts, so one of the child-problems is forced to use the large bucket size of its parent, which wastes significant device cycles.

This changes modifies the bucket selection strategy to not use [256, 512, 1024, ... n], but instead include a little slack at each level, such that both sub-problems from a non-perfect split will likely fall into the smaller bucket size. Specifically, we add 4% slack and round up to the next larger multiple of 32. These heuristic values were found experimentally. As an example, for n = 2048, we get the bucket sizes [2048, 1056, 544, 288, 256].

Maasuring runtimes on random matrices of size 512, 1024, and 2048, we see significant speedups:

   N |  wall time before | wall time after
===========================================
 512 |       27.8 ms     |      24.8 ms
1024 |       97.6 ma     |      79.3 ms
2048 |      414.5 ms     |     308.0 ms

PiperOrigin-RevId: 529567005
2023-05-04 18:01:23 -07:00
jax authors
99e7e8ee17 Merge pull request #15874 from jakevdp:keyarray-make-array
PiperOrigin-RevId: 529550502
2023-05-04 16:52:29 -07:00
Jake VanderPlas
4db717c52a KeyArray: support make_array_from_* APIs 2023-05-04 16:32:49 -07:00
jax authors
662293989a Merge pull request #15853 from jakevdp:jax2tf-keytype
PiperOrigin-RevId: 529531886
2023-05-04 15:33:02 -07:00
jax authors
505a752e23 Merge pull request #15873 from google:version
PiperOrigin-RevId: 529519828
2023-05-04 14:46:12 -07:00
Skye Wanderman-Milne
f28b20175f Update WORKSPACE and setup.py in preparation for jax/jaxlib 0.4.9 release 2023-05-04 14:38:46 -07:00
Jake VanderPlas
b031cc2660 jax2tf: better handling for opaque dtypes 2023-05-04 14:22:15 -07:00
Matthew Johnson
2845df03fc In jax.remat/jax.checkpoint, don't cache on Tracers in static args
Why do we have caching in jax.remat at all? I added it in
https://github.com/google/jax/pull/11743 without much justification other than
it made some tests faster. I think I was worried that the switch to the new
remat's "initial-style" (jaxpr forming up-front) approach would regress
eager-mode performance, so I added benchmarks to measure it and then made those
fast with caching.

But the caching seems a bit too aggressive when static_argnums are involved. In
particular, I allowed caching on Tracer arguments (by object id). That seems
dangerous!

So the change here is to check whether any of the arguments marked static by
static_argnums are Tracers. If so, skip the caching. This change happens not to
affect the benchmarks at all.

PiperOrigin-RevId: 529502687
2023-05-04 13:42:00 -07:00
jax authors
e6e6490ab0 Merge pull request #15247 from jakevdp:ml-dtypes-finfo
PiperOrigin-RevId: 529463737
2023-05-04 11:21:04 -07:00
Jake VanderPlas
59e6ed213e Use ml_dtypes definition for jnp.finfo 2023-05-04 10:40:44 -07:00
pizzud
40d730be49 aot_test: Stop forcing XLA to assume a certain number of devices.
Test cases are still frequently skipped due to lack of CompileOptions
support, but the skip/run behavior does not seem to meaningfully change
compared to a clean checkout. This was verified by inserting an exception
in place of unittest.SkipTest.

PiperOrigin-RevId: 529437419
2023-05-04 09:53:26 -07:00
jax authors
68614b4dcc [XLA:TPU] Fix a bug in eigh that caused a slight loss of accuracy.
PiperOrigin-RevId: 529406623
2023-05-04 07:49:04 -07:00
Peter Hawkins
09fce87f54 Increase sharding of or disable some flaky CI tests.
PiperOrigin-RevId: 529405705
2023-05-04 07:41:56 -07:00
George Necula
40aa4e1781 [shape_poly] Disable tests for eigh shape polymorphism.
We are seeing some failures when comparing the results
for eigh with shape polymorphism and without.
Normally, shape polymorphism should not change the HLO
so a golden comparison is not necessarily bad, even though
for eigh we should check for correctness of the results
rather than identity.

We need to investigate this further but meanwhile turn
off these tests. The changes introduced recently for
shape polymorphism for eigh are not affecting the
code paths in absence of shape polymorphism. So it
is appropriate to just turn off the tests, and add
an error that shape polymorphism for eigh on
GPU is not ready.

PiperOrigin-RevId: 529388749
2023-05-04 06:14:18 -07:00
Adam Paszke
9c5e3f7ecc Verify that slices are trivial before discarding them in state primitives
At the moment, if `r` is a JAX ref then `r[0:1] = a` works, but it silently ignores the slices
and performs `r[:] = a` instead...

PiperOrigin-RevId: 529385973
2023-05-04 05:59:47 -07:00
jax authors
ebcad11862 Merge pull request #15842 from gnecula:tf_eval_shape
PiperOrigin-RevId: 529292794
2023-05-03 22:09:00 -07:00
George Necula
f66d15c831 [shape_poly] Add a version of jax.eval_shape that works with shape polymorphism
The use would be to find the output shapes for a function in
presence of shape polymorphism, and to compute the
`polymorphic_shapes` value that can be used in a subsequent
call to `jax2tf.convert`.
2023-05-04 06:54:13 +02:00