diff --git a/CHANGELOG.md b/CHANGELOG.md index c98d87b93..35ca60329 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,10 @@ Remember to align the itemized text with the first line of an item within a list * `random.threefry2x32_key(seed)` becomes `random.PRNGKey(seed, impl='threefry2x32')` * `random.rbg_key(seed)` becomes `random.PRNGKey(seed, impl='rbg')` * `random.unsafe_rbg_key(seed)` becomes `random.PRNGKey(seed, impl='unsafe_rbg')` +* Changes: + * {func}`jax.scipy.stats.mode` now returns a 0 count if the mode is taken + across a size-0 axis, matching the behavior of `scipy.stats.mode` in SciPy + 1.11. # jaxlib 0.4.17 diff --git a/jax/_src/scipy/stats/_core.py b/jax/_src/scipy/stats/_core.py index db9533c8f..e83e7a8cb 100644 --- a/jax/_src/scipy/stats/_core.py +++ b/jax/_src/scipy/stats/_core.py @@ -73,7 +73,8 @@ def mode(a: ArrayLike, axis: Optional[int] = 0, nan_policy: str = "propagate", k def _mode_helper(x: jax.Array) -> tuple[jax.Array, jax.Array]: """Helper function to return mode and count of a given array.""" if x.size == 0: - return jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)), jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)) + return (jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)), + jnp.array(0, dtype=dtypes.canonicalize_dtype(jnp.float_))) else: vals, counts = jnp.unique(x, return_counts=True, size=x.size) return vals[jnp.argmax(counts)], counts.max() diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index efc9062f5..2c8868d37 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -1481,13 +1481,20 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): def scipy_mode_wrapper(a, axis=0, nan_policy='propagate', keepdims=None): """Wrapper to manage the shape discrepancies between scipy and jax""" - if scipy_version < (1, 9, 0) and a.size == 0 and keepdims == True: - if axis == None: - output_shape = tuple(1 for _ in a.shape) + if scipy_version < (1, 11, 0) and a.size == 0: + if keepdims: + if axis == None: + output_shape = tuple(1 for _ in a.shape) + else: + output_shape = tuple(1 if i == axis else s for i, s in enumerate(a.shape)) else: - output_shape = tuple(1 if i == axis else s for i, s in enumerate(a.shape)) - return (np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_)), - np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_))) + if axis == None: + output_shape = () + else: + output_shape = np.delete(np.array(a.shape, dtype=np.int64), axis) + t = dtypes.canonicalize_dtype(jax.numpy.float_) + return (np.full(output_shape, np.nan, dtype=t), + np.zeros(output_shape, dtype=t)) if scipy_version < (1, 9, 0): result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy)