* strip DOS end-of-line characters from build.py for consistency with the rest of the source tree.
* use shutil.copy() instead of shutil.copyfile(). On Unix systems we must preserve execute permissions.
* add code to explicitly delete and recreate the target directory.
* Move build/jaxlib/__init_py to jaxlib/__init__.py and have the script move it into position, so the output directory for the jaxlib is an empty directory that the script creates.
1. Build on Windows
2. Fix OverflowError
When calling `key = random.PRNGKey(0)` OverflowError: Python int too
large to convert to C long for casting value 4294967295 (0xFFFFFFFF)
from python int to int32.
3. fix file path in regex of errors_test
4. handle ValueError of os.path.commonpath
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)
For the benchmark in #2952 on my workstation:
Before:
```
907.3490574884647
max: 4.362646594533903e-08
mean: 6.237288307614869e-09
min: 0.0
numpy fft execution time [ms]: 37.088446617126465
jax fft execution time [ms]: 74.93342399597168
```
After:
```
907.3490574884647
max: 1.9057386696477137e-12
mean: 3.9326737908882566e-13
min: 0.0
numpy fft execution time [ms]: 37.756404876708984
jax fft execution time [ms]: 28.128278255462646
```
Fixes https://github.com/google/jax/issues/2952
PiperOrigin-RevId: 338743753
This change:
* Updates our jaxlib build scripts to add `+cudaXXX` to the wheel
version, where XXX is the CUDA version number (e.g. `110`). nocuda
builds remain unchanged and do not have this extra identifier.
* Adds `generate_release_index.py`, which writes an html page that pip
can use to find the cuda wheels. (I based this format off of
wheel PyTorch's index).
* Updates the README to use the new local version identifier + wheel
index.
The end result is that the command to install cuda wheels is now much
simpler.
I manually made copies of the latest jaxlib 0.1.55 wheels that have
the local version identifiers, so the new installation commands
already work (as well as the old ones, until the next jaxlib release
using the new tooling).
Fow now, I put the html index to the GCP bucket with the wheels. We
can move it to a prettier URL if/when we have one.
Changes:
- Fix unnecessary generator
- Iterate dictionary directly instead of calling .keys()
- Remove global statement at the module level
- Use list() instead of a list comprehension
- Use with statement to open the file
- Merge isinstance calls
* Add options to compute L/R eigenvectors in geev.
The new arguments are by default set to True to ensure backwards
compatibility between jaxlib and jax.
Reformulate eig-related operations based on the new geev API.
* Addressed hawkinsp's comments from google/jax#3882.
Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
1. `wheel.pep425tags` has been removed as of
https://github.com/pypa/setuptools/pull/1829. Use the new
`packaging.tags` instead.
2. Add `--allow-downgrades` to cuda install command. I'm not sure this
is always necessary, but I ran into it, I'm guessing due to a cached
docker image.
We do this to see if this reduces the incidence of errors fetching
the tf-nightly package. These tests are being run when we import
the code in Google.
* Add jax.image.resize.
This is a port of `tf.image.resize()` and the `ScaleAndTranslate` operator.
While I don't expect this implementation to be particularly fast, it is a useful generic implementation to which we can add optimized special cases as the need arises.
* Refactored host_callback to use the C++ runtime.
* The new runtime makes it unnecessary to start the outfeed_receiver
in the user's code
* We don't need msgpack anymore
* There is an interaction between host_callback and using lax.outfeed.
I am trying to solve this by (a) making host_callback_test stop the
outfeed receiver on finish and infeed_test on start, and (b)
telling pytest-xdist to run all the tests from one file into
a single worker.
* Use dynamic loading to locate CUDA libraries in jaxlib.
This should allow jaxlib CUDA wheels to be manylinux2010 compliant.
* Tag CUDA jaxlib wheels as manylinux2010.
Drop support for CUDA 9.2, add support for CUDA 11.0.
* Reorder CUDA imports.
We choose the same set as TensorFlow (minus 3.7, which TF is apparently considering dropping anyway).
This avoids a slow PTX -> SASS compilation on first time startup.
* raise minimum Bazel version to 2.0.0 to match TensorFlow.
* set --experimental_repo_remote_exec since it is required by the TF build.
* bump TF/XLA version.
* use the --config=short_logs trick from TF to suppress build warnings.