From 4b81680b629c1c7633ab37e41062a11da8b73094 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 12 Jun 2024 14:36:31 -0700 Subject: [PATCH] [Pallas] Allow keys as input to Pallas kernels. PiperOrigin-RevId: 642740833 --- jax/_src/pallas/mosaic/lowering.py | 97 ++++++++++++++++++- .../pallas/mosaic/pallas_call_registration.py | 8 ++ jax/_src/pallas/mosaic/random.py | 58 ++++------- tests/pallas/tpu/pallas_random_test.py | 55 ++++++++--- 4 files changed, 164 insertions(+), 54 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 9c80cc08e..dd558fcc7 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -28,6 +28,7 @@ from jax import tree_util from jax._src import ad_util from jax._src import custom_derivatives from jax._src import debugging +from jax._src import dtypes from jax._src import linear_util as lu from jax._src import mesh as mesh_lib from jax._src import pjit @@ -150,6 +151,15 @@ def aval_to_ir_type(aval, shape=None, memory_space: TPUMemorySpace | None = None raise ValueError(f"Cannot allocate {aval.sem_type}.") memspace = _memory_space_to_tpu_memspace(TPUMemorySpace.SEMAPHORE) return ir.MemRefType.get((), sem_type, memory_space=memspace) + if dtypes.issubdtype(aval.dtype, dtypes.prng_key): + shape = aval.dtype._impl.key_shape + if memory_space is None: + memory_space = TPUMemorySpace.SMEM + if memory_space != TPUMemorySpace.SMEM: + raise ValueError(f"PRNG keys must be stored in SMEM. Got {memory_space}") + memspace = _memory_space_to_tpu_memspace(memory_space) + return ir.MemRefType.get(shape, _dtype_to_ir_type(np.dtype(np.uint32)), + memory_space=memspace) if isinstance(aval, state.AbstractRef): if shape is None: shape = aval.shape @@ -626,7 +636,8 @@ def jaxpr_subcomp( return atom.val if isinstance(atom, jax_core.Literal) else env[atom] def write_env(var: jax_core.Var, val): - assert isinstance(val, ir.Value), type(val) + is_valid_type = isinstance(val, (ir.Value, KeyScalarBundle)) + assert is_valid_type, type(val) env[var] = val for invar, bs in zip(jaxpr.invars, ctx.block_shapes): @@ -703,6 +714,8 @@ def jaxpr_subcomp( def _ensure_mlir_value(val, aval): if isinstance(val, ir.Value): return val + if isinstance(val, KeyScalarBundle): + return val elif isinstance(val, (np.generic, np.ndarray, int, float)): return ir_constant(val, _dtype_to_ir_type(aval.dtype)) else: @@ -898,6 +911,21 @@ def _index_ref(ref, ref_aval, ref_block_shape, indexers): ref_block_shape) return ref, ref_block_shape +@dataclasses.dataclass(frozen=True) +class KeyScalarBundle: + """A container class for PRNG key data. + + We pass around keys as a KeyScalarBundle in the lowering pass rather than + as a vector, since we want the key data to live in scalar registers rather + than vector registers. This special dataclass exists so we can return + multiple scalar values from load_op, because the load_op primitive does + not allow multiple results. + + Attributes: + scalars: A list of OpResults representing scalar key data during the + lowering pass. + """ + scalars: list[ir.OpResult] def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): ref, indexers, mask, _ = args_tree.unflatten(args_flat) @@ -916,6 +944,12 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space" ref_aval, *_ = ctx.avals_in (aval_out,) = ctx.avals_out + if isinstance(aval_out.dtype, prng.KeyTy): + if not is_smem_load: + raise ValueError("PRNG keys must be loaded from SMEM. Did you set " + "the memory space to TPUMemorySpace.SMEM in the " + "BlockSpec for the PRNG key input?") + return _prng_key_load_lowering_rule(ctx, *args_flat, args_tree=args_tree) if not is_smem_load and not ref_block_shape: raise NotImplementedError( "Indexing into a ()-shaped Ref not yet supported on TPU.") @@ -947,6 +981,37 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): _dtype_to_ir_type(aval_out.dtype)) return vector.ShapeCastOp(vec_type, load_val).result +def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree) -> KeyScalarBundle: + """Lowering rule for loading PRNG keys from SMEM. + + PRNG key loads are currently lowered as a list of scalar loads from SMEM, + rather than a single vector load. + We store these scalars in a bundle type called KeyScalarBundle, which has + special case handling for functions that consume the key such as set_seed. + """ + ref, _, _, _ = args_tree.unflatten(args_flat) + (aval_out,) = ctx.avals_out + assert isinstance(aval_out.dtype, prng.KeyTy) + ref_block_shape = aval_out.dtype._impl.key_shape + + if len(ref_block_shape) != 1: + raise NotImplementedError("Seed key_data must be 1D.") + if ref_block_shape[0] != 1: + raise NotImplementedError("Seed key_data of shape != (1,) not supported.") + + load_ops = [] + for i in range(ref_block_shape[0]): + idx = NDIndexer(indices=(i,), shape=ref_block_shape, + int_indexer_shape=tuple()) + starts, _, _, _, _ = _indexer_to_start_size_stride( + idx, + ref_block_shape, + cast_to_index=True, + ) + load_ops.append(memref.LoadOp(ref, starts).result) + return KeyScalarBundle(scalars=load_ops) + + lowering_rules[primitives.load_p] = _load_lowering_rule skip_mlir_conversions.add(primitives.load_p) @@ -2393,6 +2458,15 @@ lowering_rules[primitives.debug_print_p] = _debug_print_rule def _prng_seed_lowering_rule(ctx: LoweringRuleContext, *seeds): del ctx + # In the KeyScalarBundle case we unpack the bundle and set the seed with + # the list of scalars. + if len(seeds) == 1 and isinstance(seeds[0], KeyScalarBundle): + return tpu.PRNGSeed32Op(seeds[0].scalars).results + # For integer seeds, we can set the seed directly as PRNGSeed32Op natively + # takes in a list of integers as input. + all_integers = all(isinstance(seed.type, ir.IntegerType) for seed in seeds) + if not all_integers: + raise ValueError("All seed data must be integers.") return tpu.PRNGSeed32Op(seeds).results lowering_rules[tpu_primitives.prng_seed_p] = _prng_seed_lowering_rule @@ -2434,12 +2508,27 @@ lowering_rules[prng.random_fold_in_p] = random_fold_in_lowering def random_unwrap_lowering(ctx, key): - del ctx - return key + del ctx, key + raise NotImplementedError("key_data not implemented.") lowering_rules[prng.random_unwrap_p] = random_unwrap_lowering def random_wrap_lowering(ctx, key_data, *, impl): del ctx, impl - return key_data + if isinstance(key_data.type, ir.VectorType): + # If the key data lives in vregs, need to unpack it to sregs. + key_data_list = [] + key_data_shape = key_data.type.shape + if len(key_data_shape) != 1: + raise NotImplementedError("Seed key_data must be 1D.") + if key_data_shape[0] != 1: + raise NotImplementedError("key_data with shape != (1,) not supported.") + for i in range(key_data_shape[0]): + key_data_list.append(vector.ExtractOp(key_data, [], [i])) + return KeyScalarBundle(scalars=key_data_list) + if isinstance(key_data, KeyScalarBundle): + return key_data + else: + raise NotImplementedError(f"key_data wrap {type(key_data)}") + lowering_rules[prng.random_wrap_p] = random_wrap_lowering diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index c95a146f2..16ae315f9 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -21,6 +21,7 @@ import warnings import jax from jax import core as jax_core +from jax._src import core as jax_src_core from jax._src import sharding_impls from jax._src.interpreters import mlir from jax._src.lib.mlir import ir @@ -90,6 +91,13 @@ def pallas_call_tpu_lowering_rule( for a in input_output_aliases ) out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes] + + # Replace in_avals to physical avals. + # This step is required for mapping logical types to physical types. + # (e.g. PRNG key -> uint32[2]) + physical_avals = [jax_src_core.physical_aval(aval) for aval in ctx.avals_in] + ctx = ctx.replace(avals_in=physical_avals) + def _lower_fun(*args): # Dynamic grid bounds have to go at the front. dynamic_grid_args, args = args[:num_dyn_bounds], args[num_dyn_bounds:], diff --git a/jax/_src/pallas/mosaic/random.py b/jax/_src/pallas/mosaic/random.py index b85d5b559..89e18f637 100644 --- a/jax/_src/pallas/mosaic/random.py +++ b/jax/_src/pallas/mosaic/random.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Union +from typing import Any, Callable import jax import numpy as np @@ -24,8 +24,13 @@ from jax._src import prng as jax_prng Shape = jax_prng.Shape +SampleFnType = Any +KeylessSampleFnType = Callable[..., jax.Array] + +set_seed = prng_seed + FOLD_IN_ROUNDS = 128 -SUPPORTED_CONVERSION_KEYS = ["rbg", "unsafe_rbg", "pallas"] +SUPPORTED_CONVERSION_KEYS = ["rbg", "unsafe_rbg", "pallas_tpu"] def to_pallas_key(key: jax_prng.PRNGKeyArray) -> jax_prng.PRNGKeyArray: """Helper function for converting non-Pallas PRNG keys into Pallas keys.""" @@ -45,7 +50,7 @@ def to_pallas_key(key: jax_prng.PRNGKeyArray) -> jax_prng.PRNGKeyArray: raise ValueError(f"Key data must be at least {pallas_key_size} bytes.") pallas_key_data = jnp.ravel(key_data)[:pallas_key_size] pallas_key_data = jnp.reshape(pallas_key_data, tpu_key_impl.key_shape) - return jax_api_random.wrap_key_data(pallas_key_data, impl='pallas') + return jax_api_random.wrap_key_data(pallas_key_data, impl="pallas_tpu") def _seed_func(seed: jnp.int32): seed_data = jnp.zeros(tpu_key_impl.key_shape, dtype=jnp.int32) @@ -53,12 +58,8 @@ def _seed_func(seed: jnp.int32): def _random_bits(key: typing.Array, bit_width: int, shape: Shape): if bit_width != 32: - raise NotImplementedError("Bit width must be 32") - if isinstance(key.dtype, jax_prng.KeyTy): - key_data = jax.random.key_data(key) - else: - key_data = key - prng_seed(key_data[0, 0]) + raise ValueError("Bit width must be 32") + prng_seed(key) return prng_random_bits(shape) def _fold_in(key: jax_prng.PRNGKeyArray, data: typing.Array): @@ -67,17 +68,17 @@ def _fold_in(key: jax_prng.PRNGKeyArray, data: typing.Array): # Because the TPU generates random numbers in (8, 128) blocks at once, we # can generate that many values without additional cost which will reduce # correlation between the old and new keys. - key_data = jax.random.key_data(key) + key_shape = tpu_key_impl.key_shape prng_seed(data) data_bits = prng_random_bits( - key_data.shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32) - prng_seed(key_data[0, 0]) + key_shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32) + prng_seed(key) key_bits = prng_random_bits( - key_data.shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32) + key_shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32) mixed = key_bits[..., FOLD_IN_ROUNDS-1] ^ data_bits[..., FOLD_IN_ROUNDS-1] - assert mixed.shape == key_data.shape + assert mixed.shape == key_shape impl: jax_prng.PRNGSpec = jax.random.key_impl(key) # type: ignore return jax.random.wrap_key_data(mixed, impl=impl) @@ -86,13 +87,12 @@ def _split(key: typing.Array, shape: Shape): raise NotImplementedError() tpu_key_impl = jax_prng.PRNGImpl( - # Use a 2D key since pallas only supports 2D tiling. - key_shape=(1, 1), + key_shape=(1,), seed=_seed_func, split=_split, random_bits=_random_bits, fold_in=_fold_in, - name="pallas", + name="pallas_tpu", tag="pl" ) jax_prng.register_prng(tpu_key_impl) @@ -135,26 +135,6 @@ tpu_internal_stateful_impl = jax_prng.PRNGImpl( ) jax_prng.register_prng(tpu_internal_stateful_impl) -def set_seed(seed: Union[jnp.int32, jax.Array]): - """Sets the seed for PRNG. - - Args: - seeds: An integer seed for setting the PRNG seed. - """ - if isinstance(seed, jax.Array): - if seed.ndim != 1: - raise ValueError("Seed data must be a scalar or 1D array") - # TODO(justinfu): Mosaic currently only supports indexing by 0 - # for scalar results when using vector.extract - # After support is added, use all seed data. - prng_seed(seed[0]) - else: - prng_seed(seed) - - -SampleFnType = Any -KeylessSampleFnType = Callable[..., jax.Array] - def _make_stateful_sampler(sampler: SampleFnType) -> KeylessSampleFnType: """Converts a jax.random sampling function to a stateful version. @@ -174,8 +154,8 @@ def _make_stateful_sampler(sampler: SampleFnType) -> KeylessSampleFnType: return sampler(placeholder_key, *args, **kwargs) # Remove key argument from docstring. doc_lines = filter( - lambda line: 'key:' not in line, sampler.__doc__.split('\n')) - new_sampler.__doc__ = '\n'.join(doc_lines) + lambda line: "key:" not in line, sampler.__doc__.split("\n")) + new_sampler.__doc__ = "\n".join(doc_lines) return new_sampler bits = _make_stateful_sampler(jax_api_random.bits) diff --git a/tests/pallas/tpu/pallas_random_test.py b/tests/pallas/tpu/pallas_random_test.py index 590ed7da3..0a6c22dea 100644 --- a/tests/pallas/tpu/pallas_random_test.py +++ b/tests/pallas/tpu/pallas_random_test.py @@ -18,6 +18,7 @@ from absl.testing import parameterized import jax from jax import random as jax_random from jax._src import test_util as jtu +from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import random as plrandom from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu @@ -27,6 +28,7 @@ import numpy as np jax.config.parse_flags_with_absl() + class PRNGTest(jtu.JaxTestCase): def setUp(self): @@ -90,32 +92,63 @@ class PRNGTest(jtu.JaxTestCase): # Test stateful RNG using the jax.random API wrappers. def body(key_ref, o_ref): plrandom.set_seed(key_ref[...]) - samples = plrandom.uniform( + o_ref[...] = plrandom.uniform( shape=o_ref[...].shape, minval=0.0, maxval=1.0) - o_ref[...] = samples - key = jax_random.key_data(jax_random.key(0, impl="rbg")) + rbg_key = jax_random.key(0, impl="rbg") + key = plrandom.to_pallas_key(rbg_key) o_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) - result = pl.pallas_call(body, out_shape=o_shape)(key) + result = pl.pallas_call( + body, + in_specs=[pl.BlockSpec(memory_space=tpu_core.TPUMemorySpace.SMEM)], + out_shape=o_shape, + )(key) self.assertGreaterEqual(jnp.min(result), 0) self.assertLessEqual(jnp.max(result), 1.0) def test_stateless_uniform_sample(self): # Test keyed RNG using the jax.random API. def body(key_ref, o_ref): - key = jax_random.wrap_key_data(key_ref[...], impl="pallas") - samples = jax_random.uniform( - key, shape=o_ref[...].shape, minval=0.0, maxval=1.0) - o_ref[...] = samples + o_ref[...] = jax_random.uniform( + key_ref[...], shape=o_ref[...].shape, minval=0.0, maxval=1.0 + ) rbg_key = jax_random.key(0, impl="rbg") - pallas_key = plrandom.to_pallas_key(rbg_key) - key = jax_random.key_data(pallas_key) + key = plrandom.to_pallas_key(rbg_key) o_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) - result = pl.pallas_call(body, out_shape=o_shape)(key) + result = pl.pallas_call( + body, + in_specs=[pl.BlockSpec(memory_space=tpu_core.TPUMemorySpace.SMEM)], + out_shape=o_shape, + )(key) self.assertGreaterEqual(jnp.min(result), 0) self.assertLessEqual(jnp.max(result), 1.0) + def test_fold_in(self): + # Test that folding in a value results in different random numbers. + def body(key_ref, o_ref): + key = key_ref[...] + o_ref[0, ...] = jax_random.uniform( + key, shape=o_ref[0, ...].shape, minval=0.0, maxval=1.0 + ) + + key = jax_random.fold_in(key, 2) + o_ref[1, ...] = jax_random.uniform( + key, shape=o_ref[1, ...].shape, minval=0.0, maxval=1.0 + ) + + rbg_key = jax_random.key(0, impl="rbg") + key = plrandom.to_pallas_key(rbg_key) + o_shape = jax.ShapeDtypeStruct((2, 8, 128), jnp.float32) + result = pl.pallas_call( + body, + in_specs=[pl.BlockSpec(memory_space=tpu_core.TPUMemorySpace.SMEM)], + out_shape=o_shape, + )(key) + result_a = result[0] + result_b = result[1] + np.testing.assert_array_compare(np.not_equal, result_a, result_b) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())