mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #24500 from gnecula:poly_choice
PiperOrigin-RevId: 689792194
This commit is contained in:
commit
8c6164a492
@ -1067,8 +1067,9 @@ def threefry_2x32(keypair, count):
|
||||
|
||||
odd_size = count.size % 2
|
||||
if not isinstance(odd_size, int):
|
||||
msg = ("jax.random functions have limited support for shape polymorphism. "
|
||||
"In particular, the product of the known dimensions must be even.")
|
||||
msg = ("jax.random functions have limited support for shape polymorphism "
|
||||
"when using threefry. "
|
||||
f"In particular, the array size ({count.size}) must be even.")
|
||||
raise core.InconclusiveDimensionOperation(msg)
|
||||
|
||||
if odd_size:
|
||||
|
@ -581,6 +581,10 @@ def _shuffle(key, x, axis) -> Array:
|
||||
# another analysis (where the keys are generated one bit at a time).
|
||||
exponent = 3 # see tjablin@'s analysis for explanation of this parameter
|
||||
uint32max = jnp.iinfo(np.uint32).max
|
||||
if not core.is_constant_dim(x.size):
|
||||
raise NotImplementedError(
|
||||
"shape polymorphism for `permutation` or `shuffle`"
|
||||
f" for arrays of non-constant size: {x.size}")
|
||||
num_rounds = int(np.ceil(exponent * np.log(max(1, x.size)) / np.log(uint32max)))
|
||||
|
||||
for _ in range(num_rounds):
|
||||
@ -640,7 +644,9 @@ def choice(key: KeyArrayLike,
|
||||
if n_inputs <= 0:
|
||||
raise ValueError("a must be greater than 0 unless no samples are taken")
|
||||
if not replace and n_draws > n_inputs:
|
||||
raise ValueError("Cannot take a larger sample than population when 'replace=False'")
|
||||
raise ValueError(
|
||||
f"Cannot take a larger sample (size {n_draws}) than "
|
||||
f"population (size {n_inputs}) when 'replace=False'")
|
||||
|
||||
if p is None:
|
||||
if replace:
|
||||
@ -653,7 +659,9 @@ def choice(key: KeyArrayLike,
|
||||
check_arraylike("choice", p)
|
||||
p_arr, = promote_dtypes_inexact(p)
|
||||
if p_arr.shape != (n_inputs,):
|
||||
raise ValueError("p must be None or match the shape of a")
|
||||
raise ValueError(
|
||||
"p must be None or a 1D vector with the same size as a.shape[axis]. "
|
||||
f"p has shape {p_arr.shape} and a.shape[axis] is {n_inputs}.")
|
||||
if replace:
|
||||
p_cuml = jnp.cumsum(p_arr)
|
||||
r = p_cuml[-1] * (1 - uniform(key, shape, dtype=p_cuml.dtype))
|
||||
@ -665,7 +673,7 @@ def choice(key: KeyArrayLike,
|
||||
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
|
||||
|
||||
return result.reshape(shape if arr.ndim == 0 else
|
||||
np.insert(np.delete(arr.shape, axis), axis, shape))
|
||||
arr.shape[0:axis] + tuple(shape) + arr.shape[axis+1:])
|
||||
|
||||
|
||||
def normal(key: KeyArrayLike,
|
||||
|
@ -2089,7 +2089,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
polymorphic_shapes=[None, "b0, ..."],
|
||||
expect_error=(
|
||||
(core.InconclusiveDimensionOperation,
|
||||
"the product of the known dimensions must be even") if flags_name == "threefry_non_partitionable" else (None, None)),
|
||||
"array size .* must be even") if flags_name == "threefry_non_partitionable" else (None, None)),
|
||||
override_jax_config_flags=override_jax_config_flags) # type: ignore
|
||||
]
|
||||
for key_size, flags_name, override_jax_config_flags in [
|
||||
|
@ -2941,6 +2941,40 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
RandArg((3, 5, 0), _f32)],
|
||||
polymorphic_shapes=[None, "b0, b1, ..."],
|
||||
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
||||
[
|
||||
PolyHarness("random_choice", f"{flags_name}_arr_poly={arr_poly}_shape_poly={shape_poly}_replace={replace}_use_p={use_p}",
|
||||
lambda key, a, res_shape, use_p: jax.random.choice(
|
||||
jax.random.wrap_key_data(key),
|
||||
a,
|
||||
shape=res_shape.shape,
|
||||
p=jnp.full((a.shape[1],), 0.1, dtype=_f32) if use_p else None,
|
||||
axis=1,
|
||||
replace=replace),
|
||||
arg_descriptors=[RandArg((key_size,), np.uint32),
|
||||
RandArg((64, 12, 4), _f32), # sample on axis=1
|
||||
RandArg((3, 4), _f32),
|
||||
StaticArg(use_p)],
|
||||
# TODO(necula): threefry requires even-sized samples.
|
||||
polymorphic_shapes=[None,
|
||||
"_, 2*b1, _" if arr_poly else None,
|
||||
"b3, b4" if shape_poly else None],
|
||||
# The array sampled dimension must be larger than res_shape.size
|
||||
symbolic_constraints=[
|
||||
"2*b1 >= 12" if arr_poly else "1 >= 0",
|
||||
"2*b1 >= b3*b4" if arr_poly and shape_poly else "1 >= 0",
|
||||
"12 >= b3*b4" if shape_poly else "1 >= 0"
|
||||
],
|
||||
override_jax_config_flags=override_jax_config_flags,
|
||||
expect_error=(
|
||||
(NotImplementedError, "permutation")
|
||||
if arr_poly and not use_p else None)) # type: ignore
|
||||
# np.insert used in random.choice tries to coerce shape_poly to
|
||||
# integer arrays, but only when the arr_poly is False.
|
||||
for arr_poly in [True, False]
|
||||
for shape_poly in [True, False]
|
||||
for replace in [True, False]
|
||||
for use_p in [True, False]
|
||||
],
|
||||
PolyHarness("random_split", f"{flags_name}",
|
||||
lambda key, a: jax.random.key_data(
|
||||
jax.random.split(jax.random.wrap_key_data(key),
|
||||
@ -2971,7 +3005,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
polymorphic_shapes=[None, "b0, ..."],
|
||||
expect_error=(
|
||||
(core.InconclusiveDimensionOperation,
|
||||
"the product of the known dimensions must be even") if flags_name == "threefry_non_partitionable" else None),
|
||||
"array size .* must be even") if flags_name == "threefry_non_partitionable" else None),
|
||||
override_jax_config_flags=override_jax_config_flags) # type: ignore
|
||||
]
|
||||
for key_size, flags_name, override_jax_config_flags in [
|
||||
|
Loading…
x
Reference in New Issue
Block a user