65 Commits

Author SHA1 Message Date
Nitin Srinivasan
031614c22b Pin numpy~=2.1.0 in workflow file instead of test-requirements.txt
PiperOrigin-RevId: 737632771
2025-03-17 08:59:06 -07:00
Nitin Srinivasan
5944c9ed65 Install test dependencies from test-requirements.txt instead of requirements.in
PiperOrigin-RevId: 736878834
2025-03-14 08:57:20 -07:00
jax authors
007fc7a6f1 Remove version limit for setuptools dependency.
PiperOrigin-RevId: 735453796
2025-03-10 11:36:17 -07:00
jax authors
615219b1f6 Remove tensorstore dependency from //jax/experimental/array_serialization:serialization in OSS (see https://github.com/google/tensorstore/issues/218)
Disable serialization_test in OSS.

PiperOrigin-RevId: 731463136
2025-02-26 14:47:16 -08:00
jax authors
eb912ad0d9 Create jax wheel build target.
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)

Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).

You can still build the `jax` wheel with `python3 -m build` command.

Bazel `jax` wheel target: `//:jax_wheel`

Environment variables combinations for creating wheels with different versions:
  * self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * release: `--repo_env=ML_WHEEL_TYPE=release`
  * release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
  * nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 730916743
2025-02-25 09:30:08 -08:00
Kanglan Tang
59a3552ae6 Remove portpicker for free threaded python 3.13t in test-requirements.txt
PiperOrigin-RevId: 722776783
2025-02-03 13:30:01 -08:00
jax authors
a0b0a8e5a1 Set minimum supported Python version to 3.10 for matplotlib.
Temporary fixes an issue with `python -m build` that fails when python 3.8 is used because `matplotlib~=3.8.4` is unavailable for this python version.

