[Pallas] Allow keys as input to Pallas kernels.

PiperOrigin-RevId: 642740833
This commit is contained in:
Justin Fu 2024-06-12 14:36:31 -07:00 committed by jax authors
parent b7a8f9d584
commit 4b81680b62
4 changed files with 164 additions and 54 deletions

View File

@ -28,6 +28,7 @@ from jax import tree_util
from jax._src import ad_util
from jax._src import custom_derivatives
from jax._src import debugging
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import pjit
@ -150,6 +151,15 @@ def aval_to_ir_type(aval, shape=None, memory_space: TPUMemorySpace | None = None
raise ValueError(f"Cannot allocate {aval.sem_type}.")
memspace = _memory_space_to_tpu_memspace(TPUMemorySpace.SEMAPHORE)
return ir.MemRefType.get((), sem_type, memory_space=memspace)
if dtypes.issubdtype(aval.dtype, dtypes.prng_key):
shape = aval.dtype._impl.key_shape
if memory_space is None:
memory_space = TPUMemorySpace.SMEM
if memory_space != TPUMemorySpace.SMEM:
raise ValueError(f"PRNG keys must be stored in SMEM. Got {memory_space}")
memspace = _memory_space_to_tpu_memspace(memory_space)
return ir.MemRefType.get(shape, _dtype_to_ir_type(np.dtype(np.uint32)),
memory_space=memspace)
if isinstance(aval, state.AbstractRef):
if shape is None:
shape = aval.shape
@ -626,7 +636,8 @@ def jaxpr_subcomp(
return atom.val if isinstance(atom, jax_core.Literal) else env[atom]
def write_env(var: jax_core.Var, val):
assert isinstance(val, ir.Value), type(val)
is_valid_type = isinstance(val, (ir.Value, KeyScalarBundle))
assert is_valid_type, type(val)
env[var] = val
for invar, bs in zip(jaxpr.invars, ctx.block_shapes):
@ -703,6 +714,8 @@ def jaxpr_subcomp(
def _ensure_mlir_value(val, aval):
if isinstance(val, ir.Value):
return val
if isinstance(val, KeyScalarBundle):
return val
elif isinstance(val, (np.generic, np.ndarray, int, float)):
return ir_constant(val, _dtype_to_ir_type(aval.dtype))
else:
@ -898,6 +911,21 @@ def _index_ref(ref, ref_aval, ref_block_shape, indexers):
ref_block_shape)
return ref, ref_block_shape
@dataclasses.dataclass(frozen=True)
class KeyScalarBundle:
"""A container class for PRNG key data.
We pass around keys as a KeyScalarBundle in the lowering pass rather than
as a vector, since we want the key data to live in scalar registers rather
than vector registers. This special dataclass exists so we can return
multiple scalar values from load_op, because the load_op primitive does
not allow multiple results.
Attributes:
scalars: A list of OpResults representing scalar key data during the
lowering pass.
"""
scalars: list[ir.OpResult]
def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
ref, indexers, mask, _ = args_tree.unflatten(args_flat)
@ -916,6 +944,12 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space<smem>"
ref_aval, *_ = ctx.avals_in
(aval_out,) = ctx.avals_out
if isinstance(aval_out.dtype, prng.KeyTy):
if not is_smem_load:
raise ValueError("PRNG keys must be loaded from SMEM. Did you set "
"the memory space to TPUMemorySpace.SMEM in the "
"BlockSpec for the PRNG key input?")
return _prng_key_load_lowering_rule(ctx, *args_flat, args_tree=args_tree)
if not is_smem_load and not ref_block_shape:
raise NotImplementedError(
"Indexing into a ()-shaped Ref not yet supported on TPU.")
@ -947,6 +981,37 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
_dtype_to_ir_type(aval_out.dtype))
return vector.ShapeCastOp(vec_type, load_val).result
def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree) -> KeyScalarBundle:
"""Lowering rule for loading PRNG keys from SMEM.
PRNG key loads are currently lowered as a list of scalar loads from SMEM,
rather than a single vector load.
We store these scalars in a bundle type called KeyScalarBundle, which has
special case handling for functions that consume the key such as set_seed.
"""
ref, _, _, _ = args_tree.unflatten(args_flat)
(aval_out,) = ctx.avals_out
assert isinstance(aval_out.dtype, prng.KeyTy)
ref_block_shape = aval_out.dtype._impl.key_shape
if len(ref_block_shape) != 1:
raise NotImplementedError("Seed key_data must be 1D.")
if ref_block_shape[0] != 1:
raise NotImplementedError("Seed key_data of shape != (1,) not supported.")
load_ops = []
for i in range(ref_block_shape[0]):
idx = NDIndexer(indices=(i,), shape=ref_block_shape,
int_indexer_shape=tuple())
starts, _, _, _, _ = _indexer_to_start_size_stride(
idx,
ref_block_shape,
cast_to_index=True,
)
load_ops.append(memref.LoadOp(ref, starts).result)
return KeyScalarBundle(scalars=load_ops)
lowering_rules[primitives.load_p] = _load_lowering_rule
skip_mlir_conversions.add(primitives.load_p)
@ -2393,6 +2458,15 @@ lowering_rules[primitives.debug_print_p] = _debug_print_rule
def _prng_seed_lowering_rule(ctx: LoweringRuleContext, *seeds):
del ctx
# In the KeyScalarBundle case we unpack the bundle and set the seed with
# the list of scalars.
if len(seeds) == 1 and isinstance(seeds[0], KeyScalarBundle):
return tpu.PRNGSeed32Op(seeds[0].scalars).results
# For integer seeds, we can set the seed directly as PRNGSeed32Op natively
# takes in a list of integers as input.
all_integers = all(isinstance(seed.type, ir.IntegerType) for seed in seeds)
if not all_integers:
raise ValueError("All seed data must be integers.")
return tpu.PRNGSeed32Op(seeds).results
lowering_rules[tpu_primitives.prng_seed_p] = _prng_seed_lowering_rule
@ -2434,12 +2508,27 @@ lowering_rules[prng.random_fold_in_p] = random_fold_in_lowering
def random_unwrap_lowering(ctx, key):
del ctx
return key
del ctx, key
raise NotImplementedError("key_data not implemented.")
lowering_rules[prng.random_unwrap_p] = random_unwrap_lowering
def random_wrap_lowering(ctx, key_data, *, impl):
del ctx, impl
return key_data
if isinstance(key_data.type, ir.VectorType):
# If the key data lives in vregs, need to unpack it to sregs.
key_data_list = []
key_data_shape = key_data.type.shape
if len(key_data_shape) != 1:
raise NotImplementedError("Seed key_data must be 1D.")
if key_data_shape[0] != 1:
raise NotImplementedError("key_data with shape != (1,) not supported.")
for i in range(key_data_shape[0]):
key_data_list.append(vector.ExtractOp(key_data, [], [i]))
return KeyScalarBundle(scalars=key_data_list)
if isinstance(key_data, KeyScalarBundle):
return key_data
else:
raise NotImplementedError(f"key_data wrap {type(key_data)}")
lowering_rules[prng.random_wrap_p] = random_wrap_lowering

