Tracer: add missing __round__ and __reversed__ methods

This commit is contained in:
Jake VanderPlas 2022-09-20 09:09:23 -07:00
parent 441f400358
commit 74698048f3
2 changed files with 26 additions and 7 deletions

View File

@ -550,6 +550,9 @@ class Tracer:
def __iter__(self):
return iter(self.aval._iter(self))
def __reversed__(self):
return iter(self[::-1])
def __len__(self):
return self.aval._len(self)
@ -617,6 +620,7 @@ class Tracer:
def __complex__(self): return self.aval._complex(self)
def __copy__(self): return self.aval._copy(self)
def __deepcopy__(self, memo): return self.aval._deepcopy(self, memo)
def __round__(self, ndigits=None): return self.aval._round(self, ndigits)
# raises a useful error on attempts to pickle a Tracer.
def __reduce__(self):

View File

@ -1731,21 +1731,36 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes,
atol=tol, rtol=tol)
def testOperatorRound(self):
@parameterized.named_parameters(
{"testcase_name": f"_jit={jit}", "jit": jit} for jit in [True, False])
def testOperatorRound(self, jit):
jround = jax.jit(round, static_argnums=1) if jit else round
self.assertAllClose(round(np.float32(7.532), 1),
round(jnp.float32(7.5), 1))
jround(jnp.float32(7.5), 1))
self.assertAllClose(round(np.float32(1.234), 2),
round(jnp.float32(1.234), 2))
jround(jnp.float32(1.234), 2))
self.assertAllClose(round(np.float32(1.234)),
round(jnp.float32(1.234)), check_dtypes=False)
jround(jnp.float32(1.234)), check_dtypes=False)
self.assertAllClose(round(np.float32(7.532), 1),
round(jnp.array(7.5, jnp.float32), 1))
jround(jnp.array(7.5, jnp.float32), 1))
self.assertAllClose(round(np.float32(1.234), 2),
round(jnp.array(1.234, jnp.float32), 2))
jround(jnp.array(1.234, jnp.float32), 2))
self.assertAllClose(round(np.float32(1.234)),
round(jnp.array(1.234, jnp.float32)),
jround(jnp.array(1.234, jnp.float32)),
check_dtypes=False)
@parameterized.named_parameters(
{"testcase_name": f"_shape={shape}", "shape": shape}
for shape in [(5,), (5, 2)])
def testOperatorReversed(self, shape):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, 'float32')]
np_fun = lambda x: np.array(list(reversed(x)))
jnp_fun = lambda x: jnp.array(list(reversed(x)))
self._CompileAndCheck(jnp_fun, args_maker)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_mode={}_padwidth={}_constantvalues={}".format(
jtu.format_shape_dtype_string(shape, dtype), mode, pad_width,