From e5f4be55641235837d8601bbf54091cdb0cd2ffe Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 24 Oct 2024 13:07:33 +0200 Subject: [PATCH] [shape_poly] Expands support for random.choice `random.choice` uses `np.insert(arr.shape, new_shape)` which attempts to coerce all the values in `new_shape` to constants when `arr.shape` is constant. Replace use of `np.insert` with tuple slicing and concatenation. The case when the sampled axis has non-constant size and `replace=False` is not supported, because `permutation` on arrays with non-constant size is not supported. Adds tests for many combinations of arguments for `random.choice`. Improves a few error messages. --- jax/_src/prng.py | 5 +-- jax/_src/random.py | 14 ++++++-- .../jax2tf/tests/shape_poly_test.py | 2 +- tests/shape_poly_test.py | 36 ++++++++++++++++++- 4 files changed, 50 insertions(+), 7 deletions(-) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 7ca7db022..039b0a309 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -1067,8 +1067,9 @@ def threefry_2x32(keypair, count): odd_size = count.size % 2 if not isinstance(odd_size, int): - msg = ("jax.random functions have limited support for shape polymorphism. " - "In particular, the product of the known dimensions must be even.") + msg = ("jax.random functions have limited support for shape polymorphism " + "when using threefry. " + f"In particular, the array size ({count.size}) must be even.") raise core.InconclusiveDimensionOperation(msg) if odd_size: diff --git a/jax/_src/random.py b/jax/_src/random.py index 203f72d40..dc9fc18af 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -581,6 +581,10 @@ def _shuffle(key, x, axis) -> Array: # another analysis (where the keys are generated one bit at a time). exponent = 3 # see tjablin@'s analysis for explanation of this parameter uint32max = jnp.iinfo(np.uint32).max + if not core.is_constant_dim(x.size): + raise NotImplementedError( + "shape polymorphism for `permutation` or `shuffle`" + f" for arrays of non-constant size: {x.size}") num_rounds = int(np.ceil(exponent * np.log(max(1, x.size)) / np.log(uint32max))) for _ in range(num_rounds): @@ -640,7 +644,9 @@ def choice(key: KeyArrayLike, if n_inputs <= 0: raise ValueError("a must be greater than 0 unless no samples are taken") if not replace and n_draws > n_inputs: - raise ValueError("Cannot take a larger sample than population when 'replace=False'") + raise ValueError( + f"Cannot take a larger sample (size {n_draws}) than " + f"population (size {n_inputs}) when 'replace=False'") if p is None: if replace: @@ -653,7 +659,9 @@ def choice(key: KeyArrayLike, check_arraylike("choice", p) p_arr, = promote_dtypes_inexact(p) if p_arr.shape != (n_inputs,): - raise ValueError("p must be None or match the shape of a") + raise ValueError( + "p must be None or a 1D vector with the same size as a.shape[axis]. " + f"p has shape {p_arr.shape} and a.shape[axis] is {n_inputs}.") if replace: p_cuml = jnp.cumsum(p_arr) r = p_cuml[-1] * (1 - uniform(key, shape, dtype=p_cuml.dtype)) @@ -665,7 +673,7 @@ def choice(key: KeyArrayLike, result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis) return result.reshape(shape if arr.ndim == 0 else - np.insert(np.delete(arr.shape, axis), axis, shape)) + arr.shape[0:axis] + tuple(shape) + arr.shape[axis+1:]) def normal(key: KeyArrayLike, diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 38af6d9d7..7fdc6854d 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -2114,7 +2114,7 @@ _POLY_SHAPE_TEST_HARNESSES = [ polymorphic_shapes=[None, "b0, ..."], expect_error=( (core.InconclusiveDimensionOperation, - "the product of the known dimensions must be even") if flags_name == "threefry_non_partitionable" else (None, None)), + "array size .* must be even") if flags_name == "threefry_non_partitionable" else (None, None)), override_jax_config_flags=override_jax_config_flags) # type: ignore ] for key_size, flags_name, override_jax_config_flags in [ diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 1b213a8b5..ead77e2b5 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -2941,6 +2941,40 @@ _POLY_SHAPE_TEST_HARNESSES = [ RandArg((3, 5, 0), _f32)], polymorphic_shapes=[None, "b0, b1, ..."], override_jax_config_flags=override_jax_config_flags), # type: ignore + [ + PolyHarness("random_choice", f"{flags_name}_arr_poly={arr_poly}_shape_poly={shape_poly}_replace={replace}_use_p={use_p}", + lambda key, a, res_shape, use_p: jax.random.choice( + jax.random.wrap_key_data(key), + a, + shape=res_shape.shape, + p=jnp.full((a.shape[1],), 0.1, dtype=_f32) if use_p else None, + axis=1, + replace=replace), + arg_descriptors=[RandArg((key_size,), np.uint32), + RandArg((64, 12, 4), _f32), # sample on axis=1 + RandArg((3, 4), _f32), + StaticArg(use_p)], + # TODO(necula): threefry requires even-sized samples. + polymorphic_shapes=[None, + "_, 2*b1, _" if arr_poly else None, + "b3, b4" if shape_poly else None], + # The array sampled dimension must be larger than res_shape.size + symbolic_constraints=[ + "2*b1 >= 12" if arr_poly else "1 >= 0", + "2*b1 >= b3*b4" if arr_poly and shape_poly else "1 >= 0", + "12 >= b3*b4" if shape_poly else "1 >= 0" + ], + override_jax_config_flags=override_jax_config_flags, + expect_error=( + (NotImplementedError, "permutation") + if arr_poly and not use_p else None)) # type: ignore + # np.insert used in random.choice tries to coerce shape_poly to + # integer arrays, but only when the arr_poly is False. + for arr_poly in [True, False] + for shape_poly in [True, False] + for replace in [True, False] + for use_p in [True, False] + ], PolyHarness("random_split", f"{flags_name}", lambda key, a: jax.random.key_data( jax.random.split(jax.random.wrap_key_data(key), @@ -2971,7 +3005,7 @@ _POLY_SHAPE_TEST_HARNESSES = [ polymorphic_shapes=[None, "b0, ..."], expect_error=( (core.InconclusiveDimensionOperation, - "the product of the known dimensions must be even") if flags_name == "threefry_non_partitionable" else None), + "array size .* must be even") if flags_name == "threefry_non_partitionable" else None), override_jax_config_flags=override_jax_config_flags) # type: ignore ] for key_size, flags_name, override_jax_config_flags in [