mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #16393 from hawkinsp:winci2
PiperOrigin-RevId: 540246519
This commit is contained in:
commit
9ba1f8e002
@ -2311,12 +2311,17 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
and not config.x64_enabled):
|
||||
self.skipTest("Only run float64 testcase when float64 is enabled.")
|
||||
rng = rng_factory(self.rng())
|
||||
np_fun = lambda x: np.frexp(x)
|
||||
jnp_fun = lambda x: jnp.frexp(x)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
def np_frexp(x):
|
||||
mantissa, exponent = np.frexp(x)
|
||||
# NumPy is inconsistent between Windows and Linux/Mac on what the
|
||||
# value of exponent is if the input is infinite. Normalize to the Linux
|
||||
# behavior.
|
||||
exponent = np.where(np.isinf(mantissa), np.zeros_like(exponent), exponent)
|
||||
return mantissa, exponent
|
||||
self._CheckAgainstNumpy(np_frexp, jnp.frexp, args_maker,
|
||||
check_dtypes=np.issubdtype(dtype, np.inexact))
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp.frexp, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, axis1=axis1, axis2=axis2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user