diff --git a/CHANGELOG.md b/CHANGELOG.md index 346c399b3..1835f0857 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Changes: * The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum supported version until June 2025. + * The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum + supported version until June 2025. * {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than `optimize='optimal'`. This avoids exponentially-scaling trace-time in the case of many arguments ({jax-issue}`#25214`). diff --git a/jaxlib/setup.py b/jaxlib/setup.py index c2efd3d7b..b3a37a25f 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -61,8 +61,7 @@ setup( packages=['jaxlib', 'jaxlib.xla_extension'], python_requires='>=3.10', install_requires=[ - 'scipy>=1.10', - "scipy>=1.11.1; python_version>='3.12'", + 'scipy>=1.11.1', 'numpy>=1.25', 'ml_dtypes>=0.2.0', ], diff --git a/setup.py b/setup.py index b3bd4a346..39508388b 100644 --- a/setup.py +++ b/setup.py @@ -60,8 +60,7 @@ setup( 'numpy>=1.25', "numpy>=1.26.0; python_version>='3.12'", 'opt_einsum', - 'scipy>=1.10', - "scipy>=1.11.1; python_version>='3.12'", + 'scipy>=1.11.1', ], extras_require={ # Minimum jaxlib version; used in testing. diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 0da09e232..65f7c8145 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1686,7 +1686,7 @@ class ScipyLinalgTest(jtu.JaxTestCase): @jtu.sample_product( n=[1, 4, 5, 20, 50, 100], - batch_size=[(), (2,), (3, 4)] if scipy_version >= (1, 9, 0) else [()], + batch_size=[(), (2,), (3, 4)], dtype=int_types + float_types + complex_types ) def testExpm(self, n, batch_size, dtype):