check that overflow is raised on builtin int overflow

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
This commit is contained in:
Matthew Johnson 2023-04-03 15:47:38 -07:00
parent dd7fc98772
commit c6fa6c1557
2 changed files with 67 additions and 39 deletions

View File

@ -9913,19 +9913,22 @@ class NamedCallTest(jtu.JaxTestCase):
out = f(5)
self.assertEqual(out, 5)
@jtu.sample_product(
@parameterized.parameters(
[dict(func=func, jit=jit)
for func in ['trivial', 'identity', 'asarray', 'device_put']
for func in ['identity_trivial', 'identity', 'closure_trivial', 'closure',
'asarray', 'device_put']
for jit in jtu.JIT_IMPLEMENTATION
if not (jit._name == "noop" and func in ('trivial', 'identity'))
if not (jit._name == "noop" and func in ('identity', 'identity_trivial'))
],
)
def test_integer_overflow(self, jit, func):
funcdict = {
'trivial': lambda x: x,
'identity': lambda x: x * 1, # non-trivial
'asarray': jnp.asarray,
'device_put': api.device_put,
'identity_trivial': lambda x: x, # may hit trivial dispatch path
'identity': lambda x: x + 0,
'closure_trivial': lambda x: jax.jit(lambda: x)(),
'closure': lambda x: jax.jit(lambda: x + 0)(),
'asarray': lambda x: jnp.asarray(x), # add lambdas so no cross-test cache
'device_put': lambda x: api.device_put(x),
}
f = jit(funcdict[func])
@ -9934,8 +9937,16 @@ class NamedCallTest(jtu.JaxTestCase):
int_max = np.iinfo(int_dtype).max
int_min = np.iinfo(int_dtype).min
# check before any jit cache entries
self.assertRaises(OverflowError, f, int_max + 1)
self.assertRaises(OverflowError, f, int_min - 1)
self.assertEqual(f(int_max).dtype, int_dtype)
self.assertEqual(f(int_min).dtype, int_dtype)
self.assertAllClose(f(int_max), int_max)
self.assertAllClose(f(int_min), int_min)
# check after any cache entries
self.assertRaises(OverflowError, f, int_max + 1)
self.assertRaises(OverflowError, f, int_min - 1)
if func in ('trivial', 'identity'):

View File

@ -393,41 +393,58 @@ class PrngTest(jtu.JaxTestCase):
self.assertArraysEqual(f2, s2)
self.assertArraysEqual(f3, s3)
@jtu.sample_product([
{"seed": 0, "type": int, "jit": True, "key": [0, 0]},
{"seed": 0, "type": int, "jit": False, "key": [0, 0]},
{"seed": 1, "type": np.int32, "jit": True, "key": [0, 1]},
{"seed": 1, "type": np.int32, "jit": False, "key": [0, 1]},
{"seed": 2, "type": np.uint32, "jit": True, "key": [0, 2]},
{"seed": 2, "type": np.uint32, "jit": False, "key": [0, 2]},
{"seed": 3, "type": np.int64, "jit": True, "key": [0, 3]},
{"seed": 3, "type": np.int64, "jit": False, "key": [0, 3]},
{"seed": -1, "type": int, "jit": True, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
{"seed": -1, "type": int, "jit": False, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
{"seed": -2, "type": np.int32, "jit": True, "key": [0, 4294967294]},
{"seed": -2, "type": np.int32, "jit": False, "key": [0, 4294967294]},
{"seed": -3, "type": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
{"seed": -3, "type": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
{"seed": np.iinfo(np.int32).max + 100, "type": int, "jit": True, "key": [0, 2147483747]},
{"seed": np.iinfo(np.int32).max + 100, "type": int, "jit": False, "key": [0, 2147483747]},
{"seed": np.iinfo(np.int32).max + 101, "type": np.uint32, "jit": True, "key": [0, 2147483748]},
{"seed": np.iinfo(np.int32).max + 101, "type": np.uint32, "jit": False, "key": [0, 2147483748]},
{"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": True, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 100, "type": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
{"seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
@parameterized.parameters([
{"seed": 0, "typ": int, "jit": True, "key": [0, 0]},
{"seed": 0, "typ": int, "jit": False, "key": [0, 0]},
{"seed": 1, "typ": np.int32, "jit": True, "key": [0, 1]},
{"seed": 1, "typ": np.int32, "jit": False, "key": [0, 1]},
{"seed": 2, "typ": np.uint32, "jit": True, "key": [0, 2]},
{"seed": 2, "typ": np.uint32, "jit": False, "key": [0, 2]},
{"seed": 3, "typ": np.int64, "jit": True, "key": [0, 3]},
{"seed": 3, "typ": np.int64, "jit": False, "key": [0, 3]},
{"seed": -1, "typ": int, "jit": True, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
{"seed": -1, "typ": int, "jit": False, "key": [4294967295, 4294967295] if config.x64_enabled else [0, 4294967295]},
{"seed": -2, "typ": np.int32, "jit": True, "key": [0, 4294967294]},
{"seed": -2, "typ": np.int32, "jit": False, "key": [0, 4294967294]},
{"seed": -3, "typ": np.int64, "jit": True, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
{"seed": -3, "typ": np.int64, "jit": False, "key": [4294967295, 4294967293] if config.x64_enabled else [0, 4294967293]},
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": True, "key": [0, 2147483747]},
{"seed": np.iinfo(np.int32).max + 100, "typ": int, "jit": False, "key": [0, 2147483747]},
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": True, "key": [0, 2147483748]},
{"seed": np.iinfo(np.int32).max + 101, "typ": np.uint32, "jit": False, "key": [0, 2147483748]},
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": True, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 100, "typ": int, "jit": False, "key": [4294967295, 2147483548] if config.x64_enabled else [0, 2147483548]},
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": True, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
{"seed": np.iinfo(np.int32).min - 101, "typ": np.int64, "jit": False, "key": [4294967295, 2147483547] if config.x64_enabled else [0, 2147483547]},
])
def test_prng_seeds_and_keys(self, seed, type, jit, key):
if (jit and type is int and not config.x64_enabled and
(seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)):
self.skipTest("Expected failure: integer out of range for jit.")
seed = type(seed)
def test_prng_seeds_and_keys(self, seed, typ, jit, key):
seed = typ(seed)
if jit:
actual = _prng_key_as_array(jax.jit(random.PRNGKey)(seed))
maker = lambda k: _prng_key_as_array(jax.jit(random.PRNGKey)(k))
else:
actual = _prng_key_as_array(random.PRNGKey(seed))
expected = jnp.array(key, dtype=jnp.uint32)
self.assertArraysEqual(actual, expected)
maker = lambda k: _prng_key_as_array(random.PRNGKey(k))
if (jit and typ is int and not config.x64_enabled and
(seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)):
# We expect an error to be raised.
# NOTE: we check 'if jit' because some people rely on builtin int seeds
# (e.g. from PRNGKey(hash("altair is best plotting library"))) outside jit
# First check with no cache entry (note lambda above).
with self.assertRaises(OverflowError):
maker(seed)
# Then populate a cache entry.
maker(typ(0)).block_until_ready()
# Then check now that we have a cache entry.
with self.assertRaises(OverflowError):
maker(seed)
else:
# Otherwise we expect no error.
actual = maker(seed)
expected = jnp.array(key, dtype=jnp.uint32)
self.assertArraysEqual(actual, expected)
def test_default_prng_selection(self):
if not config.jax_enable_custom_prng: