Add test of jnp.arange() corner case

This commit is contained in:
Jake VanderPlas 2021-11-15 13:33:51 -08:00
parent be751d1dd6
commit fbd9009c54

View File

@ -5607,6 +5607,28 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('stop')):
jax.jit(lambda stop: jnp.arange(0, stop))(3)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": str(dtype), "dtype": dtype}
for dtype in [None] + float_dtypes))
def testArange64Bit(self, dtype):
# Test that jnp.arange uses 64-bit arithmetic to define its range, even if the
# output has another dtype. The issue here is that if python scalar inputs to
# jnp.arange are cast to float32 before the range is computed, it changes the
# number of elements output by the range. It's unclear whether this was deliberate
# behavior in the initial implementation, but it's behavior that downstream users
# have come to rely on.
args = (1.2, 4.8, 0.24)
# Ensure that this test case leads to differing lengths if cast to float32.
self.assertLen(np.arange(*args), 15)
self.assertLen(np.arange(*map(np.float32, args)), 16)
jnp_fun = lambda: jnp.arange(*args, dtype=dtype)
np_fun = lambda: np.arange(*args, dtype=dtype)
args_maker = lambda: []
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testIssue2347(self):
# https://github.com/google/jax/issues/2347
object_list = List[Tuple[jnp.array, float, float, jnp.array, bool]]