12232 Commits

Author SHA1 Message Date
Peter Hawkins
7c49864fdf Symlink xla_client and xla_extension into jaxlib rather than copying them into place in the wheel build.
Change in preparation for allowing JAX tests to run under Bazel.

Remove code to patch paths in xla_client.py in the wheel build script; the patch is no longer used.

PiperOrigin-RevId: 458522398
2022-07-01 12:31:42 -07:00
Yash Katariya
68b9eaf0ee [Rollback] Use checkpoint count rather than final_checkpoint_dir as the barrier ID so that it can be unique everytime.
PiperOrigin-RevId: 458496801
2022-07-01 10:26:50 -07:00
Peter Hawkins
e63765a7a6 Use symlink_files() to add version.py to jaxlib, rather than copying it in as part of the wheel assembly process.
Change in preparation for supporting running JAX tests under Bazel. This change allows the Bazel py_library() to see version.py.

Update symlink_files Bazel macro to a newer version.

PiperOrigin-RevId: 458481396
2022-07-01 09:07:03 -07:00
jax authors
d785f30e67 Merge pull request #11336 from hawkinsp:lobpcg
PiperOrigin-RevId: 458467593
2022-07-01 07:41:11 -07:00
Peter Hawkins
02534bdff1 Move MLIR dependencies onto //jaxlib rule instead of wheel build rule.
Change in preparation for allowing testing with Bazel.

PiperOrigin-RevId: 458460128
2022-07-01 06:54:31 -07:00
Peter Hawkins
1d83f7dd6f Disable lobpcg F32 consistency test on GPU.
The test appears to be failing at least some of the time

Typical failure:

=================================== FAILURES ===================================
___________ F32LobpcgTest.testLobpcgConsistencyF32cluster_k_2__n100 ____________
[gw12] linux -- Python 3.9.12 /usr/local/bin/python3.9
tests/lobpcg_test.py:370: in testLobpcgConsistencyF32
    self.checkLobpcgConsistency(matrix_name, n, k, m, jnp.float32)
