mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Add support for dynamic shapes to GPU threefry2x32 custom call.
In presence of dynamic shapes the ThreeFry2x32Descriptor will contain the value n=-1, and the actual desired output length will be passed as an additional operand. If the shape is static then the length will be passed as part of the descriptor. PiperOrigin-RevId: 497945778
This commit is contained in:
parent
0c8a4fb7cd
commit
7d452adfd3
@ -44,7 +44,7 @@ from jax._src.lax import utils as lax_utils
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip
|
||||
from jax._src.lib import gpu_prng
|
||||
from jax._src.lib import gpu_prng, xla_extension_version
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
@ -940,7 +940,7 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
|
||||
return tuple(x)
|
||||
|
||||
|
||||
def _threefry2x32_gpu_lowering(threefry2x32_lowering, ctx, k1, k2, x1, x2):
|
||||
def _threefry2x32_gpu_lowering(lowering_func, ctx, k1, k2, x1, x2):
|
||||
aval_out, _ = ctx.avals_out
|
||||
k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in
|
||||
rank = len(aval_out.shape)
|
||||
@ -950,10 +950,24 @@ def _threefry2x32_gpu_lowering(threefry2x32_lowering, ctx, k1, k2, x1, x2):
|
||||
def _broadcast(x, aval):
|
||||
return mlir.broadcast_in_dim(ctx, x, aval_out,
|
||||
broadcast_dimensions=range(rank - len(aval.shape), rank))
|
||||
return threefry2x32_lowering(
|
||||
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
||||
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)))
|
||||
|
||||
if xla_extension_version >= 113:
|
||||
out_len = reduce(op.mul, aval_out.shape, 1)
|
||||
if not core.is_constant_dim(out_len):
|
||||
length = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, [out_len]))
|
||||
length = mlir.hlo.ConvertOp(
|
||||
mlir.ir.RankedTensorType.get((1,), mlir.ir.IntegerType.get_signless(64)),
|
||||
length).result
|
||||
else:
|
||||
length = int(out_len) # will be passed statically
|
||||
|
||||
return lowering_func(
|
||||
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
||||
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length)
|
||||
else:
|
||||
return lowering_func(
|
||||
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
||||
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)))
|
||||
|
||||
threefry2x32_p = core.Primitive("threefry2x32")
|
||||
threefry2x32_p.multiple_results = True
|
||||
|
@ -1873,8 +1873,16 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
poly_axes=[None, 0]),
|
||||
PolyHarness("random_categorical", "axis=1",
|
||||
lambda key, a: jax.random.categorical(key, a, axis=1),
|
||||
arg_descriptors=[RandArg((2,), np.uint32), RandArg((3, 8), _f32)],
|
||||
poly_axes=[None, 0]),
|
||||
arg_descriptors=[RandArg((2,), np.uint32), RandArg((3, 5, 8), _f32)],
|
||||
poly_axes=[None, (0, 1)]),
|
||||
PolyHarness("random_categorical", "axis=1_then_reshape",
|
||||
lambda key, a: jax.random.categorical(key, a, axis=1).reshape((-1)),
|
||||
arg_descriptors=[RandArg((2,), np.uint32), RandArg((3, 5, 8), _f32)],
|
||||
poly_axes=[None, (0, 1)]),
|
||||
PolyHarness("random_categorical", "0_dim", # One axis has 0 size
|
||||
lambda key, a: jax.random.categorical(key, a, axis=1),
|
||||
arg_descriptors=[RandArg((2,), np.uint32), RandArg((3, 5, 0), _f32)],
|
||||
poly_axes=[None, (0, 1)]),
|
||||
# Works when the known dimensions are known to be even or odd.
|
||||
PolyHarness("random_uniform", "even_1",
|
||||
lambda key, a: jax.random.uniform(key, a.shape, dtype=_f32),
|
||||
@ -2262,9 +2270,6 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
"householder_product:cpu", "householder_product:gpu",
|
||||
"vmap_geqrf:cpu", "vmap_geqrf:gpu",
|
||||
"vmap_lu:cpu", "vmap_lu:gpu", "vmap_qr:cpu", "vmap_qr:gpu",
|
||||
"random_gamma:gpu", "vmap_random_gamma:gpu",
|
||||
"random_categorical:gpu", "vmap_random_categorical:gpu",
|
||||
"random_randint:gpu", "random_uniform:gpu", "vmap_random_split:gpu",
|
||||
"vmap_svd:cpu", "vmap_svd:gpu"}
|
||||
if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses:
|
||||
raise unittest.SkipTest("native lowering with shape polymorphism not implemented for custom calls; b/261671778")
|
||||
|
@ -107,15 +107,26 @@ void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
|
||||
std::array<const std::uint32_t*, 2> data;
|
||||
data[0] = reinterpret_cast<const std::uint32_t*>(buffers[2]);
|
||||
data[1] = reinterpret_cast<const std::uint32_t*>(buffers[3]);
|
||||
std::int64_t n = descriptor.n;
|
||||
int output_idx = 4;
|
||||
if (n < 0) {
|
||||
// n is an operand in device memory.
|
||||
gpuMemcpyAsync((void*)&n, reinterpret_cast<const std::int64_t*>(buffers[4]),
|
||||
sizeof(n), gpuMemcpyDeviceToHost,
|
||||
stream);
|
||||
gpuStreamSynchronize(stream);
|
||||
output_idx = 5;
|
||||
}
|
||||
|
||||
std::array<std::uint32_t*, 2> out;
|
||||
out[0] = reinterpret_cast<std::uint32_t*>(buffers[4]);
|
||||
out[1] = reinterpret_cast<std::uint32_t*>(buffers[5]);
|
||||
out[0] = reinterpret_cast<std::uint32_t*>(buffers[output_idx]);
|
||||
out[1] = reinterpret_cast<std::uint32_t*>(buffers[output_idx + 1]);
|
||||
const int block_dim = 128;
|
||||
const std::int64_t grid_dim =
|
||||
std::min<std::int64_t>(1024, (descriptor.n + block_dim - 1) / block_dim);
|
||||
std::min<std::int64_t>(1024, (n + block_dim - 1) / block_dim);
|
||||
ThreeFry2x32Kernel<<<grid_dim, block_dim, /*dynamic_shared_mem_bytes=*/0,
|
||||
stream>>>(keys[0], keys[1], data[0], data[1], out[0],
|
||||
out[1], descriptor.n);
|
||||
out[1], n);
|
||||
}
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
|
@ -26,7 +26,7 @@ namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
struct ThreeFry2x32Descriptor {
|
||||
std::int64_t n;
|
||||
std::int64_t n; // If -1 then the length is passed as a 5th operand
|
||||
};
|
||||
|
||||
void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
|
||||
|
@ -17,6 +17,7 @@ import functools
|
||||
from functools import partial
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Optional, Union
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
|
||||
@ -41,29 +42,44 @@ except ImportError:
|
||||
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
|
||||
|
||||
|
||||
def _threefry2x32_lowering(prng, platform, keys, data):
|
||||
def _threefry2x32_lowering(prng, platform, keys, data,
|
||||
length: Optional[Union[int, ir.Value]] = None):
|
||||
"""ThreeFry2x32 kernel for GPU."""
|
||||
assert len(keys) == 2, keys
|
||||
assert len(data) == 2, data
|
||||
assert (ir.RankedTensorType(keys[0].type).element_type ==
|
||||
ir.IntegerType.get_unsigned(32)), keys[0].type
|
||||
|
||||
typ = keys[0].type
|
||||
dims = ir.RankedTensorType(typ).shape
|
||||
if any(d < 0 for d in dims):
|
||||
raise NotImplementedError("Shape polymorphism for custom call is not implemented (threefry); b/261671778")
|
||||
|
||||
for x in itertools.chain(keys, data):
|
||||
assert x.type == typ, (x.type, typ)
|
||||
ndims = len(dims)
|
||||
|
||||
opaque = prng.threefry2x32_descriptor(_prod(dims))
|
||||
layout = tuple(range(ndims - 1, -1, -1))
|
||||
operand_layouts = [layout] * 4
|
||||
operands = [keys[0], keys[1], data[0], data[1]]
|
||||
|
||||
if length is None:
|
||||
length = _prod(dims)
|
||||
|
||||
if isinstance(length, int):
|
||||
opaque = prng.threefry2x32_descriptor(length)
|
||||
else:
|
||||
opaque = prng.threefry2x32_descriptor(-1)
|
||||
assert (ir.RankedTensorType(length.type).element_type ==
|
||||
ir.IntegerType.get_signless(64)), length
|
||||
assert (ir.RankedTensorType(length.type).shape ==
|
||||
[1]), (length, ir.RankedTensorType(length.type).shape)
|
||||
operands.append(length)
|
||||
operand_layouts.append((0,))
|
||||
|
||||
return custom_call(
|
||||
f"{platform}_threefry2x32",
|
||||
[typ, typ],
|
||||
[keys[0], keys[1], data[0], data[1]],
|
||||
operands,
|
||||
backend_config=opaque,
|
||||
operand_layouts=[layout] * 4,
|
||||
operand_layouts=operand_layouts,
|
||||
result_layouts=[layout] * 2)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user