View File

@ -21,6 +21,7 @@ import warnings
import jax
from jax import core as jax_core
from jax._src import core as jax_src_core
from jax._src import sharding_impls
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
@ -90,6 +91,13 @@ def pallas_call_tpu_lowering_rule(
for a in input_output_aliases
)
out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes]
# Replace in_avals to physical avals.
# This step is required for mapping logical types to physical types.
# (e.g. PRNG key -> uint32[2])
physical_avals = [jax_src_core.physical_aval(aval) for aval in ctx.avals_in]
ctx = ctx.replace(avals_in=physical_avals)
def _lower_fun(*args):
# Dynamic grid bounds have to go at the front.
dynamic_grid_args, args = args[:num_dyn_bounds], args[num_dyn_bounds:],

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Union
from typing import Any, Callable
import jax
import numpy as np
@ -24,8 +24,13 @@ from jax._src import prng as jax_prng
Shape = jax_prng.Shape
SampleFnType = Any
KeylessSampleFnType = Callable[..., jax.Array]
set_seed = prng_seed
FOLD_IN_ROUNDS = 128
SUPPORTED_CONVERSION_KEYS = ["rbg", "unsafe_rbg", "pallas"]
SUPPORTED_CONVERSION_KEYS = ["rbg", "unsafe_rbg", "pallas_tpu"]
def to_pallas_key(key: jax_prng.PRNGKeyArray) -> jax_prng.PRNGKeyArray:
"""Helper function for converting non-Pallas PRNG keys into Pallas keys."""
@ -45,7 +50,7 @@ def to_pallas_key(key: jax_prng.PRNGKeyArray) -> jax_prng.PRNGKeyArray:
raise ValueError(f"Key data must be at least {pallas_key_size} bytes.")
pallas_key_data = jnp.ravel(key_data)[:pallas_key_size]
pallas_key_data = jnp.reshape(pallas_key_data, tpu_key_impl.key_shape)
return jax_api_random.wrap_key_data(pallas_key_data, impl='pallas')
return jax_api_random.wrap_key_data(pallas_key_data, impl="pallas_tpu")
def _seed_func(seed: jnp.int32):
seed_data = jnp.zeros(tpu_key_impl.key_shape, dtype=jnp.int32)
@ -53,12 +58,8 @@ def _seed_func(seed: jnp.int32):
def _random_bits(key: typing.Array, bit_width: int, shape: Shape):
if bit_width != 32:
raise NotImplementedError("Bit width must be 32")
if isinstance(key.dtype, jax_prng.KeyTy):
key_data = jax.random.key_data(key)
else:
key_data = key
prng_seed(key_data[0, 0])
raise ValueError("Bit width must be 32")
prng_seed(key)
return prng_random_bits(shape)
def _fold_in(key: jax_prng.PRNGKeyArray, data: typing.Array):
@ -67,17 +68,17 @@ def _fold_in(key: jax_prng.PRNGKeyArray, data: typing.Array):
# Because the TPU generates random numbers in (8, 128) blocks at once, we
# can generate that many values without additional cost which will reduce
# correlation between the old and new keys.
key_data = jax.random.key_data(key)
key_shape = tpu_key_impl.key_shape
prng_seed(data)
data_bits = prng_random_bits(
key_data.shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32)
prng_seed(key_data[0, 0])
key_shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32)
prng_seed(key)
key_bits = prng_random_bits(
key_data.shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32)
key_shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32)
mixed = key_bits[..., FOLD_IN_ROUNDS-1] ^ data_bits[..., FOLD_IN_ROUNDS-1]
assert mixed.shape == key_data.shape
assert mixed.shape == key_shape
impl: jax_prng.PRNGSpec = jax.random.key_impl(key) # type: ignore
return jax.random.wrap_key_data(mixed, impl=impl)
@ -86,13 +87,12 @@ def _split(key: typing.Array, shape: Shape):
raise NotImplementedError()
tpu_key_impl = jax_prng.PRNGImpl(
# Use a 2D key since pallas only supports 2D tiling.
key_shape=(1, 1),
key_shape=(1,),
seed=_seed_func,
split=_split,
random_bits=_random_bits,
fold_in=_fold_in,
name="pallas",
name="pallas_tpu",
tag="pl"
)
jax_prng.register_prng(tpu_key_impl)
@ -135,26 +135,6 @@ tpu_internal_stateful_impl = jax_prng.PRNGImpl(
)
jax_prng.register_prng(tpu_internal_stateful_impl)
def set_seed(seed: Union[jnp.int32, jax.Array]):
"""Sets the seed for PRNG.
Args:
seeds: An integer seed for setting the PRNG seed.
"""
if isinstance(seed, jax.Array):
if seed.ndim != 1:
raise ValueError("Seed data must be a scalar or 1D array")
# TODO(justinfu): Mosaic currently only supports indexing by 0
# for scalar results when using vector.extract
# After support is added, use all seed data.
prng_seed(seed[0])
else:
prng_seed(seed)
SampleFnType = Any
KeylessSampleFnType = Callable[..., jax.Array]
def _make_stateful_sampler(sampler: SampleFnType) -> KeylessSampleFnType:
"""Converts a jax.random sampling function to a stateful version.
@ -174,8 +154,8 @@ def _make_stateful_sampler(sampler: SampleFnType) -> KeylessSampleFnType:
return sampler(placeholder_key, *args, **kwargs)
# Remove key argument from docstring.
doc_lines = filter(
lambda line: 'key:' not in line, sampler.__doc__.split('\n'))
new_sampler.__doc__ = '\n'.join(doc_lines)
lambda line: "key:" not in line, sampler.__doc__.split("\n"))
new_sampler.__doc__ = "\n".join(doc_lines)
return new_sampler
bits = _make_stateful_sampler(jax_api_random.bits)

View File

@ -18,6 +18,7 @@ from absl.testing import parameterized
import jax
from jax import random as jax_random
from jax._src import test_util as jtu
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import random as plrandom
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
@ -27,6 +28,7 @@ import numpy as np
jax.config.parse_flags_with_absl()
class PRNGTest(jtu.JaxTestCase):
def setUp(self):
@ -90,32 +92,63 @@ class PRNGTest(jtu.JaxTestCase):
# Test stateful RNG using the jax.random API wrappers.
def body(key_ref, o_ref):
plrandom.set_seed(key_ref[...])
samples = plrandom.uniform(
o_ref[...] = plrandom.uniform(
shape=o_ref[...].shape, minval=0.0, maxval=1.0)
o_ref[...] = samples
key = jax_random.key_data(jax_random.key(0, impl="rbg"))
rbg_key = jax_random.key(0, impl="rbg")
key = plrandom.to_pallas_key(rbg_key)
o_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)
result = pl.pallas_call(body, out_shape=o_shape)(key)
result = pl.pallas_call(
body,
in_specs=[pl.BlockSpec(memory_space=tpu_core.TPUMemorySpace.SMEM)],
out_shape=o_shape,
)(key)
self.assertGreaterEqual(jnp.min(result), 0)
self.assertLessEqual(jnp.max(result), 1.0)
def test_stateless_uniform_sample(self):
# Test keyed RNG using the jax.random API.
def body(key_ref, o_ref):
key = jax_random.wrap_key_data(key_ref[...], impl="pallas")
samples = jax_random.uniform(
key, shape=o_ref[...].shape, minval=0.0, maxval=1.0)
o_ref[...] = samples
o_ref[...] = jax_random.uniform(
key_ref[...], shape=o_ref[...].shape, minval=0.0, maxval=1.0
)
rbg_key = jax_random.key(0, impl="rbg")
pallas_key = plrandom.to_pallas_key(rbg_key)
key = jax_random.key_data(pallas_key)
key = plrandom.to_pallas_key(rbg_key)
o_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)
result = pl.pallas_call(body, out_shape=o_shape)(key)
result = pl.pallas_call(
body,
in_specs=[pl.BlockSpec(memory_space=tpu_core.TPUMemorySpace.SMEM)],
out_shape=o_shape,
)(key)
self.assertGreaterEqual(jnp.min(result), 0)
self.assertLessEqual(jnp.max(result), 1.0)
def test_fold_in(self):
# Test that folding in a value results in different random numbers.
def body(key_ref, o_ref):
key = key_ref[...]
o_ref[0, ...] = jax_random.uniform(
key, shape=o_ref[0, ...].shape, minval=0.0, maxval=1.0
)
key = jax_random.fold_in(key, 2)
o_ref[1, ...] = jax_random.uniform(
key, shape=o_ref[1, ...].shape, minval=0.0, maxval=1.0
)
rbg_key = jax_random.key(0, impl="rbg")
key = plrandom.to_pallas_key(rbg_key)
o_shape = jax.ShapeDtypeStruct((2, 8, 128), jnp.float32)
result = pl.pallas_call(
body,
in_specs=[pl.BlockSpec(memory_space=tpu_core.TPUMemorySpace.SMEM)],
out_shape=o_shape,
)(key)
result_a = result[0]
result_b = result[1]
np.testing.assert_array_compare(np.not_equal, result_a, result_b)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())