tests/lobpcg_test.py:203: in checkLobpcgConsistency
    self.assertLess(
E   AssertionError: DeviceArray(20, dtype=int32, weak_type=True) not less than 20 : expected early convergence iters 20 < max 20
2022-07-01 09:47:49 -04:00
Sharad Vikram
7b59bd02ae Deflake debugger_test
PiperOrigin-RevId: 458392106
2022-06-30 23:10:32 -07:00
jax authors
b05e2b4c57 Merge pull request #11123 from ROCmSoftwarePlatform:revert_rocm_unit_test_enablement
PiperOrigin-RevId: 458356761
2022-06-30 18:24:08 -07:00
Yash Katariya
4635ec676d Add start and end logging for commit to the storage layer
PiperOrigin-RevId: 458343644
2022-06-30 16:58:06 -07:00
jax authors
33f1f40b20 Merge pull request #11298 from pschuh:axis-cache-env
PiperOrigin-RevId: 458328457
2022-06-30 15:42:48 -07:00
jax authors
a31a0562bf Merge pull request #11326 from vlad17:f64-cpu
PiperOrigin-RevId: 458328276
2022-06-30 15:37:36 -07:00
Vladimir Feinberg
4bfb26c709 Skip F64 tests on GPU.
I had erroneously assumed that GPU would be as-high accuracy for f64 (both in numerics and eigh) when submitting #3112, so I did not disable f64 tests on that platform. This is of course not the case, so those tests should be disabled.
2022-06-30 15:24:10 -07:00
jax authors
ba36ef12a0 Merge pull request #11268 from mattjj:djax-ad-linearize
PiperOrigin-RevId: 458318428
2022-06-30 14:48:40 -07:00
Matthew Johnson
004b59fbc9 [dynamic-shapes] basic linearize and grad working 2022-06-30 14:30:22 -07:00
jax authors
4446c73d47 Merge pull request #10962 from vlad17:topk
PiperOrigin-RevId: 458312605
2022-06-30 14:21:24 -07:00
Vladimir Feinberg
76fcf63fb4 Add initial LOBPCG top-k eigenvalue solver (#3112)
This initial version is f32-only for accelerators, since it relies on an eigh call (which itself is f32 at most) in its inner loop.

For details, see jax.experimental.linalg.standard_lobpcg documentation.

This is a partial implementation of the similar [scipy lobpcg
function](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.lobpcg.html).
2022-06-30 13:01:54 -07:00
jax authors
4e224bcfb9 [jax2tf] Add support for common audio convolutions (1D variants, dilated depthwise, transpose with SAME padding).
PiperOrigin-RevId: 458266485
2022-06-30 11:02:43 -07:00
Felix Chern
61b3dc5801 [JAX] Update approx_top_k doc with arxiv link.
PiperOrigin-RevId: 458258457
2022-06-30 10:29:22 -07:00
Yash Katariya
98ae6aa12b Use checkpoint count rather than final_checkpoint_dir as the barrier ID so that it can be unique everytime.
PiperOrigin-RevId: 458242788
2022-06-30 09:20:31 -07:00
Colin Gaffney
fa1a93195a Create AsyncManager, which factors out thread management functionality from GlobalAsyncCheckpointManager and makes it available for use (such as in Orbax) by classes supporting async read/write.
PiperOrigin-RevId: 458081905
2022-06-29 15:47:42 -07:00
jax authors
856eb3cad5 Merge pull request #11311 from jakevdp:fix-vmap-kwarg
PiperOrigin-RevId: 458076604
2022-06-29 15:22:43 -07:00
Jake VanderPlas
cb25a96d43 vmap: better errors for mismatched axis in keyword arguments 2022-06-29 14:31:03 -07:00
Parker Schuh
6c5d204d7e Jax caches should depend on axis env. 2022-06-29 14:25:14 -07:00
jax authors
7d637d15e4 Merge pull request #11301 from sharadmv:for-loop
PiperOrigin-RevId: 458057864
2022-06-29 14:05:08 -07:00
jax authors
35f5d0fa06 Merge pull request #11266 from ROCmSoftwarePlatform:rocm_fixes
PiperOrigin-RevId: 458043701
2022-06-29 12:59:58 -07:00
Peter Hawkins
1e171ccd10 Unify jax and jaxlib versions.
Currently jax and jaxlib have separate version numbers in the JAX source
tree. It is tedious and confusing to bump both version numbers.

However, there is a simpler way to think of things: it is the source
tree that is versioned using a single version number, and jax/jaxlib
releases are made using that unified source version number.

PiperOrigin-RevId: 458041752
2022-06-29 12:51:01 -07:00
Sharad Vikram
790135989d Add scan implementation using for and tests 2022-06-29 12:49:41 -07:00
jax authors
eb0052bdf2 Merge pull request #11296 from rsuderman:AddMLProgram
PiperOrigin-RevId: 458013593
2022-06-29 10:57:16 -07:00
jax authors
2842ccd958 Merge pull request #11299 from sharadmv:debugger
PiperOrigin-RevId: 458013391
2022-06-29 10:51:39 -07:00
jax authors
6959ec8bdf Merge pull request #11308 from jakevdp:flat-unimplemented
PiperOrigin-RevId: 458008041
2022-06-29 10:32:11 -07:00
jax authors
79cc81075a Merge pull request #11307 from jakevdp:bfgs-fix
PiperOrigin-RevId: 458007403
2022-06-29 10:26:38 -07:00
Jake VanderPlas
8336a8b2d8 DeviceArray: raise explicit NotImplementedError for arr.flat 2022-06-29 10:11:22 -07:00
Sharad Vikram
e8bd71b31c Add JAX debugger 2022-06-29 10:08:58 -07:00
Jake VanderPlas
93deb86710 jax.scipy.optimize: fix type inconsistency 2022-06-29 09:47:53 -07:00
jax authors
637bb61915 Merge pull request #11264 from PWhiddy:patch-1
PiperOrigin-RevId: 457997298
2022-06-29 09:40:46 -07:00
jax authors
dd948141d5 Merge pull request #11302 from ROCmSoftwarePlatform:upgrade_to_rocm_52
PiperOrigin-RevId: 457994618
2022-06-29 09:27:24 -07:00
jax authors
95d7a6677e Merge pull request #11234 from jakevdp:binop-validation
PiperOrigin-RevId: 457987820
2022-06-29 08:56:45 -07:00
Jake VanderPlas
39b0ff7eb6 jnp.ndarray: raise TypeError for binary operations with builtin collections 2022-06-29 08:22:05 -07:00
Rohit Santhanam
080cf47002 [ROCm] Fixes for compilation failures caused by compiler changes in ROCm Tensorflow fork. 2022-06-29 14:34:08 +00:00
Rohit Santhanam
721602ef59 Upgrade ROCm build docker to ROCm version 5.2. 2022-06-28 21:43:07 -07:00
jax authors
7011de56ef Merge pull request #11292 from ArjunSharda:main
PiperOrigin-RevId: 457859272
2022-06-28 17:55:01 -07:00
jax authors
a319dd842c Merge pull request #11237 from dfm:kde
PiperOrigin-RevId: 457854908
2022-06-28 17:27:32 -07:00
jax authors
12cb4c2d0c Merge pull request #11300 from sharadmv:release
PiperOrigin-RevId: 457847408
2022-06-28 16:46:18 -07:00
Robert Suderman
45046857f6 Fix ModuleNotFoundError for phawkins only with version 2022-06-28 22:42:45 +00:00
Sharad Vikram
fcf65ac64e Bump minimum jaxlib version to 0.3.10 2022-06-28 15:39:21 -07:00
jax authors
10320cbcc2 Merge pull request #11294 from sharadmv:release
PiperOrigin-RevId: 457828893
2022-06-28 15:12:14 -07:00
Sharad Vikram
1daea700f2 Bump JAX/Jaxlib versions 2022-06-28 14:36:47 -07:00
Peter Hawkins
47f2f091bc Reapply: Drop flatbuffers as a Python dependency of JAX.
The crashes on Mac were, as best we can tell, unrelated to this PR.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457819042
2022-06-28 14:25:14 -07:00
Dan F-M
dc2a50ff21 looser TPU precision 2022-06-28 17:07:30 -04:00
Robert Suderman
64aaeb2da9 Make ml_program import conditional 2022-06-28 20:43:50 +00:00