Change in preparation for allowing JAX tests to run under Bazel.
Remove code to patch paths in xla_client.py in the wheel build script; the patch is no longer used.
PiperOrigin-RevId: 458522398
Change in preparation for supporting running JAX tests under Bazel. This change allows the Bazel py_library() to see version.py.
Update symlink_files Bazel macro to a newer version.
PiperOrigin-RevId: 458481396
The test appears to be failing at least some of the time
Typical failure:
=================================== FAILURES ===================================
___________ F32LobpcgTest.testLobpcgConsistencyF32cluster_k_2__n100 ____________
[gw12] linux -- Python 3.9.12 /usr/local/bin/python3.9
tests/lobpcg_test.py:370: in testLobpcgConsistencyF32
self.checkLobpcgConsistency(matrix_name, n, k, m, jnp.float32)
tests/lobpcg_test.py:203: in checkLobpcgConsistency
self.assertLess(
E AssertionError: DeviceArray(20, dtype=int32, weak_type=True) not less than 20 : expected early convergence iters 20 < max 20
I had erroneously assumed that GPU would be as-high accuracy for f64 (both in numerics and eigh) when submitting #3112, so I did not disable f64 tests on that platform. This is of course not the case, so those tests should be disabled.
This initial version is f32-only for accelerators, since it relies on an eigh call (which itself is f32 at most) in its inner loop.
For details, see jax.experimental.linalg.standard_lobpcg documentation.
This is a partial implementation of the similar [scipy lobpcg
function](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.lobpcg.html).
Currently jax and jaxlib have separate version numbers in the JAX source
tree. It is tedious and confusing to bump both version numbers.
However, there is a simpler way to think of things: it is the source
tree that is versioned using a single version number, and jax/jaxlib
releases are made using that unified source version number.
PiperOrigin-RevId: 458041752
The crashes on Mac were, as best we can tell, unrelated to this PR.
Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.
Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.
PiperOrigin-RevId: 457819042