Increase the minimum scipy version to 1.5.

We don't have a formal support policy for scipy versions, but 1.5 dates from around the same date as the oldest supported NumPy release NEP-29 would have us support (1.20).
This commit is contained in:
Peter Hawkins 2022-06-24 15:01:16 -04:00
parent 989a3304bf
commit a560a29e12
5 changed files with 13 additions and 27 deletions

View File

@ -22,8 +22,6 @@ from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or, nan
from jax._src.scipy.special import betaln
scipy_version = tuple(map(int, scipy.version.version.split('.')[:2]))
def logpmf(k, n, a, b, loc=0):
"""JAX implementation of scipy.stats.betabinom.logpmf."""
@ -45,7 +43,5 @@ def pmf(k, n, a, b, loc=0):
return lax.exp(logpmf(k, n, a, b, loc))
# betabinom was added in scipy 1.4.0
if scipy_version >= (1, 4):
logpmf = _wraps(osp_stats.betabinom.logpmf, update_doc=False)(logpmf)
pmf = _wraps(osp_stats.betabinom.pmf, update_doc=False)(pmf)
logpmf = _wraps(osp_stats.betabinom.logpmf, update_doc=False)(logpmf)
pmf = _wraps(osp_stats.betabinom.pmf, update_doc=False)(pmf)

View File

@ -38,7 +38,7 @@ setup(
author_email='jax-dev@google.com',
packages=['jaxlib', 'jaxlib.xla_extension'],
python_requires='>=3.7',
install_requires=['scipy', 'numpy>=1.19', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
install_requires=['scipy>=1.5', 'numpy>=1.19', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
url='https://github.com/google/jax',
license='Apache-2.0',
classifiers=[

View File

@ -42,7 +42,7 @@ setup(
'absl-py',
'numpy>=1.19',
'opt_einsum',
'scipy>=1.2.1',
'scipy>=1.5',
'typing_extensions',
'etils[epath]'
],

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from absl.testing import absltest, parameterized
import numpy as np
import scipy
@ -26,7 +24,6 @@ from jax.config import config
import jax.scipy.optimize
config.parse_flags_with_absl()
scipy_version = tuple(map(int, scipy.version.version.split('.')[:2]))
def rosenbrock(np):
@ -127,16 +124,12 @@ class TestBFGS(jtu.JaxTestCase):
x0=initial_value,
method='BFGS',
).x
# Scipy does type-promoting binary ops on JAX array inputs.
# Newer versions of scipy (1.5+) don't have this issue.
with (jax.numpy_dtype_promotion('standard') if scipy_version < (1, 5)
else contextlib.nullcontext()):
scipy_res = scipy.optimize.minimize(
fun=opt_fn,
jac=jax.grad(opt_fn),
method='BFGS',
x0=initial_value
).x
scipy_res = scipy.optimize.minimize(
fun=opt_fn,
jac=jax.grad(opt_fn),
method='BFGS',
x0=initial_value
).x
self.assertAllClose(scipy_res, jax_res, atol=2e-5, check_dtypes=False)

View File

@ -18,7 +18,6 @@ import itertools
from absl.testing import absltest, parameterized
import numpy as np
import scipy as osp
import scipy.stats as osp_stats
import jax
@ -31,7 +30,6 @@ config.parse_flags_with_absl()
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)]
scipy_version = tuple(map(int, osp.version.version.split('.')[:2]))
def genNamedParametersNArgs(n):
@ -554,10 +552,9 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
return [k, n, a, b, loc]
with jtu.strict_promotion_if_dtypes_match(dtypes):
if scipy_version >= (1, 4):
scipy_fun = osp_stats.betabinom.logpmf
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
scipy_fun = osp_stats.betabinom.logpmf
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
def testIssue972(self):