Merge pull request #16393 from hawkinsp:winci2

PiperOrigin-RevId: 540246519
This commit is contained in:
jax authors 2023-06-14 06:01:09 -07:00
commit 9ba1f8e002

View File

@ -2311,12 +2311,17 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
and not config.x64_enabled):
self.skipTest("Only run float64 testcase when float64 is enabled.")
rng = rng_factory(self.rng())
np_fun = lambda x: np.frexp(x)
jnp_fun = lambda x: jnp.frexp(x)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
def np_frexp(x):
mantissa, exponent = np.frexp(x)
# NumPy is inconsistent between Windows and Linux/Mac on what the
# value of exponent is if the input is infinite. Normalize to the Linux
# behavior.
exponent = np.where(np.isinf(mantissa), np.zeros_like(exponent), exponent)
return mantissa, exponent
self._CheckAgainstNumpy(np_frexp, jnp.frexp, args_maker,
check_dtypes=np.issubdtype(dtype, np.inexact))
self._CompileAndCheck(jnp_fun, args_maker)
self._CompileAndCheck(jnp.frexp, args_maker)
@jtu.sample_product(
[dict(shape=shape, axis1=axis1, axis2=axis2)