Merge pull request #5868 from jakevdp:fix-randint

PiperOrigin-RevId: 366265880
This commit is contained in:
jax authors 2021-04-01 09:25:08 -07:00
commit 9ca11d4de9
3 changed files with 99 additions and 9 deletions

View File

@ -330,6 +330,47 @@ def _promote_args_inexact(fun_name, *args):
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))
def _convert_and_clip_integer(val, dtype):
"""
Convert integer-typed val to specified integer dtype, clipping to dtype
range rather than wrapping.
Args:
val: value to be converted
dtype: dtype of output
Returns:
equivalent of val in new dtype
Examples
--------
Normal integer type conversion will wrap:
>>> val = jnp.uint32(0xFFFFFFFF)
>>> val.astype('int32')
DeviceArray(-1, dtype=int32)
This function clips to the values representable in the new type:
>>> _convert_and_clip_integer(val, 'int32')
DeviceArray(2147483647, dtype=int32)
"""
val = val if isinstance(val, ndarray) else asarray(val)
dtype = dtypes.canonicalize_dtype(dtype)
if not (issubdtype(dtype, integer) and issubdtype(val.dtype, integer)):
raise TypeError("_convert_and_clip_integer only accepts integer dtypes.")
val_dtype = dtypes.canonicalize_dtype(val.dtype)
if val_dtype != val.dtype:
# TODO(jakevdp): this is a weird corner case; need to figure out how to handle it.
# This happens in X32 mode and can either come from a jax value created in another
# context, or a Python integer converted to int64.
pass
min_val = _constant_like(val, _max(iinfo(dtype).min, iinfo(val_dtype).min))
max_val = _constant_like(val, _min(iinfo(dtype).max, iinfo(val_dtype).max))
return clip(val, min_val, max_val).astype(dtype)
def _constant_like(x, const):
return np.array(const, dtype=_dtype(x))

View File

@ -25,7 +25,7 @@ from jax import numpy as jnp
from jax import dtypes
from jax.core import NamedShape
from jax.api import jit, vmap
from jax._src.numpy.lax_numpy import _constant_like, asarray
from jax._src.numpy.lax_numpy import _constant_like, _convert_and_clip_integer, asarray
from jax.lib import xla_bridge
from jax.lib import xla_client
from jax.lib import cuda_prng
@ -441,20 +441,28 @@ def randint(key: jnp.ndarray,
def _randint(key, shape, minval, maxval, dtype):
_check_shape("randint", shape, np.shape(minval), np.shape(maxval))
if not jnp.issubdtype(dtype, np.integer):
raise TypeError("randint only accepts integer dtypes.")
raise TypeError(f"randint only accepts integer dtypes, got {dtype}")
minval = lax.convert_element_type(minval, dtype)
maxval = lax.convert_element_type(maxval, dtype)
minval = _asarray(minval)
maxval = _asarray(maxval)
if not jnp.issubdtype(minval.dtype, np.integer):
minval = minval.astype(int)
if not jnp.issubdtype(maxval.dtype, np.integer):
maxval = maxval.astype(int)
# Flag where maxval is greater than the maximum value of dtype
# in order to handle cases like randint(key, shape, 0, 256, 'uint8')
maxval_out_of_range = lax.gt(
maxval, _convert_and_clip_integer(jnp.array(jnp.iinfo(dtype).max, dtype), maxval.dtype))
minval = _convert_and_clip_integer(minval, dtype)
maxval = _convert_and_clip_integer(maxval, dtype)
minval = lax.broadcast_to_rank(minval, len(shape))
maxval = lax.broadcast_to_rank(maxval, len(shape))
nbits = jnp.iinfo(dtype).bits
if nbits not in (8, 16, 32, 64):
raise TypeError("randint only accepts 8-, 16-, 32-, or 64-bit dtypes.")
# if we don't have minval < maxval, just always return minval
# https://github.com/google/jax/issues/222
maxval = lax.max(lax.add(minval, np.array(1, dtype)), maxval)
raise TypeError(f"randint only accepts 8-, 16-, 32-, or 64-bit dtypes, got {dtype}")
# This algorithm is biased whenever (maxval - minval) is not a power of 2.
# We generate double the number of random bits required by the dtype so as to
@ -466,6 +474,18 @@ def _randint(key, shape, minval, maxval, dtype):
unsigned_dtype = _UINT_DTYPES[nbits]
span = lax.convert_element_type(maxval - minval, unsigned_dtype)
# Ensure that span=1 when maxval <= minval, so minval is always returned;
# https://github.com/google/jax/issues/222
span = lax.select(maxval <= minval, lax.full_like(span, 1), span)
# When maxval is out of range, the span has to be one larger.
# If span is already the maximum representable value, this will wrap to zero,
# causing remainders below to have no effect, which is the correct semantics.
span = lax.select(
maxval_out_of_range & (maxval > minval),
lax.add(span, lax._const(span, 1)),
span)
# To compute a remainder operation on an integer that might have twice as many
# bits as we can represent in the native unsigned dtype, we compute a
# multiplier equal to 2**nbits % span. To avoid overflow, we use the identity:

View File

@ -26,6 +26,7 @@ import scipy.stats
from jax import api
from jax import core
from jax import dtypes
from jax import grad
from jax import lax
from jax import numpy as jnp
@ -983,6 +984,34 @@ class LaxRandomTest(jtu.JaxTestCase):
api.jit(random.split)(key)
self.assertEqual(count[0], 1) # 1 for the argument device_put
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={dtype}", "dtype": dtype}
for dtype in int_dtypes + uint_dtypes))
def test_randint_bounds(self, dtype):
min = np.iinfo(dtype).min
max = np.iinfo(dtype).max
key = random.PRNGKey(1701)
shape = (10,)
if np.iinfo(dtype).bits < np.iinfo(dtypes.canonicalize_dtype(int)).bits:
expected = random.randint(key, shape, min, max, dtype)
self.assertArraysEqual(expected, random.randint(key, shape, min - 12345, max + 12345, dtype))
else:
self.assertRaises(OverflowError, random.randint, key, shape, min - 12345, max + 12345, dtype)
def test_randint_out_of_range(self):
key = random.PRNGKey(0)
r = random.randint(key, (10,), 255, 256, np.uint8)
self.assertAllClose(r, jnp.full_like(r, 255))
r = random.randint(key, (1000,), -128, 128, np.int8)
self.assertGreater((r == -128).sum(), 0)
self.assertGreater((r == 127).sum(), 0)
r = random.randint(key, (1000,), -1000, 1000, np.uint8)
self.assertGreater((r == 0).sum(), 0)
self.assertGreater((r == 255).sum(), 0)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())