mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Increase the minimum scipy version to 1.5.
We don't have a formal support policy for scipy versions, but 1.5 dates from around the same date as the oldest supported NumPy release NEP-29 would have us support (1.20).
This commit is contained in:
parent
989a3304bf
commit
a560a29e12
@ -22,8 +22,6 @@ from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or, nan
|
||||
from jax._src.scipy.special import betaln
|
||||
|
||||
scipy_version = tuple(map(int, scipy.version.version.split('.')[:2]))
|
||||
|
||||
|
||||
def logpmf(k, n, a, b, loc=0):
|
||||
"""JAX implementation of scipy.stats.betabinom.logpmf."""
|
||||
@ -45,7 +43,5 @@ def pmf(k, n, a, b, loc=0):
|
||||
return lax.exp(logpmf(k, n, a, b, loc))
|
||||
|
||||
|
||||
# betabinom was added in scipy 1.4.0
|
||||
if scipy_version >= (1, 4):
|
||||
logpmf = _wraps(osp_stats.betabinom.logpmf, update_doc=False)(logpmf)
|
||||
pmf = _wraps(osp_stats.betabinom.pmf, update_doc=False)(pmf)
|
||||
logpmf = _wraps(osp_stats.betabinom.logpmf, update_doc=False)(logpmf)
|
||||
pmf = _wraps(osp_stats.betabinom.pmf, update_doc=False)(pmf)
|
||||
|
@ -38,7 +38,7 @@ setup(
|
||||
author_email='jax-dev@google.com',
|
||||
packages=['jaxlib', 'jaxlib.xla_extension'],
|
||||
python_requires='>=3.7',
|
||||
install_requires=['scipy', 'numpy>=1.19', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
|
||||
install_requires=['scipy>=1.5', 'numpy>=1.19', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
|
||||
url='https://github.com/google/jax',
|
||||
license='Apache-2.0',
|
||||
classifiers=[
|
||||
|
2
setup.py
2
setup.py
@ -42,7 +42,7 @@ setup(
|
||||
'absl-py',
|
||||
'numpy>=1.19',
|
||||
'opt_einsum',
|
||||
'scipy>=1.2.1',
|
||||
'scipy>=1.5',
|
||||
'typing_extensions',
|
||||
'etils[epath]'
|
||||
],
|
||||
|
@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
import numpy as np
|
||||
import scipy
|
||||
@ -26,7 +24,6 @@ from jax.config import config
|
||||
import jax.scipy.optimize
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
scipy_version = tuple(map(int, scipy.version.version.split('.')[:2]))
|
||||
|
||||
|
||||
def rosenbrock(np):
|
||||
@ -127,16 +124,12 @@ class TestBFGS(jtu.JaxTestCase):
|
||||
x0=initial_value,
|
||||
method='BFGS',
|
||||
).x
|
||||
# Scipy does type-promoting binary ops on JAX array inputs.
|
||||
# Newer versions of scipy (1.5+) don't have this issue.
|
||||
with (jax.numpy_dtype_promotion('standard') if scipy_version < (1, 5)
|
||||
else contextlib.nullcontext()):
|
||||
scipy_res = scipy.optimize.minimize(
|
||||
fun=opt_fn,
|
||||
jac=jax.grad(opt_fn),
|
||||
method='BFGS',
|
||||
x0=initial_value
|
||||
).x
|
||||
scipy_res = scipy.optimize.minimize(
|
||||
fun=opt_fn,
|
||||
jac=jax.grad(opt_fn),
|
||||
method='BFGS',
|
||||
x0=initial_value
|
||||
).x
|
||||
self.assertAllClose(scipy_res, jax_res, atol=2e-5, check_dtypes=False)
|
||||
|
||||
|
||||
|
@ -18,7 +18,6 @@ import itertools
|
||||
from absl.testing import absltest, parameterized
|
||||
|
||||
import numpy as np
|
||||
import scipy as osp
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
import jax
|
||||
@ -31,7 +30,6 @@ config.parse_flags_with_absl()
|
||||
|
||||
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
|
||||
one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)]
|
||||
scipy_version = tuple(map(int, osp.version.version.split('.')[:2]))
|
||||
|
||||
|
||||
def genNamedParametersNArgs(n):
|
||||
@ -554,10 +552,9 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
return [k, n, a, b, loc]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
if scipy_version >= (1, 4):
|
||||
scipy_fun = osp_stats.betabinom.logpmf
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
scipy_fun = osp_stats.betabinom.logpmf
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
|
||||
|
||||
def testIssue972(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user