mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
Drop support for NumPy 1.17.
This commit is contained in:
parent
c4602475ca
commit
6e9169d100
2
.github/workflows/ci-build.yaml
vendored
2
.github/workflows/ci-build.yaml
vendored
@ -62,7 +62,7 @@ jobs:
|
||||
os: ubuntu-latest
|
||||
enable-x64: 0
|
||||
# Test with the oldest legal NumPy version.
|
||||
package-overrides: "numpy==1.17.5 scipy==1.2.1"
|
||||
package-overrides: "numpy==1.18.5 scipy==1.2.1"
|
||||
num_generated_cases: 8
|
||||
# Test against latest jaxlib
|
||||
use-latest-jaxlib: true
|
||||
|
10
CHANGELOG.md
10
CHANGELOG.md
@ -10,7 +10,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
|
||||
## jax 0.2.19 (unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.18...main).
|
||||
|
||||
* Breaking changes:
|
||||
* Support for NumPy 1.17 has been dropped, per the
|
||||
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
|
||||
Please upgrade to a supported NumPy version.
|
||||
* New features:
|
||||
* Improved the support for shape polymorphism in jax2tf for operations that
|
||||
need to use a dimension size in array computation, e.g., `jnp.mean`.
|
||||
@ -21,7 +24,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* Support for Python 3.6 has been dropped, per the
|
||||
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
|
||||
Please upgrade to a supported Python version.
|
||||
|
||||
* Support for NumPy 1.17 has been dropped, per the
|
||||
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
|
||||
Please upgrade to a supported NumPy version.
|
||||
|
||||
* The host_callback mechanism now uses one thread per local device for
|
||||
making the calls to the Python callbacks. Previously there was a single
|
||||
thread for all devices. This means that the callbacks may now be called
|
||||
|
@ -82,8 +82,8 @@ def check_numpy_version(python_bin_path):
|
||||
version = shell(
|
||||
[python_bin_path, "-c", "import numpy as np; print(np.__version__)"])
|
||||
numpy_version = tuple(map(int, version.split('.')[:2]))
|
||||
if numpy_version < (1, 17):
|
||||
print("ERROR: JAX requires NumPy 1.17 or newer, found " + version + ".")
|
||||
if numpy_version < (1, 18):
|
||||
print("ERROR: JAX requires NumPy 1.18 or newer, found " + version + ".")
|
||||
sys.exit(-1)
|
||||
return version
|
||||
|
||||
|
@ -32,7 +32,7 @@ setup(
|
||||
author_email='jax-dev@google.com',
|
||||
packages=['jaxlib', 'jaxlib.xla_extension-stubs'],
|
||||
python_requires='>=3.7',
|
||||
install_requires=['scipy', 'numpy>=1.17', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
|
||||
install_requires=['scipy', 'numpy>=1.18', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
|
||||
url='https://github.com/google/jax',
|
||||
license='Apache-2.0',
|
||||
package_data={
|
||||
|
2
setup.py
2
setup.py
@ -36,7 +36,7 @@ setup(
|
||||
package_data={'jax': ['py.typed']},
|
||||
python_requires='>=3.7',
|
||||
install_requires=[
|
||||
'numpy>=1.17',
|
||||
'numpy>=1.18',
|
||||
'absl-py',
|
||||
'opt_einsum',
|
||||
],
|
||||
|
@ -1406,7 +1406,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max)
|
||||
jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
# TODO(phawkins): the promotion behavior changed in Numpy 1.17.
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@ -4226,8 +4225,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
else:
|
||||
args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)]
|
||||
|
||||
# TODO(jakevdp): remove this ignore_warning when minimum numpy version is 1.17.0
|
||||
@jtu.ignore_warning(category=RuntimeWarning, message="Invalid value encountered.*")
|
||||
def np_fun(*args):
|
||||
args = [x if jnp.result_type(x) != jnp.bfloat16 else
|
||||
np.asarray(x, np.float32) for x in args]
|
||||
|
Loading…
x
Reference in New Issue
Block a user