mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix test failures under SciPy 1.11 for scipy.stats.mode.
This commit is contained in:
parent
e84be656fc
commit
2fd6df45e4
@ -15,6 +15,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* `random.threefry2x32_key(seed)` becomes `random.PRNGKey(seed, impl='threefry2x32')`
|
||||
* `random.rbg_key(seed)` becomes `random.PRNGKey(seed, impl='rbg')`
|
||||
* `random.unsafe_rbg_key(seed)` becomes `random.PRNGKey(seed, impl='unsafe_rbg')`
|
||||
* Changes:
|
||||
* {func}`jax.scipy.stats.mode` now returns a 0 count if the mode is taken
|
||||
across a size-0 axis, matching the behavior of `scipy.stats.mode` in SciPy
|
||||
1.11.
|
||||
|
||||
# jaxlib 0.4.17
|
||||
|
||||
|
@ -73,7 +73,8 @@ def mode(a: ArrayLike, axis: Optional[int] = 0, nan_policy: str = "propagate", k
|
||||
def _mode_helper(x: jax.Array) -> tuple[jax.Array, jax.Array]:
|
||||
"""Helper function to return mode and count of a given array."""
|
||||
if x.size == 0:
|
||||
return jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)), jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_))
|
||||
return (jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)),
|
||||
jnp.array(0, dtype=dtypes.canonicalize_dtype(jnp.float_)))
|
||||
else:
|
||||
vals, counts = jnp.unique(x, return_counts=True, size=x.size)
|
||||
return vals[jnp.argmax(counts)], counts.max()
|
||||
|
@ -1481,13 +1481,20 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
def scipy_mode_wrapper(a, axis=0, nan_policy='propagate', keepdims=None):
|
||||
"""Wrapper to manage the shape discrepancies between scipy and jax"""
|
||||
if scipy_version < (1, 9, 0) and a.size == 0 and keepdims == True:
|
||||
if axis == None:
|
||||
output_shape = tuple(1 for _ in a.shape)
|
||||
if scipy_version < (1, 11, 0) and a.size == 0:
|
||||
if keepdims:
|
||||
if axis == None:
|
||||
output_shape = tuple(1 for _ in a.shape)
|
||||
else:
|
||||
output_shape = tuple(1 if i == axis else s for i, s in enumerate(a.shape))
|
||||
else:
|
||||
output_shape = tuple(1 if i == axis else s for i, s in enumerate(a.shape))
|
||||
return (np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_)),
|
||||
np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_)))
|
||||
if axis == None:
|
||||
output_shape = ()
|
||||
else:
|
||||
output_shape = np.delete(np.array(a.shape, dtype=np.int64), axis)
|
||||
t = dtypes.canonicalize_dtype(jax.numpy.float_)
|
||||
return (np.full(output_shape, np.nan, dtype=t),
|
||||
np.zeros(output_shape, dtype=t))
|
||||
|
||||
if scipy_version < (1, 9, 0):
|
||||
result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy)
|
||||
|
Loading…
x
Reference in New Issue
Block a user