mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #5868 from jakevdp:fix-randint
PiperOrigin-RevId: 366265880
This commit is contained in:
commit
9ca11d4de9
@ -330,6 +330,47 @@ def _promote_args_inexact(fun_name, *args):
|
||||
_check_no_float0s(fun_name, *args)
|
||||
return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))
|
||||
|
||||
def _convert_and_clip_integer(val, dtype):
|
||||
"""
|
||||
Convert integer-typed val to specified integer dtype, clipping to dtype
|
||||
range rather than wrapping.
|
||||
|
||||
Args:
|
||||
val: value to be converted
|
||||
dtype: dtype of output
|
||||
|
||||
Returns:
|
||||
equivalent of val in new dtype
|
||||
|
||||
Examples
|
||||
--------
|
||||
Normal integer type conversion will wrap:
|
||||
|
||||
>>> val = jnp.uint32(0xFFFFFFFF)
|
||||
>>> val.astype('int32')
|
||||
DeviceArray(-1, dtype=int32)
|
||||
|
||||
This function clips to the values representable in the new type:
|
||||
|
||||
>>> _convert_and_clip_integer(val, 'int32')
|
||||
DeviceArray(2147483647, dtype=int32)
|
||||
"""
|
||||
val = val if isinstance(val, ndarray) else asarray(val)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if not (issubdtype(dtype, integer) and issubdtype(val.dtype, integer)):
|
||||
raise TypeError("_convert_and_clip_integer only accepts integer dtypes.")
|
||||
|
||||
val_dtype = dtypes.canonicalize_dtype(val.dtype)
|
||||
if val_dtype != val.dtype:
|
||||
# TODO(jakevdp): this is a weird corner case; need to figure out how to handle it.
|
||||
# This happens in X32 mode and can either come from a jax value created in another
|
||||
# context, or a Python integer converted to int64.
|
||||
pass
|
||||
min_val = _constant_like(val, _max(iinfo(dtype).min, iinfo(val_dtype).min))
|
||||
max_val = _constant_like(val, _min(iinfo(dtype).max, iinfo(val_dtype).max))
|
||||
return clip(val, min_val, max_val).astype(dtype)
|
||||
|
||||
|
||||
def _constant_like(x, const):
|
||||
return np.array(const, dtype=_dtype(x))
|
||||
|
||||
|
@ -25,7 +25,7 @@ from jax import numpy as jnp
|
||||
from jax import dtypes
|
||||
from jax.core import NamedShape
|
||||
from jax.api import jit, vmap
|
||||
from jax._src.numpy.lax_numpy import _constant_like, asarray
|
||||
from jax._src.numpy.lax_numpy import _constant_like, _convert_and_clip_integer, asarray
|
||||
from jax.lib import xla_bridge
|
||||
from jax.lib import xla_client
|
||||
from jax.lib import cuda_prng
|
||||
@ -441,20 +441,28 @@ def randint(key: jnp.ndarray,
|
||||
def _randint(key, shape, minval, maxval, dtype):
|
||||
_check_shape("randint", shape, np.shape(minval), np.shape(maxval))
|
||||
if not jnp.issubdtype(dtype, np.integer):
|
||||
raise TypeError("randint only accepts integer dtypes.")
|
||||
raise TypeError(f"randint only accepts integer dtypes, got {dtype}")
|
||||
|
||||
minval = lax.convert_element_type(minval, dtype)
|
||||
maxval = lax.convert_element_type(maxval, dtype)
|
||||
minval = _asarray(minval)
|
||||
maxval = _asarray(maxval)
|
||||
if not jnp.issubdtype(minval.dtype, np.integer):
|
||||
minval = minval.astype(int)
|
||||
if not jnp.issubdtype(maxval.dtype, np.integer):
|
||||
maxval = maxval.astype(int)
|
||||
|
||||
# Flag where maxval is greater than the maximum value of dtype
|
||||
# in order to handle cases like randint(key, shape, 0, 256, 'uint8')
|
||||
maxval_out_of_range = lax.gt(
|
||||
maxval, _convert_and_clip_integer(jnp.array(jnp.iinfo(dtype).max, dtype), maxval.dtype))
|
||||
|
||||
minval = _convert_and_clip_integer(minval, dtype)
|
||||
maxval = _convert_and_clip_integer(maxval, dtype)
|
||||
minval = lax.broadcast_to_rank(minval, len(shape))
|
||||
maxval = lax.broadcast_to_rank(maxval, len(shape))
|
||||
nbits = jnp.iinfo(dtype).bits
|
||||
|
||||
if nbits not in (8, 16, 32, 64):
|
||||
raise TypeError("randint only accepts 8-, 16-, 32-, or 64-bit dtypes.")
|
||||
|
||||
# if we don't have minval < maxval, just always return minval
|
||||
# https://github.com/google/jax/issues/222
|
||||
maxval = lax.max(lax.add(minval, np.array(1, dtype)), maxval)
|
||||
raise TypeError(f"randint only accepts 8-, 16-, 32-, or 64-bit dtypes, got {dtype}")
|
||||
|
||||
# This algorithm is biased whenever (maxval - minval) is not a power of 2.
|
||||
# We generate double the number of random bits required by the dtype so as to
|
||||
@ -466,6 +474,18 @@ def _randint(key, shape, minval, maxval, dtype):
|
||||
unsigned_dtype = _UINT_DTYPES[nbits]
|
||||
span = lax.convert_element_type(maxval - minval, unsigned_dtype)
|
||||
|
||||
# Ensure that span=1 when maxval <= minval, so minval is always returned;
|
||||
# https://github.com/google/jax/issues/222
|
||||
span = lax.select(maxval <= minval, lax.full_like(span, 1), span)
|
||||
|
||||
# When maxval is out of range, the span has to be one larger.
|
||||
# If span is already the maximum representable value, this will wrap to zero,
|
||||
# causing remainders below to have no effect, which is the correct semantics.
|
||||
span = lax.select(
|
||||
maxval_out_of_range & (maxval > minval),
|
||||
lax.add(span, lax._const(span, 1)),
|
||||
span)
|
||||
|
||||
# To compute a remainder operation on an integer that might have twice as many
|
||||
# bits as we can represent in the native unsigned dtype, we compute a
|
||||
# multiplier equal to 2**nbits % span. To avoid overflow, we use the identity:
|
||||
|
@ -26,6 +26,7 @@ import scipy.stats
|
||||
|
||||
from jax import api
|
||||
from jax import core
|
||||
from jax import dtypes
|
||||
from jax import grad
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
@ -983,6 +984,34 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
api.jit(random.split)(key)
|
||||
self.assertEqual(count[0], 1) # 1 for the argument device_put
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={dtype}", "dtype": dtype}
|
||||
for dtype in int_dtypes + uint_dtypes))
|
||||
def test_randint_bounds(self, dtype):
|
||||
min = np.iinfo(dtype).min
|
||||
max = np.iinfo(dtype).max
|
||||
key = random.PRNGKey(1701)
|
||||
shape = (10,)
|
||||
if np.iinfo(dtype).bits < np.iinfo(dtypes.canonicalize_dtype(int)).bits:
|
||||
expected = random.randint(key, shape, min, max, dtype)
|
||||
self.assertArraysEqual(expected, random.randint(key, shape, min - 12345, max + 12345, dtype))
|
||||
else:
|
||||
self.assertRaises(OverflowError, random.randint, key, shape, min - 12345, max + 12345, dtype)
|
||||
|
||||
def test_randint_out_of_range(self):
|
||||
key = random.PRNGKey(0)
|
||||
|
||||
r = random.randint(key, (10,), 255, 256, np.uint8)
|
||||
self.assertAllClose(r, jnp.full_like(r, 255))
|
||||
|
||||
r = random.randint(key, (1000,), -128, 128, np.int8)
|
||||
self.assertGreater((r == -128).sum(), 0)
|
||||
self.assertGreater((r == 127).sum(), 0)
|
||||
|
||||
r = random.randint(key, (1000,), -1000, 1000, np.uint8)
|
||||
self.assertGreater((r == 0).sum(), 0)
|
||||
self.assertGreater((r == 255).sum(), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user