mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas] Allow keys as input to Pallas kernels.
PiperOrigin-RevId: 642740833
This commit is contained in:
parent
b7a8f9d584
commit
4b81680b62
@ -28,6 +28,7 @@ from jax import tree_util
|
||||
from jax._src import ad_util
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import debugging
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import pjit
|
||||
@ -150,6 +151,15 @@ def aval_to_ir_type(aval, shape=None, memory_space: TPUMemorySpace | None = None
|
||||
raise ValueError(f"Cannot allocate {aval.sem_type}.")
|
||||
memspace = _memory_space_to_tpu_memspace(TPUMemorySpace.SEMAPHORE)
|
||||
return ir.MemRefType.get((), sem_type, memory_space=memspace)
|
||||
if dtypes.issubdtype(aval.dtype, dtypes.prng_key):
|
||||
shape = aval.dtype._impl.key_shape
|
||||
if memory_space is None:
|
||||
memory_space = TPUMemorySpace.SMEM
|
||||
if memory_space != TPUMemorySpace.SMEM:
|
||||
raise ValueError(f"PRNG keys must be stored in SMEM. Got {memory_space}")
|
||||
memspace = _memory_space_to_tpu_memspace(memory_space)
|
||||
return ir.MemRefType.get(shape, _dtype_to_ir_type(np.dtype(np.uint32)),
|
||||
memory_space=memspace)
|
||||
if isinstance(aval, state.AbstractRef):
|
||||
if shape is None:
|
||||
shape = aval.shape
|
||||
@ -626,7 +636,8 @@ def jaxpr_subcomp(
|
||||
return atom.val if isinstance(atom, jax_core.Literal) else env[atom]
|
||||
|
||||
def write_env(var: jax_core.Var, val):
|
||||
assert isinstance(val, ir.Value), type(val)
|
||||
is_valid_type = isinstance(val, (ir.Value, KeyScalarBundle))
|
||||
assert is_valid_type, type(val)
|
||||
env[var] = val
|
||||
|
||||
for invar, bs in zip(jaxpr.invars, ctx.block_shapes):
|
||||
@ -703,6 +714,8 @@ def jaxpr_subcomp(
|
||||
def _ensure_mlir_value(val, aval):
|
||||
if isinstance(val, ir.Value):
|
||||
return val
|
||||
if isinstance(val, KeyScalarBundle):
|
||||
return val
|
||||
elif isinstance(val, (np.generic, np.ndarray, int, float)):
|
||||
return ir_constant(val, _dtype_to_ir_type(aval.dtype))
|
||||
else:
|
||||
@ -898,6 +911,21 @@ def _index_ref(ref, ref_aval, ref_block_shape, indexers):
|
||||
ref_block_shape)
|
||||
return ref, ref_block_shape
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class KeyScalarBundle:
|
||||
"""A container class for PRNG key data.
|
||||
|
||||
We pass around keys as a KeyScalarBundle in the lowering pass rather than
|
||||
as a vector, since we want the key data to live in scalar registers rather
|
||||
than vector registers. This special dataclass exists so we can return
|
||||
multiple scalar values from load_op, because the load_op primitive does
|
||||
not allow multiple results.
|
||||
|
||||
Attributes:
|
||||
scalars: A list of OpResults representing scalar key data during the
|
||||
lowering pass.
|
||||
"""
|
||||
scalars: list[ir.OpResult]
|
||||
|
||||
def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
|
||||
ref, indexers, mask, _ = args_tree.unflatten(args_flat)
|
||||
@ -916,6 +944,12 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
|
||||
is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space<smem>"
|
||||
ref_aval, *_ = ctx.avals_in
|
||||
(aval_out,) = ctx.avals_out
|
||||
if isinstance(aval_out.dtype, prng.KeyTy):
|
||||
if not is_smem_load:
|
||||
raise ValueError("PRNG keys must be loaded from SMEM. Did you set "
|
||||
"the memory space to TPUMemorySpace.SMEM in the "
|
||||
"BlockSpec for the PRNG key input?")
|
||||
return _prng_key_load_lowering_rule(ctx, *args_flat, args_tree=args_tree)
|
||||
if not is_smem_load and not ref_block_shape:
|
||||
raise NotImplementedError(
|
||||
"Indexing into a ()-shaped Ref not yet supported on TPU.")
|
||||
@ -947,6 +981,37 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
|
||||
_dtype_to_ir_type(aval_out.dtype))
|
||||
return vector.ShapeCastOp(vec_type, load_val).result
|
||||
|
||||
def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree) -> KeyScalarBundle:
|
||||
"""Lowering rule for loading PRNG keys from SMEM.
|
||||
|
||||
PRNG key loads are currently lowered as a list of scalar loads from SMEM,
|
||||
rather than a single vector load.
|
||||
We store these scalars in a bundle type called KeyScalarBundle, which has
|
||||
special case handling for functions that consume the key such as set_seed.
|
||||
"""
|
||||
ref, _, _, _ = args_tree.unflatten(args_flat)
|
||||
(aval_out,) = ctx.avals_out
|
||||
assert isinstance(aval_out.dtype, prng.KeyTy)
|
||||
ref_block_shape = aval_out.dtype._impl.key_shape
|
||||
|
||||
if len(ref_block_shape) != 1:
|
||||
raise NotImplementedError("Seed key_data must be 1D.")
|
||||
if ref_block_shape[0] != 1:
|
||||
raise NotImplementedError("Seed key_data of shape != (1,) not supported.")
|
||||
|
||||
load_ops = []
|
||||
for i in range(ref_block_shape[0]):
|
||||
idx = NDIndexer(indices=(i,), shape=ref_block_shape,
|
||||
int_indexer_shape=tuple())
|
||||
starts, _, _, _, _ = _indexer_to_start_size_stride(
|
||||
idx,
|
||||
ref_block_shape,
|
||||
cast_to_index=True,
|
||||
)
|
||||
load_ops.append(memref.LoadOp(ref, starts).result)
|
||||
return KeyScalarBundle(scalars=load_ops)
|
||||
|
||||
|
||||
|
||||
lowering_rules[primitives.load_p] = _load_lowering_rule
|
||||
skip_mlir_conversions.add(primitives.load_p)
|
||||
@ -2393,6 +2458,15 @@ lowering_rules[primitives.debug_print_p] = _debug_print_rule
|
||||
|
||||
def _prng_seed_lowering_rule(ctx: LoweringRuleContext, *seeds):
|
||||
del ctx
|
||||
# In the KeyScalarBundle case we unpack the bundle and set the seed with
|
||||
# the list of scalars.
|
||||
if len(seeds) == 1 and isinstance(seeds[0], KeyScalarBundle):
|
||||
return tpu.PRNGSeed32Op(seeds[0].scalars).results
|
||||
# For integer seeds, we can set the seed directly as PRNGSeed32Op natively
|
||||
# takes in a list of integers as input.
|
||||
all_integers = all(isinstance(seed.type, ir.IntegerType) for seed in seeds)
|
||||
if not all_integers:
|
||||
raise ValueError("All seed data must be integers.")
|
||||
return tpu.PRNGSeed32Op(seeds).results
|
||||
lowering_rules[tpu_primitives.prng_seed_p] = _prng_seed_lowering_rule
|
||||
|
||||
@ -2434,12 +2508,27 @@ lowering_rules[prng.random_fold_in_p] = random_fold_in_lowering
|
||||
|
||||
|
||||
def random_unwrap_lowering(ctx, key):
|
||||
del ctx
|
||||
return key
|
||||
del ctx, key
|
||||
raise NotImplementedError("key_data not implemented.")
|
||||
lowering_rules[prng.random_unwrap_p] = random_unwrap_lowering
|
||||
|
||||
|
||||
def random_wrap_lowering(ctx, key_data, *, impl):
|
||||
del ctx, impl
|
||||
return key_data
|
||||
if isinstance(key_data.type, ir.VectorType):
|
||||
# If the key data lives in vregs, need to unpack it to sregs.
|
||||
key_data_list = []
|
||||
key_data_shape = key_data.type.shape
|
||||
if len(key_data_shape) != 1:
|
||||
raise NotImplementedError("Seed key_data must be 1D.")
|
||||
if key_data_shape[0] != 1:
|
||||
raise NotImplementedError("key_data with shape != (1,) not supported.")
|
||||
for i in range(key_data_shape[0]):
|
||||
key_data_list.append(vector.ExtractOp(key_data, [], [i]))
|
||||
return KeyScalarBundle(scalars=key_data_list)
|
||||
if isinstance(key_data, KeyScalarBundle):
|
||||
return key_data
|
||||
else:
|
||||
raise NotImplementedError(f"key_data wrap {type(key_data)}")
|
||||
|
||||
lowering_rules[prng.random_wrap_p] = random_wrap_lowering
|
||||
|
@ -21,6 +21,7 @@ import warnings
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax._src import core as jax_src_core
|
||||
from jax._src import sharding_impls
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
@ -90,6 +91,13 @@ def pallas_call_tpu_lowering_rule(
|
||||
for a in input_output_aliases
|
||||
)
|
||||
out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes]
|
||||
|
||||
# Replace in_avals to physical avals.
|
||||
# This step is required for mapping logical types to physical types.
|
||||
# (e.g. PRNG key -> uint32[2])
|
||||
physical_avals = [jax_src_core.physical_aval(aval) for aval in ctx.avals_in]
|
||||
ctx = ctx.replace(avals_in=physical_avals)
|
||||
|
||||
def _lower_fun(*args):
|
||||
# Dynamic grid bounds have to go at the front.
|
||||
dynamic_grid_args, args = args[:num_dyn_bounds], args[num_dyn_bounds:],
|
||||
|
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Callable
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
@ -24,8 +24,13 @@ from jax._src import prng as jax_prng
|
||||
|
||||
|
||||
Shape = jax_prng.Shape
|
||||
SampleFnType = Any
|
||||
KeylessSampleFnType = Callable[..., jax.Array]
|
||||
|
||||
set_seed = prng_seed
|
||||
|
||||
FOLD_IN_ROUNDS = 128
|
||||
SUPPORTED_CONVERSION_KEYS = ["rbg", "unsafe_rbg", "pallas"]
|
||||
SUPPORTED_CONVERSION_KEYS = ["rbg", "unsafe_rbg", "pallas_tpu"]
|
||||
|
||||
def to_pallas_key(key: jax_prng.PRNGKeyArray) -> jax_prng.PRNGKeyArray:
|
||||
"""Helper function for converting non-Pallas PRNG keys into Pallas keys."""
|
||||
@ -45,7 +50,7 @@ def to_pallas_key(key: jax_prng.PRNGKeyArray) -> jax_prng.PRNGKeyArray:
|
||||
raise ValueError(f"Key data must be at least {pallas_key_size} bytes.")
|
||||
pallas_key_data = jnp.ravel(key_data)[:pallas_key_size]
|
||||
pallas_key_data = jnp.reshape(pallas_key_data, tpu_key_impl.key_shape)
|
||||
return jax_api_random.wrap_key_data(pallas_key_data, impl='pallas')
|
||||
return jax_api_random.wrap_key_data(pallas_key_data, impl="pallas_tpu")
|
||||
|
||||
def _seed_func(seed: jnp.int32):
|
||||
seed_data = jnp.zeros(tpu_key_impl.key_shape, dtype=jnp.int32)
|
||||
@ -53,12 +58,8 @@ def _seed_func(seed: jnp.int32):
|
||||
|
||||
def _random_bits(key: typing.Array, bit_width: int, shape: Shape):
|
||||
if bit_width != 32:
|
||||
raise NotImplementedError("Bit width must be 32")
|
||||
if isinstance(key.dtype, jax_prng.KeyTy):
|
||||
key_data = jax.random.key_data(key)
|
||||
else:
|
||||
key_data = key
|
||||
prng_seed(key_data[0, 0])
|
||||
raise ValueError("Bit width must be 32")
|
||||
prng_seed(key)
|
||||
return prng_random_bits(shape)
|
||||
|
||||
def _fold_in(key: jax_prng.PRNGKeyArray, data: typing.Array):
|
||||
@ -67,17 +68,17 @@ def _fold_in(key: jax_prng.PRNGKeyArray, data: typing.Array):
|
||||
# Because the TPU generates random numbers in (8, 128) blocks at once, we
|
||||
# can generate that many values without additional cost which will reduce
|
||||
# correlation between the old and new keys.
|
||||
key_data = jax.random.key_data(key)
|
||||
key_shape = tpu_key_impl.key_shape
|
||||
|
||||
prng_seed(data)
|
||||
data_bits = prng_random_bits(
|
||||
key_data.shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32)
|
||||
prng_seed(key_data[0, 0])
|
||||
key_shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32)
|
||||
prng_seed(key)
|
||||
key_bits = prng_random_bits(
|
||||
key_data.shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32)
|
||||
key_shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32)
|
||||
|
||||
mixed = key_bits[..., FOLD_IN_ROUNDS-1] ^ data_bits[..., FOLD_IN_ROUNDS-1]
|
||||
assert mixed.shape == key_data.shape
|
||||
assert mixed.shape == key_shape
|
||||
impl: jax_prng.PRNGSpec = jax.random.key_impl(key) # type: ignore
|
||||
return jax.random.wrap_key_data(mixed, impl=impl)
|
||||
|
||||
@ -86,13 +87,12 @@ def _split(key: typing.Array, shape: Shape):
|
||||
raise NotImplementedError()
|
||||
|
||||
tpu_key_impl = jax_prng.PRNGImpl(
|
||||
# Use a 2D key since pallas only supports 2D tiling.
|
||||
key_shape=(1, 1),
|
||||
key_shape=(1,),
|
||||
seed=_seed_func,
|
||||
split=_split,
|
||||
random_bits=_random_bits,
|
||||
fold_in=_fold_in,
|
||||
name="pallas",
|
||||
name="pallas_tpu",
|
||||
tag="pl"
|
||||
)
|
||||
jax_prng.register_prng(tpu_key_impl)
|
||||
@ -135,26 +135,6 @@ tpu_internal_stateful_impl = jax_prng.PRNGImpl(
|
||||
)
|
||||
jax_prng.register_prng(tpu_internal_stateful_impl)
|
||||
|
||||
def set_seed(seed: Union[jnp.int32, jax.Array]):
|
||||
"""Sets the seed for PRNG.
|
||||
|
||||
Args:
|
||||
seeds: An integer seed for setting the PRNG seed.
|
||||
"""
|
||||
if isinstance(seed, jax.Array):
|
||||
if seed.ndim != 1:
|
||||
raise ValueError("Seed data must be a scalar or 1D array")
|
||||
# TODO(justinfu): Mosaic currently only supports indexing by 0
|
||||
# for scalar results when using vector.extract
|
||||
# After support is added, use all seed data.
|
||||
prng_seed(seed[0])
|
||||
else:
|
||||
prng_seed(seed)
|
||||
|
||||
|
||||
SampleFnType = Any
|
||||
KeylessSampleFnType = Callable[..., jax.Array]
|
||||
|
||||
def _make_stateful_sampler(sampler: SampleFnType) -> KeylessSampleFnType:
|
||||
"""Converts a jax.random sampling function to a stateful version.
|
||||
|
||||
@ -174,8 +154,8 @@ def _make_stateful_sampler(sampler: SampleFnType) -> KeylessSampleFnType:
|
||||
return sampler(placeholder_key, *args, **kwargs)
|
||||
# Remove key argument from docstring.
|
||||
doc_lines = filter(
|
||||
lambda line: 'key:' not in line, sampler.__doc__.split('\n'))
|
||||
new_sampler.__doc__ = '\n'.join(doc_lines)
|
||||
lambda line: "key:" not in line, sampler.__doc__.split("\n"))
|
||||
new_sampler.__doc__ = "\n".join(doc_lines)
|
||||
return new_sampler
|
||||
|
||||
bits = _make_stateful_sampler(jax_api_random.bits)
|
||||
|
@ -18,6 +18,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import random as jax_random
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.pallas.mosaic import core as tpu_core
|
||||
from jax._src.pallas.mosaic import random as plrandom
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
@ -27,6 +28,7 @@ import numpy as np
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class PRNGTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -90,32 +92,63 @@ class PRNGTest(jtu.JaxTestCase):
|
||||
# Test stateful RNG using the jax.random API wrappers.
|
||||
def body(key_ref, o_ref):
|
||||
plrandom.set_seed(key_ref[...])
|
||||
samples = plrandom.uniform(
|
||||
o_ref[...] = plrandom.uniform(
|
||||
shape=o_ref[...].shape, minval=0.0, maxval=1.0)
|
||||
o_ref[...] = samples
|
||||
|
||||
key = jax_random.key_data(jax_random.key(0, impl="rbg"))
|
||||
rbg_key = jax_random.key(0, impl="rbg")
|
||||
key = plrandom.to_pallas_key(rbg_key)
|
||||
o_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)
|
||||
result = pl.pallas_call(body, out_shape=o_shape)(key)
|
||||
result = pl.pallas_call(
|
||||
body,
|
||||
in_specs=[pl.BlockSpec(memory_space=tpu_core.TPUMemorySpace.SMEM)],
|
||||
out_shape=o_shape,
|
||||
)(key)
|
||||
self.assertGreaterEqual(jnp.min(result), 0)
|
||||
self.assertLessEqual(jnp.max(result), 1.0)
|
||||
|
||||
def test_stateless_uniform_sample(self):
|
||||
# Test keyed RNG using the jax.random API.
|
||||
def body(key_ref, o_ref):
|
||||
key = jax_random.wrap_key_data(key_ref[...], impl="pallas")
|
||||
samples = jax_random.uniform(
|
||||
key, shape=o_ref[...].shape, minval=0.0, maxval=1.0)
|
||||
o_ref[...] = samples
|
||||
o_ref[...] = jax_random.uniform(
|
||||
key_ref[...], shape=o_ref[...].shape, minval=0.0, maxval=1.0
|
||||
)
|
||||
|
||||
rbg_key = jax_random.key(0, impl="rbg")
|
||||
pallas_key = plrandom.to_pallas_key(rbg_key)
|
||||
key = jax_random.key_data(pallas_key)
|
||||
key = plrandom.to_pallas_key(rbg_key)
|
||||
o_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)
|
||||
result = pl.pallas_call(body, out_shape=o_shape)(key)
|
||||
result = pl.pallas_call(
|
||||
body,
|
||||
in_specs=[pl.BlockSpec(memory_space=tpu_core.TPUMemorySpace.SMEM)],
|
||||
out_shape=o_shape,
|
||||
)(key)
|
||||
self.assertGreaterEqual(jnp.min(result), 0)
|
||||
self.assertLessEqual(jnp.max(result), 1.0)
|
||||
|
||||
def test_fold_in(self):
|
||||
# Test that folding in a value results in different random numbers.
|
||||
def body(key_ref, o_ref):
|
||||
key = key_ref[...]
|
||||
o_ref[0, ...] = jax_random.uniform(
|
||||
key, shape=o_ref[0, ...].shape, minval=0.0, maxval=1.0
|
||||
)
|
||||
|
||||
key = jax_random.fold_in(key, 2)
|
||||
o_ref[1, ...] = jax_random.uniform(
|
||||
key, shape=o_ref[1, ...].shape, minval=0.0, maxval=1.0
|
||||
)
|
||||
|
||||
rbg_key = jax_random.key(0, impl="rbg")
|
||||
key = plrandom.to_pallas_key(rbg_key)
|
||||
o_shape = jax.ShapeDtypeStruct((2, 8, 128), jnp.float32)
|
||||
result = pl.pallas_call(
|
||||
body,
|
||||
in_specs=[pl.BlockSpec(memory_space=tpu_core.TPUMemorySpace.SMEM)],
|
||||
out_shape=o_shape,
|
||||
)(key)
|
||||
result_a = result[0]
|
||||
result_b = result[1]
|
||||
np.testing.assert_array_compare(np.not_equal, result_a, result_b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user