Move PRNG GPU lowering from jaxlib into JAX.

PiperOrigin-RevId: 738398099
This commit is contained in:
Dan Foreman-Mackey 2025-03-19 07:56:34 -07:00 committed by jax authors
parent 1e25c44d67
commit d7d0aa943e
2 changed files with 25 additions and 90 deletions

View File

@ -31,6 +31,7 @@ from jax._src import config as config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import ffi
from jax._src import pretty_printer as pp
from jax._src import source_info_util
from jax._src import tree_util as tree_util_internal
@ -64,6 +65,13 @@ Shape = tuple[int, ...]
UINT_DTYPES = {
8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64}
if hasattr(gpu_prng, "registrations"):
for platform, targets in gpu_prng.registrations().items():
for name, value, api_version in targets:
ffi.register_ffi_target(
name, value, platform=platform, api_version=api_version
)
# -- PRNG implementation interface
class PRNGImpl(NamedTuple):
@ -902,7 +910,7 @@ _threefry2x32_cpu_lowering_rule = mlir.lower_fun(
multiple_results=True)
def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2):
def _threefry2x32_gpu_lowering_rule(ctx, k1, k2, x1, x2, *, target_name_prefix):
if not config.threefry_gpu_kernel_lowering.value: # back to default lowering
return _threefry2x32_lowering_rule(ctx, k1, k2, x1, x2)
@ -917,23 +925,11 @@ def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2):
return mlir.broadcast_in_dim(ctx, x, aval_out,
broadcast_dimensions=range(rank - len(aval.shape), rank))
out_len = reduce(op.mul, aval_out.shape, 1)
if not core.is_constant_dim(out_len):
length = mlir.eval_dynamic_shape_as_tensor(ctx, [out_len])
length = mlir.hlo.convert(
ir.RankedTensorType.get((1,), ir.IntegerType.get_signless(64)),
length)
output_shape = mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape)
else:
length = int(out_len) # will be passed statically
output_shape = None
return lowering_func(
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length,
output_shape,
False, # forward_compatibility_mode
)
sub_ctx = ctx.replace(avals_in=(aval_out,) * 4)
rule = ffi.ffi_lowering(
f"{target_name_prefix}_threefry2x32_ffi")
return rule(sub_ctx, _broadcast(k1, k1_aval), _broadcast(k2, k2_aval),
_broadcast(x1, x1_aval), _broadcast(x2, x2_aval))
threefry2x32_p = core.Primitive("threefry2x32")
@ -947,11 +943,11 @@ mlir.register_lowering(
threefry2x32_p, _threefry2x32_cpu_lowering_rule, platform='cpu')
mlir.register_lowering(
threefry2x32_p,
partial(_threefry2x32_gpu_lowering_rule, gpu_prng.cuda_threefry2x32),
partial(_threefry2x32_gpu_lowering_rule, target_name_prefix='cu'),
platform='cuda')
mlir.register_lowering(
threefry2x32_p,
partial(_threefry2x32_gpu_lowering_rule, gpu_prng.rocm_threefry2x32),
partial(_threefry2x32_gpu_lowering_rule, target_name_prefix='hip'),
platform='rocm')

View File

@ -12,79 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from functools import partial
import itertools
from typing import Any
import jaxlib.mlir.ir as ir
from jaxlib import xla_client
from .hlo_helpers import custom_call
from .plugin_support import import_from_plugin
_cuda_prng = import_from_plugin("cuda", "_prng")
_hip_prng = import_from_plugin("rocm", "_prng")
if _cuda_prng:
for _name, _value in _cuda_prng.registrations().items():
# TODO(danfm): remove after JAX 0.5.1 release
api_version = 1 if "_ffi" in _name else 0
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
api_version=api_version)
if _hip_prng:
for _name, _value in _hip_prng.registrations().items():
# TODO(danfm): remove after JAX 0.5.1 release
api_version = 1 if "_ffi" in _name else 0
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
api_version=api_version)
def _threefry2x32_lowering(prng, platform: str, keys, data,
length: int | ir.Value | None = None,
output_shape: ir.Value | None = None,
forward_compatibility_mode: bool = False):
"""ThreeFry2x32 kernel for GPU.
In presence of dynamic shapes, `length` is an `ir.Value` and `output_shape`
is a 1D tensor describing the shape of the two outputs.
"""
del forward_compatibility_mode
assert len(keys) == 2, keys
assert len(data) == 2, data
assert (ir.RankedTensorType(keys[0].type).element_type ==
ir.IntegerType.get_unsigned(32)), keys[0].type
typ = keys[0].type
dims = ir.RankedTensorType(typ).shape
for x in itertools.chain(keys, data):
assert x.type == typ, (x.type, typ)
ndims = len(dims)
layout = tuple(range(ndims - 1, -1, -1))
operand_layouts = [layout] * 4
operands = [keys[0], keys[1], data[0], data[1]]
opaque = {} # Use if not forward_compatibility_mode to trigger the FFI (v4).
if isinstance(length, int):
result_shapes = None
else:
assert output_shape is not None
# We also need to pass separately the shapes of the outputs.
result_shapes = [output_shape, output_shape]
custom_call_target = f"{platform}_threefry2x32_ffi"
return custom_call(
custom_call_target,
api_version=4,
result_types=[typ, typ],
operands=operands,
backend_config=opaque,
operand_layouts=operand_layouts,
result_layouts=[layout] * 2,
result_shapes=result_shapes).results
cuda_threefry2x32 = partial(_threefry2x32_lowering, _cuda_prng, "cu")
rocm_threefry2x32 = partial(_threefry2x32_lowering, _hip_prng, "hip")
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
registrations = {"CUDA": [], "ROCM": []}
for platform, module in [("CUDA", _cuda_prng), ("ROCM", _hip_prng)]:
if module:
registrations[platform].extend(
(name, value, int(name.endswith("_ffi")))
for name, value in module.registrations().items())
return registrations