We are working on creating Bazel build rule with the hermetic Python for JAX wheel ([we already have Jaxlib and plugins build rules ready](https://github.com/jax-ml/jax/pull/23276)). The required python modules are provided in requirements.in file, so when we implement Bazel build rule for JAX wheel, requirements.in will be the only source of dependencies, and test-requirements.txt won't be needed for building JAX wheel.

PiperOrigin-RevId: 692260046
2024-11-01 12:34:28 -07:00
Vadym Matsishevskyi
a75d94622c Reverts 72f9a493589a1046e6927a5f16d7dc71df530743
PiperOrigin-RevId: 691843537
2024-10-31 10:05:22 -07:00
Peter Hawkins
72f9a49358 Reverts 6d8950c04f23ad15a0443006f1e5bd21bfa84156
PiperOrigin-RevId: 691222756
2024-10-29 17:46:55 -07:00
Vadym Matsishevskyi
6d8950c04f Cleanup requirements.in and test-requirements.txt
PiperOrigin-RevId: 691208596
2024-10-29 16:50:54 -07:00
8bitmp3
60a06fd4c9
Update pillow version in JAX build test-requirements.txt 2024-09-25 14:55:46 +00:00
jax authors
0c7c71e640 Update python version from 3.12 to 3.13.0rc2 in Github presubmit jobs.
PiperOrigin-RevId: 676140293
2024-09-18 14:49:42 -07:00
Sergei Lebedev
3b1b5fda81 Added filelock to test-requirements.txt and requirements lock files
This is a follow up to #21741.
2024-06-11 11:53:10 +01:00
Pearu Peterson
fdb5015909 Evaluate the correctness of JAX complex functions using mpmath as a reference 2024-03-21 23:35:29 +02:00
George Necula
b077483bfa [export] Add support for serialization and deserialization of Exported
At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process.

Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow.

Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure).

In the process of implementing this we have done some small cleanup of the Exported structure:

  * renamed serialization_version to mlir_module_serialization_version
  * renamed disabled_checks to disabled_safety_checks

This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export.

There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR.

PiperOrigin-RevId: 590078785
2023-12-11 23:23:02 -08:00
Peter Hawkins
cf28e2c5fa Small improvements to build/build.py
Add a --verbose option that logs all shell() commands run by the script.
Remove some Python 2 backward compatibility logic related to urllib and shutil.

Enable debug logging on Windows wheel builds.

Also include setuptools in the build requirements and test for its presence in build.py.
2023-09-26 21:55:45 -04:00
Sharad Vikram
d872812a35 [Pallas] Upstream pallas to JAX
PiperOrigin-RevId: 552963029
2023-08-01 16:43:13 -07:00
Peter Hawkins
f540ae4338 Fix warning about direct invocation of setup.py during jaxlib build.
The jaxlib wheel build currently uses `python setup.py bdist_wheel` to construct the wheel. Change it to use `python -m build -w` instead.

To avoid Python getting confused between the directory named `build` in the bazel tree and the Python `build` module, move `build_wheel.py` into `jaxlib/tools`.

PiperOrigin-RevId: 548133811
2023-07-14 08:31:16 -07:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Skye Wanderman-Milne
2ca151ef5b profiler_test.py fixes and add coverage to Cloud TPU CI
* Add deps to test requirements, including in new
  `collect-profile-requirements.txt` (to avoid adding tensorflow to
  `test-requirements.txt`).
* Use correct Python executable `ProfilerTest.test_remote_profiler`
  (`python` sometimes defaults to python2)
* Run computations for longer in `ProfilerTest.test_remote_profiler`,
  othewise `collect_profile` sometimes misses it.
2023-06-20 22:25:17 +00:00
Peter Hawkins
0c441574c4 Add NumPy as a test requirement.
The Windows CI currently installs all of the test requirements before building jaxlib, but NumPy is needed to build jaxlib.
Previously this came transitively via matplotlib.
2023-06-14 10:14:56 -04:00
Peter Hawkins
6b76937c53 Remove matplotlib from the test requirements.
In the Windows CI, we seem to be hitting the following error:

```
=================================== ERRORS ====================================
____________________ ERROR collecting tests/lobpcg_test.py ____________________
tests\lobpcg_test.py:28: in <module>
    from matplotlib import pyplot as plt
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\pyplot.py:52: in <module>
    import matplotlib.colorbar
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\colorbar.py:19: in <module>
    from matplotlib import _api, cbook, collections, cm, colors, contour, ticker
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\contour.py:13: in <module>
    from matplotlib.backend_bases import MouseButton
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\backend_bases.py:45: in <module>
    from matplotlib import (
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\text.py:16: in <module>
    from .font_manager import FontProperties
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:1548: in <module>
    fontManager = _load_fontmanager()
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:1543: in _load_fontmanager
    json_dump(fm, fm_path)
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:957: in json_dump
    with cbook._lock_path(filename), open(filename, 'w') as fh:
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\contextlib.py:119: in __enter__
    return next(self.gen)
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\cbook\__init__.py:1804: in _lock_path
    with lock_path.open("xb"):
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\pathlib.py:1252: in open
    return io.open(self, mode, buffering, encoding, errors, newline,
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\pathlib.py:1120: in _opener
    return self._accessor.open(self, flags, mode)
E   PermissionError: [Errno 13] Permission denied: 'C:\\Users\\runneradmin\\.matplotlib\\fontlist-v330.json.matplotlib-lock'
```

The use of matplotlib is only for an optional debugging feature anyway, so just make it an optional dependency.
2023-06-14 09:02:49 -04:00
Yash Katariya
cf6b5097d0 Remove pytest_benchmark for test-requirements.txt and move the benchmark file which was using that package to use google_benchmark.
PiperOrigin-RevId: 483736267
2022-10-25 11:59:32 -07:00
Nicholas Junge
efd61b73f6 Migrate JAX internals to builtin Python logging
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):

- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.

Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:

```py
import logging
logger = logging.getLogger(__name__)

logger.debug(...)
logger.info(...)
```

 The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.

The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
2022-10-13 21:32:44 +02:00
Sharad Vikram
2d8b228706 Add function to visualize Shardings 2022-09-19 13:27:08 -07:00
Jake VanderPlas
f7731c8a29 Tests: require pillow>=9.1.0 & remove backward compatibility 2022-08-12 13:34:56 -07:00
Vlad Feinberg
269067e3e8 Make LOBPCG test plots compatible with bazel.
bazel test invocations would previously not work, because the lobpcg_test did not include the appropriate flag parsing and absl test invocations when run as a script. This change fixes that, and in addition shards tests and removes needless and redundant slow tests with larger matrix sizes to make the tests finish in a smaller amount of time. Now, generated pngs with debug information are properly reported via the undeclared outputs directory when the environment variable to emit them, LOBPCG_EMIT_DEBUG_PLOTS, is set to a non-falsy value.

