mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add test of jnp.arange() corner case
This commit is contained in:
parent
be751d1dd6
commit
fbd9009c54
@ -5607,6 +5607,28 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('stop')):
|
||||
jax.jit(lambda stop: jnp.arange(0, stop))(3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": str(dtype), "dtype": dtype}
|
||||
for dtype in [None] + float_dtypes))
|
||||
def testArange64Bit(self, dtype):
|
||||
# Test that jnp.arange uses 64-bit arithmetic to define its range, even if the
|
||||
# output has another dtype. The issue here is that if python scalar inputs to
|
||||
# jnp.arange are cast to float32 before the range is computed, it changes the
|
||||
# number of elements output by the range. It's unclear whether this was deliberate
|
||||
# behavior in the initial implementation, but it's behavior that downstream users
|
||||
# have come to rely on.
|
||||
args = (1.2, 4.8, 0.24)
|
||||
|
||||
# Ensure that this test case leads to differing lengths if cast to float32.
|
||||
self.assertLen(np.arange(*args), 15)
|
||||
self.assertLen(np.arange(*map(np.float32, args)), 16)
|
||||
|
||||
jnp_fun = lambda: jnp.arange(*args, dtype=dtype)
|
||||
np_fun = lambda: np.arange(*args, dtype=dtype)
|
||||
args_maker = lambda: []
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def testIssue2347(self):
|
||||
# https://github.com/google/jax/issues/2347
|
||||
object_list = List[Tuple[jnp.array, float, float, jnp.array, bool]]
|
||||
|
Loading…
x
Reference in New Issue
Block a user