mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix jnp.unwrap() test failures on GPU.
A recent XLA change allows XLA to use excess precision on GPU, which caused CompileAndCheck to report noticeable numerical changes for bfloat16. In passing, also enable the comparison against NumPy test for bfloat16 by using a wrapper function. PiperOrigin-RevId: 476494989
This commit is contained in:
parent
d2fcfb6b83
commit
8ee7129874
@ -1073,7 +1073,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for discont in [None, "pi", 2]
|
||||
for period in ["2pi", "pi"]
|
||||
for axis in list(range(-len(shape), len(shape)))))
|
||||
@jtu.skip_on_devices('gpu')
|
||||
def testUnwrap(self, shape, dtype, axis, discont, period):
|
||||
if numpy_version < (1, 21) and period != "2pi":
|
||||
self.skipTest("numpy < 1.21 does not support the period argument to unwrap()")
|
||||
@ -1082,18 +1081,26 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
discont = special_vals.get(discont, discont)
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
if numpy_version < (1, 21):
|
||||
np_fun = partial(np.unwrap, axis=axis, discont=discont)
|
||||
else:
|
||||
np_fun = partial(np.unwrap, axis=axis, discont=discont, period=period)
|
||||
|
||||
def np_fun(x):
|
||||
dtype = None
|
||||
if x.dtype == dtypes.bfloat16:
|
||||
dtype = x.dtype
|
||||
x = x.astype(np.float32)
|
||||
if numpy_version < (1, 21):
|
||||
out = np.unwrap(x, axis=axis, discont=discont or np.pi)
|
||||
else:
|
||||
out = np.unwrap(x, axis=axis, discont=discont, period=period)
|
||||
return out if dtype is None else out.astype(dtype)
|
||||
|
||||
jnp_fun = partial(jnp.unwrap, axis=axis, discont=discont, period=period)
|
||||
if not dtypes.issubdtype(dtype, np.inexact):
|
||||
# This case requires implicit dtype promotion
|
||||
jnp_fun = jax.numpy_dtype_promotion('standard')(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
if dtype != jnp.bfloat16: # numpy crashes on bfloat16
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
|
||||
atol={dtypes.bfloat16: 1e-1, np.float16: 1e-2})
|
||||
self._CompileAndCheck(jnp_fun, args_maker, atol={dtypes.bfloat16: 1e-1})
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_axis={}".format(
|
||||
|
Loading…
x
Reference in New Issue
Block a user