PiperOrigin-RevId: 465465731
2022-08-04 20:05:53 -07:00
Jake VanderPlas
c4169a0c76 make tests compatible with recent pillow versions 2022-07-22 13:09:52 -07:00
Peter Hawkins
47f2f091bc Reapply: Drop flatbuffers as a Python dependency of JAX.
The crashes on Mac were, as best we can tell, unrelated to this PR.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457819042
2022-06-28 14:25:14 -07:00
Peter Hawkins
5b576cb03e Revert: Drop flatbuffers as a Python dependency of JAX.
This change appears to be causing crashes on Mac.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457559793
2022-06-27 13:56:32 -07:00
Peter Hawkins
efefeac450 Drop flatbuffers as a Python dependency of JAX.
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457460347
2022-06-27 06:14:07 -07:00
Jake VanderPlas
617df70135 Unpin numpy to ensure most recent version is tested 2022-06-23 12:23:14 -07:00
Yash Katariya
1908da33af Only initialize GPU backends if they are not already initialized
PiperOrigin-RevId: 456664792
2022-06-22 19:39:52 -07:00
Jake VanderPlas
1f300e729b CI: pin pillow<9.1 to prevent deprecation warnings 2022-04-01 09:23:27 -07:00
Peter Hawkins
901d459e0d Add cloudpickle as a test requirement.
We have at least one test that tests pickling JAX objects.
2022-02-16 15:04:56 -05:00
Peter Hawkins
5fa4613e99 Adds a Wadler-Lindig pretty printer.
Changes jaxpr printing to use it.
2021-09-27 21:09:24 -04:00
Skye Wanderman-Milne
2fcf3f7270 Remove .[minimum-jaxlib] from test-requirements.txt
This means that jax and its dependencies (e.g. jaxlib) must be
manually installed before running the tests. This is useful for
testing an existing jax install, e.g. a later version of jaxlib, GPU
jaxlib, etc.
2021-09-23 12:24:24 -07:00
Jake VanderPlas
a5b6a4e6a9 CI: remove flake8 from test requirements. 2021-08-25 11:07:09 -07:00
dependabot[bot]
9f2863c66b Copybara import of the project:
--
57572d861a8bfe42a3b34b19a6e25a0b7ea4f22f by dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>:

Bump flatbuffers from 1.12 to 2.0

Bumps [flatbuffers](https://github.com/google/flatbuffers) from 1.12 to 2.0.
- [Release notes](https://github.com/google/flatbuffers/releases)
- [Commits](https://github.com/google/flatbuffers/compare/v1.12.0...v2.0.0)

---
updated-dependencies:
- dependency-name: flatbuffers
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7686 from google:dependabot/pip/flatbuffers-2.0 57572d861a8bfe42a3b34b19a6e25a0b7ea4f22f
PiperOrigin-RevId: 392097862
2021-08-20 17:13:26 -07:00
Jake VanderPlas
cbcd6eeadb CI: bump mypy & flake8 versions to newest 2021-08-20 14:35:37 -07:00
Jake VanderPlas
7fa151c5c3 cleanup: remove redundant entry from test-requirements 2021-08-20 10:09:14 -07:00
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
Peter Hawkins
0de4a60834 Update pillow pin to >= 8.3.1.
8.3.1 fixed the issue from https://github.com/google/jax/pull/7166.
2021-07-07 08:33:29 -04:00
Jake VanderPlas
4ba343aa83 CI: pin pillow dependency to 8.2 to avoid failures under 8.3 2021-07-01 16:32:35 -07:00
Jake VanderPlas
0c91be7b46 CI: temporarily pin numpy to <1.21 2021-06-22 11:15:16 -07:00
Peter Hawkins
07277f0785 Bump mypy version to 0.902. 2021-06-14 10:05:34 -04:00
Peter Hawkins
40c5e376d8 Pin flatbuffers 1.12 for CI tests. 2021-05-10 18:21:25 -04:00
Jake VanderPlas
f9a4162551 Specify minimum jaxlib version in a single location 2021-03-22 16:14:41 -07:00
Peter Hawkins
328930b917 Increase minimum jaxlib version to 0.1.62. 2021-03-16 15:11:36 -04:00
Skye Wanderman-Milne
7a67b974ac jaxlib version bump etc. 2021-02-12 09:42:04 -08:00