mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
check that overflow is raised on builtin int overflow
Co-authored-by: Jake Vanderplas <jakevdp@google.com>
This commit is contained in:
parent
dd7fc98772
commit
c6fa6c1557
@ -9913,19 +9913,22 @@ class NamedCallTest(jtu.JaxTestCase):
|
||||
out = f(5)
|
||||
self.assertEqual(out, 5)
|
||||
|
||||
@jtu.sample_product(
|
||||
@parameterized.parameters(
|
||||
[dict(func=func, jit=jit)
|
||||
for func in ['trivial', 'identity', 'asarray', 'device_put']
|
||||
for func in ['identity_trivial', 'identity', 'closure_trivial', 'closure',
|
||||
'asarray', 'device_put']
|
||||
for jit in jtu.JIT_IMPLEMENTATION
|
||||
if not (jit._name == "noop" and func in ('trivial', 'identity'))
|
||||
if not (jit._name == "noop" and func in ('identity', 'identity_trivial'))
|
||||
],
|
||||
)
|
||||
def test_integer_overflow(self, jit, func):
|
||||
funcdict = {
|
||||
'trivial': lambda x: x,
|
||||
'identity': lambda x: x * 1, # non-trivial
|
||||
'asarray': jnp.asarray,
|
||||
'device_put': api.device_put,
|
||||
'identity_trivial': lambda x: x, # may hit trivial dispatch path
|
||||
'identity': lambda x: x + 0,
|
||||
'closure_trivial': lambda x: jax.jit(lambda: x)(),
|
||||
'closure': lambda x: jax.jit(lambda: x + 0)(),
|
||||
'asarray': lambda x: jnp.asarray(x), # add lambdas so no cross-test cache
|
||||
'device_put': lambda x: api.device_put(x),
|
||||
}
|
||||
|
||||
f = jit(funcdict[func])
|
||||
@ -9934,8 +9937,16 @@ class NamedCallTest(jtu.JaxTestCase):
|
||||
int_max = np.iinfo(int_dtype).max
|
||||
int_min = np.iinfo(int_dtype).min
|
||||
|
||||
# check before any jit cache entries
|
||||
self.assertRaises(OverflowError, f, int_max + 1)
|
||||
self.assertRaises(OverflowError, f, int_min - 1)
|
||||
|
||||
self.assertEqual(f(int_max).dtype, int_dtype)
|
||||
self.assertEqual(f(int_min).dtype, int_dtype)
|
||||
self.assertAllClose(f(int_max), int_max)
|
||||
self.assertAllClose(f(int_min), int_min)
|
||||
|
||||
# check after any cache entries
|
||||
self.assertRaises(OverflowError, f, int_max + 1)
|
||||
self.assertRaises(OverflowError, f, int_min - 1)
|
||||
if func in ('trivial', 'identity'):
|
||||
|
@ -393,41 +393,58 @@ class PrngTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(f2, s2)
|
||||
self.assertArraysEqual(f3, s3)
|
||||
|
||||
@jtu.sample_product([
|
||||
{"seed": 0, "type": int, "jit": True, "key": [0, 0]},
|
||||
{"seed": 0, "type": int, "jit": False, "key": [0, 0]},
|
||||
{"seed": 1, "type": np.int32, "jit": True, "key": [0, 1]},
|
||||
{"seed": 1, "type": np.int32, "jit": False, "key": [0, 1]},
|
||||
{"seed": 2, "type": np.uint32, "jit": True, "key": [0, 2]},
|
||||
{"seed": 2, "type": np.uint32, "jit": False, "key": [0, 2]},
|
||||
{"seed": 3, "type": np.int64, "jit": True, "key": [0, 3]},
|
||||
{"seed": 3, "type": np.int64, "jit": False, "key": [0, 3]},
|
||||
{"seed": -1, "type": int, "jit": True, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
|
||||
{"seed": -1, "type": int, "jit": False, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
|
||||
{"seed": -2, "type": np.int32, "jit": True, "key": [0, 4294967294]},
|
||||
{"seed": -2, "type": np.int32, "jit": False, "key": [0, 4294967294]},
|
||||
{"seed": -3, "type": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
|
||||
{"seed": -3, "type": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
|
||||
{"seed": np.iinfo(np.int32).max + 100, "type": int, "jit": True, "key": [0, 2147483747]},
|
||||
{"seed": np.iinfo(np.int32).max + 100, "type": int, "jit": False, "key": [0, 2147483747]},
|
||||
{"seed": np.iinfo(np.int32).max + 101, "type": np.uint32, "jit": True, "key": [0, 2147483748]},
|
||||
{"seed": np.iinfo(np.int32).max + 101, "type": np.uint32, "jit": False, "key": [0, 2147483748]},
|
||||
{"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": True, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
|
||||
{"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
|
||||
@parameterized.parameters([
|
||||
{"seed": 0, "typ": int, "jit": True, "key": [0, 0]},
|
||||
{"seed": 0, "typ": int, "jit": False, "key": [0, 0]},
|
||||
{"seed": 1, "typ": np.int32, "jit": True, "key": [0, 1]},
|
||||
{"seed": 1, "typ": np.int32, "jit": False, "key": [0, 1]},
|
||||
{"seed": 2, "typ": np.uint32, "jit": True, "key": [0, 2]},
|
||||
{"seed": 2, "typ": np.uint32, "jit": False, "key": [0, 2]},
|
||||
{"seed": 3, "typ": np.int64, "jit": True, "key": [0, 3]},
|
||||
{"seed": 3, "typ": np.int64, "jit": False, "key": [0, 3]},
|
||||
{"seed": -1, "typ": int, "jit": True, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
|
||||
{"seed": -1, "typ": int, "jit": False, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
|
||||
{"seed": -2, "typ": np.int32, "jit": True, "key": [0, 4294967294]},
|
||||
{"seed": -2, "typ": np.int32, "jit": False, "key": [0, 4294967294]},
|
||||
{"seed": -3, "typ": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
|
||||
{"seed": -3, "typ": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
|
||||
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": True, "key": [0, 2147483747]},
|
||||
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": False, "key": [0, 2147483747]},
|
||||
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": True, "key": [0, 2147483748]},
|
||||
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": False, "key": [0, 2147483748]},
|
||||
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": True, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
|
||||
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
|
||||
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
|
||||
])
|
||||
def test_prng_seeds_and_keys(self, seed, type, jit, key):
|
||||
if (jit and type is int and not config.x64_enabled and
|
||||
(seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)):
|
||||
self.skipTest("Expected failure: integer out of range for jit.")
|
||||
seed = type(seed)
|
||||
def test_prng_seeds_and_keys(self, seed, typ, jit, key):
|
||||
seed = typ(seed)
|
||||
if jit:
|
||||
actual = _prng_key_as_array(jax.jit(random.PRNGKey)(seed))
|
||||
maker = lambda k: _prng_key_as_array(jax.jit(random.PRNGKey)(k))
|
||||
else:
|
||||
actual = _prng_key_as_array(random.PRNGKey(seed))
|
||||
expected = jnp.array(key, dtype=jnp.uint32)
|
||||
self.assertArraysEqual(actual, expected)
|
||||
maker = lambda k: _prng_key_as_array(random.PRNGKey(k))
|
||||
if (jit and typ is int and not config.x64_enabled and
|
||||
(seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)):
|
||||
# We expect an error to be raised.
|
||||
# NOTE: we check 'if jit' because some people rely on builtin int seeds
|
||||
# (e.g. from PRNGKey(hash("altair is best plotting library"))) outside jit
|
||||
|
||||
# First check with no cache entry (note lambda above).
|
||||
with self.assertRaises(OverflowError):
|
||||
maker(seed)
|
||||
|
||||
# Then populate a cache entry.
|
||||
maker(typ(0)).block_until_ready()
|
||||
|
||||
# Then check now that we have a cache entry.
|
||||
with self.assertRaises(OverflowError):
|
||||
maker(seed)
|
||||
|
||||
else:
|
||||
# Otherwise we expect no error.
|
||||
actual = maker(seed)
|
||||
expected = jnp.array(key, dtype=jnp.uint32)
|
||||
self.assertArraysEqual(actual, expected)
|
||||
|
||||
def test_default_prng_selection(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
|
Loading…
x
Reference in New Issue
Block a user