mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Update minimum NumPy version to v1.24.
This commit is contained in:
parent
694cafb72b
commit
7f24837eef
@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.31
|
||||
|
||||
* Changes
|
||||
* The minimum NumPy version is now 1.24.
|
||||
|
||||
## jaxlib 0.4.31
|
||||
|
||||
* Bug fixes
|
||||
|
@ -63,7 +63,7 @@ setup(
|
||||
install_requires=[
|
||||
'scipy>=1.9',
|
||||
"scipy>=1.11.1; python_version>='3.12'",
|
||||
'numpy>=1.22',
|
||||
'numpy>=1.24',
|
||||
'ml_dtypes>=0.2.0',
|
||||
],
|
||||
url='https://github.com/google/jax',
|
||||
|
3
setup.py
3
setup.py
@ -55,8 +55,7 @@ setup(
|
||||
install_requires=[
|
||||
f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}',
|
||||
'ml_dtypes>=0.2.0',
|
||||
'numpy>=1.22',
|
||||
"numpy>=1.23.2; python_version>='3.11'",
|
||||
'numpy>=1.24',
|
||||
"numpy>=1.26.0; python_version>='3.12'",
|
||||
'opt_einsum',
|
||||
'scipy>=1.9',
|
||||
|
@ -214,7 +214,6 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
shape=all_shapes,
|
||||
dtype=numpy_dtypes,
|
||||
)
|
||||
@unittest.skipIf(numpy_version < (1, 23, 0), "Requires numpy 1.23 or newer")
|
||||
@jtu.run_on_devices("cpu") # NumPy only accepts cpu DLPacks
|
||||
def testJaxToNumpy(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
@ -1679,9 +1679,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis]
|
||||
mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool)
|
||||
if numpy_version == (1, 23, 0) and mask.shape == (1,):
|
||||
# https://github.com/numpy/numpy/issues/21840
|
||||
self.skipTest("test fails for numpy v1.23.0")
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
np_fun = lambda arg: np.delete(arg, mask, axis=axis)
|
||||
jnp_fun = lambda arg: jnp.delete(arg, mask, axis=axis)
|
||||
@ -1943,9 +1940,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
@unittest.skip("jax-metal fail.")
|
||||
@jtu.sample_product(dtype=inexact_dtypes)
|
||||
def testUniqueNans(self, dtype):
|
||||
if numpy_version == (1, 23, 0) and dtype == np.float16:
|
||||
# https://github.com/numpy/numpy/issues/21838
|
||||
self.skipTest("Known failure on numpy 1.23.0")
|
||||
def args_maker():
|
||||
x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan]
|
||||
if np.issubdtype(dtype, np.complexfloating):
|
||||
@ -1966,8 +1960,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
@unittest.skip("jax-metal fail.")
|
||||
@jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False])
|
||||
def testUniqueEqualNan(self, dtype, equal_nan):
|
||||
if numpy_version < (1, 24, 0):
|
||||
self.skipTest("np.unique equal_nan requires NumPy 1.24 or newer.")
|
||||
shape = (20,)
|
||||
rng = jtu.rand_some_nan(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
@ -2669,10 +2661,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
else:
|
||||
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
||||
|
||||
if numpy_version < (1, 24):
|
||||
np_fun = jtu.promote_like_jnp(lambda *args: np.stack(*args, axis=axis).astype(out_dtype))
|
||||
else:
|
||||
np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe'))
|
||||
np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe'))
|
||||
|
||||
jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype)
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
@ -2699,7 +2688,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
else:
|
||||
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
||||
|
||||
if numpy_version < (1, 24) or op == "dstack":
|
||||
if op == "dstack":
|
||||
np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype))
|
||||
else:
|
||||
np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype,
|
||||
|
@ -507,14 +507,12 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
for weights_shape in ([None, shape] if axis is None or len(shape) == 1 or isinstance(axis, tuple)
|
||||
else [None, (shape[axis],), shape])
|
||||
],
|
||||
keepdims=([False, True] if numpy_version >= (1, 23) else [None]),
|
||||
keepdims=[False, True],
|
||||
returned=[False, True],
|
||||
)
|
||||
def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
kwds = dict(returned=returned)
|
||||
if keepdims is not None:
|
||||
kwds['keepdims'] = keepdims
|
||||
kwds = dict(returned=returned, keepdims=keepdims)
|
||||
if weights_shape is None:
|
||||
np_fun = lambda x: np.average(x, axis, **kwds)
|
||||
jnp_fun = lambda x: jnp.average(x, axis, **kwds)
|
||||
@ -527,15 +525,11 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-2, np.float32: 1e-5,
|
||||
np.float64: 1e-12, np.complex64: 1e-5}
|
||||
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE
|
||||
if numpy_version == (1, 23, 0) and keepdims and weights_shape is not None and axis is not None:
|
||||
# Known failure: https://github.com/numpy/numpy/issues/21850
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
check_dtypes=check_dtypes, tol=tol)
|
||||
except ZeroDivisionError:
|
||||
self.skipTest("don't support checking for ZeroDivisionError")
|
||||
try:
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
check_dtypes=check_dtypes, tol=tol)
|
||||
except ZeroDivisionError:
|
||||
self.skipTest("don't support checking for ZeroDivisionError")
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
|
@ -2048,9 +2048,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=inexact_dtypes)
|
||||
def testUniqueNans(self, dtype):
|
||||
if numpy_version == (1, 23, 0) and dtype == np.float16:
|
||||
# https://github.com/numpy/numpy/issues/21838
|
||||
self.skipTest("Known failure on numpy 1.23.0")
|
||||
def args_maker():
|
||||
x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan]
|
||||
if np.issubdtype(dtype, np.complexfloating):
|
||||
@ -2070,8 +2067,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False])
|
||||
def testUniqueEqualNan(self, dtype, equal_nan):
|
||||
if numpy_version < (1, 24, 0):
|
||||
self.skipTest("np.unique equal_nan requires NumPy 1.24 or newer.")
|
||||
shape = (20,)
|
||||
rng = jtu.rand_some_nan(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
@ -2784,10 +2779,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
else:
|
||||
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
||||
|
||||
if numpy_version < (1, 24):
|
||||
np_fun = jtu.promote_like_jnp(lambda *args: np.stack(*args, axis=axis).astype(out_dtype))
|
||||
else:
|
||||
np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe'))
|
||||
np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe'))
|
||||
|
||||
jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype)
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
@ -2814,7 +2806,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
else:
|
||||
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
|
||||
|
||||
if numpy_version < (1, 24) or op == "dstack":
|
||||
if op == "dstack":
|
||||
np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype))
|
||||
else:
|
||||
np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype,
|
||||
@ -5992,7 +5984,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
mismatches = {}
|
||||
|
||||
for name, (jnp_fun, np_fun) in func_pairs.items():
|
||||
if numpy_version >= (1, 24) and name in ['histogram', 'histogram2d', 'histogramdd']:
|
||||
if name in ['histogram', 'histogram2d', 'histogramdd']:
|
||||
# numpy 1.24 re-orders the density and weights arguments.
|
||||
# TODO(jakevdp): migrate histogram APIs to match newer numpy versions.
|
||||
continue
|
||||
|
Loading…
x
Reference in New Issue
Block a user