mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #6269 from jakevdp:x32-overflow
PiperOrigin-RevId: 365866951
This commit is contained in:
commit
b48ca49559
@ -25,6 +25,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
for more information.
|
||||
* Python integers larger than the maximum `int64` value will now lead to an overflow
|
||||
in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`#6047`).
|
||||
* Outside X64 mode, Python integers outside the range representable by `int32` will now lead to an
|
||||
`OverflowError` rather than having their value silently truncated.
|
||||
* Bug fixes:
|
||||
* `host_callback` now supports empty arrays in arguments and results ({jax-issue}`#6262`).
|
||||
|
||||
|
@ -59,6 +59,11 @@ def PRNGKey(seed: int) -> jnp.ndarray:
|
||||
key is constructed from a 64-bit seed by effectively bit-casting to a pair
|
||||
of uint32 values (or from a 32-bit seed by first padding out with zeros).
|
||||
"""
|
||||
# Avoid overflowerror in X32 mode by first converting ints to int64.
|
||||
# This breaks JIT invariance of PRNGKey for large ints, but supports the
|
||||
# common use-case of instantiating PRNGKey with Python hashes in X32 mode.
|
||||
if isinstance(seed, int):
|
||||
seed = np.int64(seed)
|
||||
seed_arr = jnp.asarray(seed)
|
||||
if seed_arr.shape:
|
||||
raise TypeError(f"PRNGKey seed must be a scalar; got {seed!r}.")
|
||||
@ -279,7 +284,7 @@ def fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
|
||||
A new PRNGKey that is a deterministic function of the inputs and is
|
||||
statistically safe for producing a stream of new pseudo-random values.
|
||||
"""
|
||||
return _fold_in(key, data)
|
||||
return _fold_in(key, jnp.uint32(data))
|
||||
|
||||
@jit
|
||||
def _fold_in(key, data):
|
||||
|
@ -121,8 +121,7 @@ def _scalar_type_to_dtype(typ: type, value: Any = None):
|
||||
---------------------------------------------------------------------------
|
||||
OverflowError: Python int 9223372036854775808 too large to convert to int64
|
||||
"""
|
||||
dtype = python_scalar_dtypes[typ]
|
||||
# TODO(jakevdp): use proper overflow for int32.
|
||||
dtype = canonicalize_dtype(python_scalar_dtypes[typ])
|
||||
if typ is int and value is not None:
|
||||
if value < np.iinfo(dtype).min or value > np.iinfo(dtype).max:
|
||||
raise OverflowError(f"Python int {value} too large to convert to {dtype}")
|
||||
|
@ -5031,30 +5031,28 @@ class NamedCallTest(jtu.JaxTestCase):
|
||||
for jit_type in [None, "python", "cpp"]
|
||||
if not (jit_type is None and func == 'identity')))
|
||||
def test_integer_overflow(self, jit_type, func):
|
||||
def jit(f, **kwargs):
|
||||
if jit_type is None:
|
||||
return f
|
||||
elif jit_type == "python":
|
||||
return api._python_jit(f, **kwargs)
|
||||
elif jit_type == "cpp":
|
||||
return api._cpp_jit(f, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"invalid jit_type={jit_type}")
|
||||
func = jit({
|
||||
if jit_type == "cpp" and not config.x64_enabled and jax.lib.version < (0, 1, 65):
|
||||
self.skipTest("int32 overflow detection not yet implemented in CPP JIT.")
|
||||
funcdict = {
|
||||
'identity': lambda x: x,
|
||||
'asarray': jnp.asarray,
|
||||
'device_put': api.device_put
|
||||
}[func])
|
||||
'device_put': api.device_put,
|
||||
}
|
||||
jit = {
|
||||
'python': api._python_jit,
|
||||
'cpp': api._cpp_jit,
|
||||
None: lambda x: x,
|
||||
}
|
||||
f = jit[jit_type](funcdict[func])
|
||||
|
||||
int64_max = np.iinfo(np.int64).max
|
||||
int64_min = np.iinfo(np.int64).min
|
||||
int_dtype = dtypes.canonicalize_dtype(jnp.int_)
|
||||
int_max = np.iinfo(int_dtype).max
|
||||
int_min = np.iinfo(int_dtype).min
|
||||
|
||||
int_dtype = dtypes.canonicalize_dtype(np.int64)
|
||||
|
||||
self.assertEqual(func(int64_max).dtype, int_dtype)
|
||||
self.assertEqual(func(int64_min).dtype, int_dtype)
|
||||
self.assertRaises(OverflowError, func, int64_max + 1)
|
||||
self.assertRaises(OverflowError, func, int64_min - 1)
|
||||
self.assertEqual(f(int_max).dtype, int_dtype)
|
||||
self.assertEqual(f(int_min).dtype, int_dtype)
|
||||
self.assertRaises(OverflowError, f, int_max + 1)
|
||||
self.assertRaises(OverflowError, f, int_min - 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -3275,17 +3275,17 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp.array(3, [('a','<i4'),('b','<i4')])
|
||||
|
||||
def testArrayFromInteger(self):
|
||||
# TODO(jakevdp): implement X32 overflow and canonicalize these
|
||||
int_max = jnp.iinfo(jnp.int64).max
|
||||
int_min = jnp.iinfo(jnp.int64).min
|
||||
int_dtype = dtypes.canonicalize_dtype(jnp.int64)
|
||||
int_max = jnp.iinfo(int_dtype).max
|
||||
int_min = jnp.iinfo(int_dtype).min
|
||||
|
||||
# Values at extremes are converted correctly.
|
||||
for val in [int_min, 0, int_max]:
|
||||
self.assertEqual(jnp.array(val).dtype, dtypes.canonicalize_dtype('int64'))
|
||||
self.assertEqual(jnp.array(val).dtype, int_dtype)
|
||||
|
||||
# out of bounds leads to an OverflowError
|
||||
val = int_max + 1
|
||||
with self.assertRaisesRegex(OverflowError, f"Python int {val} too large to convert to int64"):
|
||||
with self.assertRaisesRegex(OverflowError, f"Python int {val} too large to convert to {int_dtype.name}"):
|
||||
jnp.array(val)
|
||||
|
||||
# explicit uint64 should work
|
||||
|
@ -946,6 +946,9 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
]
|
||||
))
|
||||
def test_prng_seeds_and_keys(self, seed, type, jit, key):
|
||||
if (jit and type is int and not config.x64_enabled and
|
||||
(seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)):
|
||||
self.skipTest("Expected failure: integer out of range for jit.")
|
||||
seed = type(seed)
|
||||
if jit:
|
||||
actual = api.jit(random.PRNGKey)(seed)
|
||||
@ -961,6 +964,8 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def test_prng_jit_invariance(self, seed, type):
|
||||
if type == "int" and seed == (1 << 64) - 1:
|
||||
self.skipTest("Expected failure: Python int too large.")
|
||||
if not config.x64_enabled and seed > np.iinfo(np.int32).max:
|
||||
self.skipTest("Expected failure: Python int too large.")
|
||||
type = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type]
|
||||
args_maker = lambda: [type(seed)]
|
||||
self._CompileAndCheck(random.PRNGKey, args_maker)
|
||||
|
Loading…
x
Reference in New Issue
Block a user