mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix test failures under SciPy 1.11 for scipy.stats.mode.
This commit is contained in:
parent
e84be656fc
commit
2fd6df45e4
@ -15,6 +15,10 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* `random.threefry2x32_key(seed)` becomes `random.PRNGKey(seed, impl='threefry2x32')`
|
* `random.threefry2x32_key(seed)` becomes `random.PRNGKey(seed, impl='threefry2x32')`
|
||||||
* `random.rbg_key(seed)` becomes `random.PRNGKey(seed, impl='rbg')`
|
* `random.rbg_key(seed)` becomes `random.PRNGKey(seed, impl='rbg')`
|
||||||
* `random.unsafe_rbg_key(seed)` becomes `random.PRNGKey(seed, impl='unsafe_rbg')`
|
* `random.unsafe_rbg_key(seed)` becomes `random.PRNGKey(seed, impl='unsafe_rbg')`
|
||||||
|
* Changes:
|
||||||
|
* {func}`jax.scipy.stats.mode` now returns a 0 count if the mode is taken
|
||||||
|
across a size-0 axis, matching the behavior of `scipy.stats.mode` in SciPy
|
||||||
|
1.11.
|
||||||
|
|
||||||
# jaxlib 0.4.17
|
# jaxlib 0.4.17
|
||||||
|
|
||||||
|
@ -73,7 +73,8 @@ def mode(a: ArrayLike, axis: Optional[int] = 0, nan_policy: str = "propagate", k
|
|||||||
def _mode_helper(x: jax.Array) -> tuple[jax.Array, jax.Array]:
|
def _mode_helper(x: jax.Array) -> tuple[jax.Array, jax.Array]:
|
||||||
"""Helper function to return mode and count of a given array."""
|
"""Helper function to return mode and count of a given array."""
|
||||||
if x.size == 0:
|
if x.size == 0:
|
||||||
return jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)), jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_))
|
return (jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)),
|
||||||
|
jnp.array(0, dtype=dtypes.canonicalize_dtype(jnp.float_)))
|
||||||
else:
|
else:
|
||||||
vals, counts = jnp.unique(x, return_counts=True, size=x.size)
|
vals, counts = jnp.unique(x, return_counts=True, size=x.size)
|
||||||
return vals[jnp.argmax(counts)], counts.max()
|
return vals[jnp.argmax(counts)], counts.max()
|
||||||
|
@ -1481,13 +1481,20 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
|||||||
|
|
||||||
def scipy_mode_wrapper(a, axis=0, nan_policy='propagate', keepdims=None):
|
def scipy_mode_wrapper(a, axis=0, nan_policy='propagate', keepdims=None):
|
||||||
"""Wrapper to manage the shape discrepancies between scipy and jax"""
|
"""Wrapper to manage the shape discrepancies between scipy and jax"""
|
||||||
if scipy_version < (1, 9, 0) and a.size == 0 and keepdims == True:
|
if scipy_version < (1, 11, 0) and a.size == 0:
|
||||||
if axis == None:
|
if keepdims:
|
||||||
output_shape = tuple(1 for _ in a.shape)
|
if axis == None:
|
||||||
|
output_shape = tuple(1 for _ in a.shape)
|
||||||
|
else:
|
||||||
|
output_shape = tuple(1 if i == axis else s for i, s in enumerate(a.shape))
|
||||||
else:
|
else:
|
||||||
output_shape = tuple(1 if i == axis else s for i, s in enumerate(a.shape))
|
if axis == None:
|
||||||
return (np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_)),
|
output_shape = ()
|
||||||
np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_)))
|
else:
|
||||||
|
output_shape = np.delete(np.array(a.shape, dtype=np.int64), axis)
|
||||||
|
t = dtypes.canonicalize_dtype(jax.numpy.float_)
|
||||||
|
return (np.full(output_shape, np.nan, dtype=t),
|
||||||
|
np.zeros(output_shape, dtype=t))
|
||||||
|
|
||||||
if scipy_version < (1, 9, 0):
|
if scipy_version < (1, 9, 0):
|
||||||
result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy)
|
result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user