Merge pull request #6269 from jakevdp:x32-overflow

PiperOrigin-RevId: 365866951
This commit is contained in:
jax authors 2021-03-30 12:12:39 -07:00
commit b48ca49559
6 changed files with 37 additions and 28 deletions

View File

@ -25,6 +25,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
for more information.
* Python integers larger than the maximum `int64` value will now lead to an overflow
in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`#6047`).
* Outside X64 mode, Python integers outside the range representable by `int32` will now lead to an
`OverflowError` rather than having their value silently truncated.
* Bug fixes:
* `host_callback` now supports empty arrays in arguments and results ({jax-issue}`#6262`).

View File

@ -59,6 +59,11 @@ def PRNGKey(seed: int) -> jnp.ndarray:
key is constructed from a 64-bit seed by effectively bit-casting to a pair
of uint32 values (or from a 32-bit seed by first padding out with zeros).
"""
# Avoid overflowerror in X32 mode by first converting ints to int64.
# This breaks JIT invariance of PRNGKey for large ints, but supports the
# common use-case of instantiating PRNGKey with Python hashes in X32 mode.
if isinstance(seed, int):
seed = np.int64(seed)
seed_arr = jnp.asarray(seed)
if seed_arr.shape:
raise TypeError(f"PRNGKey seed must be a scalar; got {seed!r}.")
@ -279,7 +284,7 @@ def fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
A new PRNGKey that is a deterministic function of the inputs and is
statistically safe for producing a stream of new pseudo-random values.
"""
return _fold_in(key, data)
return _fold_in(key, jnp.uint32(data))
@jit
def _fold_in(key, data):

View File

@ -121,8 +121,7 @@ def _scalar_type_to_dtype(typ: type, value: Any = None):
---------------------------------------------------------------------------
OverflowError: Python int 9223372036854775808 too large to convert to int64
"""
dtype = python_scalar_dtypes[typ]
# TODO(jakevdp): use proper overflow for int32.
dtype = canonicalize_dtype(python_scalar_dtypes[typ])
if typ is int and value is not None:
if value < np.iinfo(dtype).min or value > np.iinfo(dtype).max:
raise OverflowError(f"Python int {value} too large to convert to {dtype}")

View File

@ -5031,30 +5031,28 @@ class NamedCallTest(jtu.JaxTestCase):
for jit_type in [None, "python", "cpp"]
if not (jit_type is None and func == 'identity')))
def test_integer_overflow(self, jit_type, func):
def jit(f, **kwargs):
if jit_type is None:
return f
elif jit_type == "python":
return api._python_jit(f, **kwargs)
elif jit_type == "cpp":
return api._cpp_jit(f, **kwargs)
else:
raise ValueError(f"invalid jit_type={jit_type}")
func = jit({
if jit_type == "cpp" and not config.x64_enabled and jax.lib.version < (0, 1, 65):
self.skipTest("int32 overflow detection not yet implemented in CPP JIT.")
funcdict = {
'identity': lambda x: x,
'asarray': jnp.asarray,
'device_put': api.device_put
}[func])
'device_put': api.device_put,
}
jit = {
'python': api._python_jit,
'cpp': api._cpp_jit,
None: lambda x: x,
}
f = jit[jit_type](funcdict[func])
int64_max = np.iinfo(np.int64).max
int64_min = np.iinfo(np.int64).min
int_dtype = dtypes.canonicalize_dtype(jnp.int_)
int_max = np.iinfo(int_dtype).max
int_min = np.iinfo(int_dtype).min
int_dtype = dtypes.canonicalize_dtype(np.int64)
self.assertEqual(func(int64_max).dtype, int_dtype)
self.assertEqual(func(int64_min).dtype, int_dtype)
self.assertRaises(OverflowError, func, int64_max + 1)
self.assertRaises(OverflowError, func, int64_min - 1)
self.assertEqual(f(int_max).dtype, int_dtype)
self.assertEqual(f(int_min).dtype, int_dtype)
self.assertRaises(OverflowError, f, int_max + 1)
self.assertRaises(OverflowError, f, int_min - 1)
if __name__ == '__main__':

View File

@ -3275,17 +3275,17 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp.array(3, [('a','<i4'),('b','<i4')])
def testArrayFromInteger(self):
# TODO(jakevdp): implement X32 overflow and canonicalize these
int_max = jnp.iinfo(jnp.int64).max
int_min = jnp.iinfo(jnp.int64).min
int_dtype = dtypes.canonicalize_dtype(jnp.int64)
int_max = jnp.iinfo(int_dtype).max
int_min = jnp.iinfo(int_dtype).min
# Values at extremes are converted correctly.
for val in [int_min, 0, int_max]:
self.assertEqual(jnp.array(val).dtype, dtypes.canonicalize_dtype('int64'))
self.assertEqual(jnp.array(val).dtype, int_dtype)
# out of bounds leads to an OverflowError
val = int_max + 1
with self.assertRaisesRegex(OverflowError, f"Python int {val} too large to convert to int64"):
with self.assertRaisesRegex(OverflowError, f"Python int {val} too large to convert to {int_dtype.name}"):
jnp.array(val)
# explicit uint64 should work

View File

@ -946,6 +946,9 @@ class LaxRandomTest(jtu.JaxTestCase):
]
))
def test_prng_seeds_and_keys(self, seed, type, jit, key):
if (jit and type is int and not config.x64_enabled and
(seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)):
self.skipTest("Expected failure: integer out of range for jit.")
seed = type(seed)
if jit:
actual = api.jit(random.PRNGKey)(seed)
@ -961,6 +964,8 @@ class LaxRandomTest(jtu.JaxTestCase):
def test_prng_jit_invariance(self, seed, type):
if type == "int" and seed == (1 << 64) - 1:
self.skipTest("Expected failure: Python int too large.")
if not config.x64_enabled and seed > np.iinfo(np.int32).max:
self.skipTest("Expected failure: Python int too large.")
type = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type]
args_maker = lambda: [type(seed)]
self._CompileAndCheck(random.PRNGKey, args_maker)