mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Tracer: add missing __round__ and __reversed__ methods
This commit is contained in:
parent
441f400358
commit
74698048f3
@ -550,6 +550,9 @@ class Tracer:
|
||||
def __iter__(self):
|
||||
return iter(self.aval._iter(self))
|
||||
|
||||
def __reversed__(self):
|
||||
return iter(self[::-1])
|
||||
|
||||
def __len__(self):
|
||||
return self.aval._len(self)
|
||||
|
||||
@ -617,6 +620,7 @@ class Tracer:
|
||||
def __complex__(self): return self.aval._complex(self)
|
||||
def __copy__(self): return self.aval._copy(self)
|
||||
def __deepcopy__(self, memo): return self.aval._deepcopy(self, memo)
|
||||
def __round__(self, ndigits=None): return self.aval._round(self, ndigits)
|
||||
|
||||
# raises a useful error on attempts to pickle a Tracer.
|
||||
def __reduce__(self):
|
||||
|
@ -1731,21 +1731,36 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
def testOperatorRound(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_jit={jit}", "jit": jit} for jit in [True, False])
|
||||
def testOperatorRound(self, jit):
|
||||
jround = jax.jit(round, static_argnums=1) if jit else round
|
||||
self.assertAllClose(round(np.float32(7.532), 1),
|
||||
round(jnp.float32(7.5), 1))
|
||||
jround(jnp.float32(7.5), 1))
|
||||
self.assertAllClose(round(np.float32(1.234), 2),
|
||||
round(jnp.float32(1.234), 2))
|
||||
jround(jnp.float32(1.234), 2))
|
||||
self.assertAllClose(round(np.float32(1.234)),
|
||||
round(jnp.float32(1.234)), check_dtypes=False)
|
||||
jround(jnp.float32(1.234)), check_dtypes=False)
|
||||
self.assertAllClose(round(np.float32(7.532), 1),
|
||||
round(jnp.array(7.5, jnp.float32), 1))
|
||||
jround(jnp.array(7.5, jnp.float32), 1))
|
||||
self.assertAllClose(round(np.float32(1.234), 2),
|
||||
round(jnp.array(1.234, jnp.float32), 2))
|
||||
jround(jnp.array(1.234, jnp.float32), 2))
|
||||
self.assertAllClose(round(np.float32(1.234)),
|
||||
round(jnp.array(1.234, jnp.float32)),
|
||||
jround(jnp.array(1.234, jnp.float32)),
|
||||
check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_shape={shape}", "shape": shape}
|
||||
for shape in [(5,), (5, 2)])
|
||||
def testOperatorReversed(self, shape):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, 'float32')]
|
||||
np_fun = lambda x: np.array(list(reversed(x)))
|
||||
jnp_fun = lambda x: jnp.array(list(reversed(x)))
|
||||
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_mode={}_padwidth={}_constantvalues={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), mode, pad_width,
|
||||
|
Loading…
x
Reference in New Issue
Block a user