mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #1095 from hawkinsp/fixes
Fix test failures due to Numpy 1.17.
This commit is contained in:
commit
24d4aaf3e2
@ -589,7 +589,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max)
|
||||
lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
||||
# TODO(phawkins): the promotion behavior changed in Numpy 1.17.
|
||||
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -1495,9 +1496,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
"rng": jtu.rand_default(), "shapes": shapes, "dtypes": dtypes}
|
||||
for shapes, dtypes in (
|
||||
((), ()),
|
||||
(((7,),), (onp.float32,)),
|
||||
(((3,), (4,)), (onp.float32, onp.int32)),
|
||||
(((3,), (0,), (4,)), (onp.int32, onp.float32, onp.int32)),
|
||||
(((7,),), (onp.int32,)),
|
||||
(((3,), (4,)), (onp.int32, onp.int32)),
|
||||
(((3,), (1,), (4,)), (onp.int32, onp.int32, onp.int32)),
|
||||
)))
|
||||
def testIx_(self, rng, shapes, dtypes):
|
||||
args_maker = lambda: [rng(shape, dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user