Test: make scipy version parsing compatible with pre-releases

This commit is contained in:
Jake VanderPlas 2024-01-12 14:35:28 -08:00
parent 963cd6fd3d
commit 1870eee062
4 changed files with 5 additions and 5 deletions

View File

@ -1268,14 +1268,14 @@ def strict_promotion_if_dtypes_match(dtypes):
return jax.numpy_dtype_promotion('standard')
_version_regex = re.compile(r"([0-9]+(?:\.[0-9]+)*)(?:(rc|dev).*)?")
def _parse_version(v: str) -> tuple[int, ...]:
def parse_version(v: str) -> tuple[int, ...]:
m = _version_regex.match(v)
if m is None:
raise ValueError(f"Unable to parse version '{v}'")
return tuple(int(x) for x in m.group(1).split('.'))
def numpy_version():
return _parse_version(np.__version__)
return parse_version(np.__version__)
def parameterized_filterable(*,
kwargs: Sequence[dict[str, Any]],

View File

@ -39,7 +39,7 @@ from jax._src.numpy.util import promote_dtypes_inexact
config.parse_flags_with_absl()
scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
scipy_version = jtu.parse_version(scipy.version.version)
T = lambda x: np.swapaxes(x, -1, -2)

View File

@ -29,7 +29,7 @@ from jax import config
config.parse_flags_with_absl()
scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
scipy_version = jtu.parse_version(scipy.version.version)
float_dtypes = jtu.dtypes.floating
real_dtypes = float_dtypes + jtu.dtypes.integer + jtu.dtypes.boolean

View File

@ -30,7 +30,7 @@ from jax.scipy.special import expit
from jax import config
config.parse_flags_with_absl()
scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
scipy_version = jtu.parse_version(scipy.version.version)
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)]