Merge pull request #24500 from gnecula:poly_choice

PiperOrigin-RevId: 689792194
This commit is contained in:
jax authors 2024-10-25 08:10:52 -07:00
commit 8c6164a492
4 changed files with 50 additions and 7 deletions

View File

@ -1067,8 +1067,9 @@ def threefry_2x32(keypair, count):
odd_size = count.size % 2
if not isinstance(odd_size, int):
msg = ("jax.random functions have limited support for shape polymorphism. "
"In particular, the product of the known dimensions must be even.")
msg = ("jax.random functions have limited support for shape polymorphism "
"when using threefry. "
f"In particular, the array size ({count.size}) must be even.")
raise core.InconclusiveDimensionOperation(msg)
if odd_size:

View File

@ -581,6 +581,10 @@ def _shuffle(key, x, axis) -> Array:
# another analysis (where the keys are generated one bit at a time).
exponent = 3 # see tjablin@'s analysis for explanation of this parameter
uint32max = jnp.iinfo(np.uint32).max
if not core.is_constant_dim(x.size):
raise NotImplementedError(
"shape polymorphism for `permutation` or `shuffle`"
f" for arrays of non-constant size: {x.size}")
num_rounds = int(np.ceil(exponent * np.log(max(1, x.size)) / np.log(uint32max)))
for _ in range(num_rounds):
@ -640,7 +644,9 @@ def choice(key: KeyArrayLike,
if n_inputs <= 0:
raise ValueError("a must be greater than 0 unless no samples are taken")
if not replace and n_draws > n_inputs:
raise ValueError("Cannot take a larger sample than population when 'replace=False'")
raise ValueError(
f"Cannot take a larger sample (size {n_draws}) than "
f"population (size {n_inputs}) when 'replace=False'")
if p is None:
if replace:
@ -653,7 +659,9 @@ def choice(key: KeyArrayLike,
check_arraylike("choice", p)
p_arr, = promote_dtypes_inexact(p)
if p_arr.shape != (n_inputs,):
raise ValueError("p must be None or match the shape of a")
raise ValueError(
"p must be None or a 1D vector with the same size as a.shape[axis]. "
f"p has shape {p_arr.shape} and a.shape[axis] is {n_inputs}.")
if replace:
p_cuml = jnp.cumsum(p_arr)
r = p_cuml[-1] * (1 - uniform(key, shape, dtype=p_cuml.dtype))
@ -665,7 +673,7 @@ def choice(key: KeyArrayLike,
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
return result.reshape(shape if arr.ndim == 0 else
np.insert(np.delete(arr.shape, axis), axis, shape))
arr.shape[0:axis] + tuple(shape) + arr.shape[axis+1:])
def normal(key: KeyArrayLike,

View File

@ -2089,7 +2089,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
polymorphic_shapes=[None, "b0, ..."],
expect_error=(
(core.InconclusiveDimensionOperation,
"the product of the known dimensions must be even") if flags_name == "threefry_non_partitionable" else (None, None)),
"array size .* must be even") if flags_name == "threefry_non_partitionable" else (None, None)),
override_jax_config_flags=override_jax_config_flags) # type: ignore
]
for key_size, flags_name, override_jax_config_flags in [

View File

@ -2941,6 +2941,40 @@ _POLY_SHAPE_TEST_HARNESSES = [
RandArg((3, 5, 0), _f32)],
polymorphic_shapes=[None, "b0, b1, ..."],
override_jax_config_flags=override_jax_config_flags), # type: ignore
[
PolyHarness("random_choice", f"{flags_name}_arr_poly={arr_poly}_shape_poly={shape_poly}_replace={replace}_use_p={use_p}",
lambda key, a, res_shape, use_p: jax.random.choice(
jax.random.wrap_key_data(key),
a,
shape=res_shape.shape,
p=jnp.full((a.shape[1],), 0.1, dtype=_f32) if use_p else None,
axis=1,
replace=replace),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((64, 12, 4), _f32), # sample on axis=1
RandArg((3, 4), _f32),
StaticArg(use_p)],
# TODO(necula): threefry requires even-sized samples.
polymorphic_shapes=[None,
"_, 2*b1, _" if arr_poly else None,
"b3, b4" if shape_poly else None],
# The array sampled dimension must be larger than res_shape.size
symbolic_constraints=[
"2*b1 >= 12" if arr_poly else "1 >= 0",
"2*b1 >= b3*b4" if arr_poly and shape_poly else "1 >= 0",
"12 >= b3*b4" if shape_poly else "1 >= 0"
],
override_jax_config_flags=override_jax_config_flags,
expect_error=(
(NotImplementedError, "permutation")
if arr_poly and not use_p else None)) # type: ignore
# np.insert used in random.choice tries to coerce shape_poly to
# integer arrays, but only when the arr_poly is False.
for arr_poly in [True, False]
for shape_poly in [True, False]
for replace in [True, False]
for use_p in [True, False]
],
PolyHarness("random_split", f"{flags_name}",
lambda key, a: jax.random.key_data(
jax.random.split(jax.random.wrap_key_data(key),
@ -2971,7 +3005,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
polymorphic_shapes=[None, "b0, ..."],
expect_error=(
(core.InconclusiveDimensionOperation,
"the product of the known dimensions must be even") if flags_name == "threefry_non_partitionable" else None),
"array size .* must be even") if flags_name == "threefry_non_partitionable" else None),
override_jax_config_flags=override_jax_config_flags) # type: ignore
]
for key_size, flags_name, override_jax_config_flags in [