mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #8520 from jakevdp:fix-percentile
PiperOrigin-RevId: 409495454
This commit is contained in:
commit
eeb9bf7a47
@ -6287,8 +6287,7 @@ def _quantile(a, q, axis, interpolation, keepdims, squash_nans):
|
||||
if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
|
||||
raise ValueError("interpolation can only be 'linear', 'lower', 'higher', "
|
||||
"'midpoint', or 'nearest'")
|
||||
a = asarray(a, dtype=promote_types(_dtype(a), float32))
|
||||
q = asarray(q, dtype=promote_types(_dtype(q), float32))
|
||||
a, q = _promote_dtypes_inexact(a, q)
|
||||
if axis is None:
|
||||
a = ravel(a)
|
||||
axis = 0
|
||||
@ -6473,7 +6472,8 @@ def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out=None, overwrite_input=False, interpolation="linear",
|
||||
keepdims=False):
|
||||
_check_arraylike("percentile", a, q)
|
||||
q = true_divide(q, float32(100.0))
|
||||
a, q = _promote_dtypes_inexact(a, q)
|
||||
q = true_divide(q, 100.0)
|
||||
return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
interpolation=interpolation, keepdims=keepdims)
|
||||
|
||||
|
@ -4416,6 +4416,12 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
|
||||
|
||||
@unittest.skipIf(not config.jax_enable_x64, "test requires X64")
|
||||
@unittest.skipIf(jtu.device_under_test() != 'cpu', "test is for CPU float64 precision")
|
||||
def testPercentilePrecision(self):
|
||||
# Regression test for https://github.com/google/jax/issues/8513
|
||||
x = jnp.float64([1, 2, 3, 4, 7, 10])
|
||||
self.assertEqual(jnp.percentile(x, 50), 3.5)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
|
Loading…
x
Reference in New Issue
Block a user