diff --git a/jax/_src/scipy/stats/betabinom.py b/jax/_src/scipy/stats/betabinom.py index e93e6117c..f2366ed7d 100644 --- a/jax/_src/scipy/stats/betabinom.py +++ b/jax/_src/scipy/stats/betabinom.py @@ -22,8 +22,6 @@ from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or, nan from jax._src.scipy.special import betaln -scipy_version = tuple(map(int, scipy.version.version.split('.')[:2])) - def logpmf(k, n, a, b, loc=0): """JAX implementation of scipy.stats.betabinom.logpmf.""" @@ -45,7 +43,5 @@ def pmf(k, n, a, b, loc=0): return lax.exp(logpmf(k, n, a, b, loc)) -# betabinom was added in scipy 1.4.0 -if scipy_version >= (1, 4): - logpmf = _wraps(osp_stats.betabinom.logpmf, update_doc=False)(logpmf) - pmf = _wraps(osp_stats.betabinom.pmf, update_doc=False)(pmf) +logpmf = _wraps(osp_stats.betabinom.logpmf, update_doc=False)(logpmf) +pmf = _wraps(osp_stats.betabinom.pmf, update_doc=False)(pmf) diff --git a/jaxlib/setup.py b/jaxlib/setup.py index e87a9cde9..768477a5a 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -38,7 +38,7 @@ setup( author_email='jax-dev@google.com', packages=['jaxlib', 'jaxlib.xla_extension'], python_requires='>=3.7', - install_requires=['scipy', 'numpy>=1.19', 'absl-py', 'flatbuffers >= 1.12, < 3.0'], + install_requires=['scipy>=1.5', 'numpy>=1.19', 'absl-py', 'flatbuffers >= 1.12, < 3.0'], url='https://github.com/google/jax', license='Apache-2.0', classifiers=[ diff --git a/setup.py b/setup.py index 0b62f37de..293abfc22 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ setup( 'absl-py', 'numpy>=1.19', 'opt_einsum', - 'scipy>=1.2.1', + 'scipy>=1.5', 'typing_extensions', 'etils[epath]' ], diff --git a/tests/scipy_optimize_test.py b/tests/scipy_optimize_test.py index ac6b903ca..7d6f02177 100644 --- a/tests/scipy_optimize_test.py +++ b/tests/scipy_optimize_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib - from absl.testing import absltest, parameterized import numpy as np import scipy @@ -26,7 +24,6 @@ from jax.config import config import jax.scipy.optimize config.parse_flags_with_absl() -scipy_version = tuple(map(int, scipy.version.version.split('.')[:2])) def rosenbrock(np): @@ -127,16 +124,12 @@ class TestBFGS(jtu.JaxTestCase): x0=initial_value, method='BFGS', ).x - # Scipy does type-promoting binary ops on JAX array inputs. - # Newer versions of scipy (1.5+) don't have this issue. - with (jax.numpy_dtype_promotion('standard') if scipy_version < (1, 5) - else contextlib.nullcontext()): - scipy_res = scipy.optimize.minimize( - fun=opt_fn, - jac=jax.grad(opt_fn), - method='BFGS', - x0=initial_value - ).x + scipy_res = scipy.optimize.minimize( + fun=opt_fn, + jac=jax.grad(opt_fn), + method='BFGS', + x0=initial_value + ).x self.assertAllClose(scipy_res, jax_res, atol=2e-5, check_dtypes=False) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 2dfcde31d..8ec6d9a5b 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -18,7 +18,6 @@ import itertools from absl.testing import absltest, parameterized import numpy as np -import scipy as osp import scipy.stats as osp_stats import jax @@ -31,7 +30,6 @@ config.parse_flags_with_absl() all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)] one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)] -scipy_version = tuple(map(int, osp.version.version.split('.')[:2])) def genNamedParametersNArgs(n): @@ -554,10 +552,9 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): return [k, n, a, b, loc] with jtu.strict_promotion_if_dtypes_match(dtypes): - if scipy_version >= (1, 4): - scipy_fun = osp_stats.betabinom.logpmf - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=5e-4) + scipy_fun = osp_stats.betabinom.logpmf + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5) def testIssue972(self):