mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Increase minimum NumPy version to 1.21.
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
This commit is contained in:
parent
e9bc7ee866
commit
b7375b316b
@ -24,6 +24,7 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
like bfloat16. These definitions were previously internal to JAX, but have
|
||||
been split into a separate package to facilitate sharing them with other
|
||||
projects.
|
||||
* JAX now requires NumPy 1.21 or newer and SciPy 1.7 or newer.
|
||||
|
||||
* Deprecations
|
||||
* The type `jax.numpy.DeviceArray` is deprecated. Use `jax.Array` instead,
|
||||
|
@ -83,8 +83,8 @@ def check_numpy_version(python_bin_path):
|
||||
version = shell(
|
||||
[python_bin_path, "-c", "import numpy as np; print(np.__version__)"])
|
||||
numpy_version = tuple(map(int, version.split(".")[:2]))
|
||||
if numpy_version < (1, 20):
|
||||
print("ERROR: JAX requires NumPy 1.20 or newer, found " + version + ".")
|
||||
if numpy_version < (1, 21):
|
||||
print("ERROR: JAX requires NumPy 1.21 or newer, found " + version + ".")
|
||||
sys.exit(-1)
|
||||
return version
|
||||
|
||||
|
@ -46,7 +46,7 @@ setup(
|
||||
author_email='jax-dev@google.com',
|
||||
packages=['jaxlib', 'jaxlib.xla_extension'],
|
||||
python_requires='>=3.8',
|
||||
install_requires=['scipy>=1.5', 'numpy>=1.20', 'ml_dtypes>=0.0.3'],
|
||||
install_requires=['scipy>=1.7', 'numpy>=1.21', 'ml_dtypes>=0.0.3'],
|
||||
url='https://github.com/google/jax',
|
||||
license='Apache-2.0',
|
||||
classifiers=[
|
||||
|
4
setup.py
4
setup.py
@ -65,9 +65,9 @@ setup(
|
||||
python_requires='>=3.8',
|
||||
install_requires=[
|
||||
'ml_dtypes>=0.0.3',
|
||||
'numpy>=1.20',
|
||||
'numpy>=1.21',
|
||||
'opt_einsum',
|
||||
'scipy>=1.5',
|
||||
'scipy>=1.7',
|
||||
],
|
||||
extras_require={
|
||||
# Minimum jaxlib version; used in testing.
|
||||
|
@ -492,8 +492,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where)
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
if numpy_version >= (1, 20, 2) or np_op.__name__ in ("all", "any"):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def testReductionOfOutOfBoundsAxis(self): # Issue 888
|
||||
|
@ -1322,8 +1322,6 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def testMode(self, shape, dtype, axis, contains_nans, keepdims):
|
||||
if scipy_version < (1, 9, 0) and keepdims != True:
|
||||
self.skipTest("scipy < 1.9.0 only support keepdims == True")
|
||||
if numpy_version < (1, 21, 0) and contains_nans:
|
||||
self.skipTest("numpy < 1.21.0 only support contains_nans == False")
|
||||
|
||||
if contains_nans:
|
||||
rng = jtu.rand_some_nan(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user