mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Move PRNG GPU lowering from jaxlib into JAX.
PiperOrigin-RevId: 738398099
This commit is contained in:
parent
1e25c44d67
commit
d7d0aa943e
@ -31,6 +31,7 @@ from jax._src import config as config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import ffi
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import source_info_util
|
||||
from jax._src import tree_util as tree_util_internal
|
||||
@ -64,6 +65,13 @@ Shape = tuple[int, ...]
|
||||
UINT_DTYPES = {
|
||||
8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64}
|
||||
|
||||
if hasattr(gpu_prng, "registrations"):
|
||||
for platform, targets in gpu_prng.registrations().items():
|
||||
for name, value, api_version in targets:
|
||||
ffi.register_ffi_target(
|
||||
name, value, platform=platform, api_version=api_version
|
||||
)
|
||||
|
||||
# -- PRNG implementation interface
|
||||
|
||||
class PRNGImpl(NamedTuple):
|
||||
@ -902,7 +910,7 @@ _threefry2x32_cpu_lowering_rule = mlir.lower_fun(
|
||||
multiple_results=True)
|
||||
|
||||
|
||||
def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2):
|
||||
def _threefry2x32_gpu_lowering_rule(ctx, k1, k2, x1, x2, *, target_name_prefix):
|
||||
if not config.threefry_gpu_kernel_lowering.value: # back to default lowering
|
||||
return _threefry2x32_lowering_rule(ctx, k1, k2, x1, x2)
|
||||
|
||||
@ -917,23 +925,11 @@ def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2):
|
||||
return mlir.broadcast_in_dim(ctx, x, aval_out,
|
||||
broadcast_dimensions=range(rank - len(aval.shape), rank))
|
||||
|
||||
out_len = reduce(op.mul, aval_out.shape, 1)
|
||||
if not core.is_constant_dim(out_len):
|
||||
length = mlir.eval_dynamic_shape_as_tensor(ctx, [out_len])
|
||||
length = mlir.hlo.convert(
|
||||
ir.RankedTensorType.get((1,), ir.IntegerType.get_signless(64)),
|
||||
length)
|
||||
output_shape = mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape)
|
||||
else:
|
||||
length = int(out_len) # will be passed statically
|
||||
output_shape = None
|
||||
|
||||
return lowering_func(
|
||||
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
||||
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length,
|
||||
output_shape,
|
||||
False, # forward_compatibility_mode
|
||||
)
|
||||
sub_ctx = ctx.replace(avals_in=(aval_out,) * 4)
|
||||
rule = ffi.ffi_lowering(
|
||||
f"{target_name_prefix}_threefry2x32_ffi")
|
||||
return rule(sub_ctx, _broadcast(k1, k1_aval), _broadcast(k2, k2_aval),
|
||||
_broadcast(x1, x1_aval), _broadcast(x2, x2_aval))
|
||||
|
||||
|
||||
threefry2x32_p = core.Primitive("threefry2x32")
|
||||
@ -947,11 +943,11 @@ mlir.register_lowering(
|
||||
threefry2x32_p, _threefry2x32_cpu_lowering_rule, platform='cpu')
|
||||
mlir.register_lowering(
|
||||
threefry2x32_p,
|
||||
partial(_threefry2x32_gpu_lowering_rule, gpu_prng.cuda_threefry2x32),
|
||||
partial(_threefry2x32_gpu_lowering_rule, target_name_prefix='cu'),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
threefry2x32_p,
|
||||
partial(_threefry2x32_gpu_lowering_rule, gpu_prng.rocm_threefry2x32),
|
||||
partial(_threefry2x32_gpu_lowering_rule, target_name_prefix='hip'),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
|
@ -12,79 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
from functools import partial
|
||||
import itertools
|
||||
from typing import Any
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
from .hlo_helpers import custom_call
|
||||
from .plugin_support import import_from_plugin
|
||||
|
||||
_cuda_prng = import_from_plugin("cuda", "_prng")
|
||||
_hip_prng = import_from_plugin("rocm", "_prng")
|
||||
|
||||
if _cuda_prng:
|
||||
for _name, _value in _cuda_prng.registrations().items():
|
||||
# TODO(danfm): remove after JAX 0.5.1 release
|
||||
api_version = 1 if "_ffi" in _name else 0
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
|
||||
api_version=api_version)
|
||||
|
||||
if _hip_prng:
|
||||
for _name, _value in _hip_prng.registrations().items():
|
||||
# TODO(danfm): remove after JAX 0.5.1 release
|
||||
api_version = 1 if "_ffi" in _name else 0
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
|
||||
api_version=api_version)
|
||||
|
||||
|
||||
def _threefry2x32_lowering(prng, platform: str, keys, data,
|
||||
length: int | ir.Value | None = None,
|
||||
output_shape: ir.Value | None = None,
|
||||
forward_compatibility_mode: bool = False):
|
||||
"""ThreeFry2x32 kernel for GPU.
|
||||
|
||||
In presence of dynamic shapes, `length` is an `ir.Value` and `output_shape`
|
||||
is a 1D tensor describing the shape of the two outputs.
|
||||
"""
|
||||
del forward_compatibility_mode
|
||||
assert len(keys) == 2, keys
|
||||
assert len(data) == 2, data
|
||||
assert (ir.RankedTensorType(keys[0].type).element_type ==
|
||||
ir.IntegerType.get_unsigned(32)), keys[0].type
|
||||
|
||||
typ = keys[0].type
|
||||
dims = ir.RankedTensorType(typ).shape
|
||||
|
||||
for x in itertools.chain(keys, data):
|
||||
assert x.type == typ, (x.type, typ)
|
||||
ndims = len(dims)
|
||||
layout = tuple(range(ndims - 1, -1, -1))
|
||||
operand_layouts = [layout] * 4
|
||||
operands = [keys[0], keys[1], data[0], data[1]]
|
||||
|
||||
opaque = {} # Use if not forward_compatibility_mode to trigger the FFI (v4).
|
||||
if isinstance(length, int):
|
||||
result_shapes = None
|
||||
else:
|
||||
assert output_shape is not None
|
||||
# We also need to pass separately the shapes of the outputs.
|
||||
result_shapes = [output_shape, output_shape]
|
||||
|
||||
custom_call_target = f"{platform}_threefry2x32_ffi"
|
||||
return custom_call(
|
||||
custom_call_target,
|
||||
api_version=4,
|
||||
result_types=[typ, typ],
|
||||
operands=operands,
|
||||
backend_config=opaque,
|
||||
operand_layouts=operand_layouts,
|
||||
result_layouts=[layout] * 2,
|
||||
result_shapes=result_shapes).results
|
||||
|
||||
|
||||
cuda_threefry2x32 = partial(_threefry2x32_lowering, _cuda_prng, "cu")
|
||||
rocm_threefry2x32 = partial(_threefry2x32_lowering, _hip_prng, "hip")
|
||||
def registrations() -> dict[str, list[tuple[str, Any, int]]]:
|
||||
registrations = {"CUDA": [], "ROCM": []}
|
||||
for platform, module in [("CUDA", _cuda_prng), ("ROCM", _hip_prng)]:
|
||||
if module:
|
||||
registrations[platform].extend(
|
||||
(name, value, int(name.endswith("_ffi")))
|
||||
for name, value in module.registrations().items())
|
||||
return registrations
|
||||
|
Loading…
x
Reference in New Issue
Block a user