Add support for dynamic shapes to GPU threefry2x32 custom call.

In presence of dynamic shapes the ThreeFry2x32Descriptor will contain the
value n=-1, and the actual desired output length will be passed as
an additional operand. If the shape is static then the length will be
passed as part of the descriptor.

PiperOrigin-RevId: 497945778
This commit is contained in:
George Necula 2022-12-27 04:47:48 -08:00 committed by jax authors
parent 0c8a4fb7cd
commit 7d452adfd3
5 changed files with 68 additions and 22 deletions

View File

@ -44,7 +44,7 @@ from jax._src.lax import utils as lax_utils
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy import lax_numpy
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip
from jax._src.lib import gpu_prng
from jax._src.lib import gpu_prng, xla_extension_version
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
@ -940,7 +940,7 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
return tuple(x)
def _threefry2x32_gpu_lowering(threefry2x32_lowering, ctx, k1, k2, x1, x2):
def _threefry2x32_gpu_lowering(lowering_func, ctx, k1, k2, x1, x2):
aval_out, _ = ctx.avals_out
k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in
rank = len(aval_out.shape)
@ -950,10 +950,24 @@ def _threefry2x32_gpu_lowering(threefry2x32_lowering, ctx, k1, k2, x1, x2):
def _broadcast(x, aval):
return mlir.broadcast_in_dim(ctx, x, aval_out,
broadcast_dimensions=range(rank - len(aval.shape), rank))
return threefry2x32_lowering(
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)))
if xla_extension_version >= 113:
out_len = reduce(op.mul, aval_out.shape, 1)
if not core.is_constant_dim(out_len):
length = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, [out_len]))
length = mlir.hlo.ConvertOp(
mlir.ir.RankedTensorType.get((1,), mlir.ir.IntegerType.get_signless(64)),
length).result
else:
length = int(out_len) # will be passed statically
return lowering_func(
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length)
else:
return lowering_func(
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)))
threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True

View File

@ -1873,8 +1873,16 @@ _POLY_SHAPE_TEST_HARNESSES = [
poly_axes=[None, 0]),
PolyHarness("random_categorical", "axis=1",
lambda key, a: jax.random.categorical(key, a, axis=1),
arg_descriptors=[RandArg((2,), np.uint32), RandArg((3, 8), _f32)],
poly_axes=[None, 0]),
arg_descriptors=[RandArg((2,), np.uint32), RandArg((3, 5, 8), _f32)],
poly_axes=[None, (0, 1)]),
PolyHarness("random_categorical", "axis=1_then_reshape",
lambda key, a: jax.random.categorical(key, a, axis=1).reshape((-1)),
arg_descriptors=[RandArg((2,), np.uint32), RandArg((3, 5, 8), _f32)],
poly_axes=[None, (0, 1)]),
PolyHarness("random_categorical", "0_dim", # One axis has 0 size
lambda key, a: jax.random.categorical(key, a, axis=1),
arg_descriptors=[RandArg((2,), np.uint32), RandArg((3, 5, 0), _f32)],
poly_axes=[None, (0, 1)]),
# Works when the known dimensions are known to be even or odd.
PolyHarness("random_uniform", "even_1",
lambda key, a: jax.random.uniform(key, a.shape, dtype=_f32),
@ -2262,9 +2270,6 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
"householder_product:cpu", "householder_product:gpu",
"vmap_geqrf:cpu", "vmap_geqrf:gpu",
"vmap_lu:cpu", "vmap_lu:gpu", "vmap_qr:cpu", "vmap_qr:gpu",
"random_gamma:gpu", "vmap_random_gamma:gpu",
"random_categorical:gpu", "vmap_random_categorical:gpu",
"random_randint:gpu", "random_uniform:gpu", "vmap_random_split:gpu",
"vmap_svd:cpu", "vmap_svd:gpu"}
if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses:
raise unittest.SkipTest("native lowering with shape polymorphism not implemented for custom calls; b/261671778")

View File

@ -107,15 +107,26 @@ void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
std::array<const std::uint32_t*, 2> data;
data[0] = reinterpret_cast<const std::uint32_t*>(buffers[2]);
data[1] = reinterpret_cast<const std::uint32_t*>(buffers[3]);
std::int64_t n = descriptor.n;
int output_idx = 4;
if (n < 0) {
// n is an operand in device memory.
gpuMemcpyAsync((void*)&n, reinterpret_cast<const std::int64_t*>(buffers[4]),
sizeof(n), gpuMemcpyDeviceToHost,
stream);
gpuStreamSynchronize(stream);
output_idx = 5;
}
std::array<std::uint32_t*, 2> out;
out[0] = reinterpret_cast<std::uint32_t*>(buffers[4]);
out[1] = reinterpret_cast<std::uint32_t*>(buffers[5]);
out[0] = reinterpret_cast<std::uint32_t*>(buffers[output_idx]);
out[1] = reinterpret_cast<std::uint32_t*>(buffers[output_idx + 1]);
const int block_dim = 128;
const std::int64_t grid_dim =
std::min<std::int64_t>(1024, (descriptor.n + block_dim - 1) / block_dim);
std::min<std::int64_t>(1024, (n + block_dim - 1) / block_dim);
ThreeFry2x32Kernel<<<grid_dim, block_dim, /*dynamic_shared_mem_bytes=*/0,
stream>>>(keys[0], keys[1], data[0], data[1], out[0],
out[1], descriptor.n);
out[1], n);
}
} // namespace JAX_GPU_NAMESPACE

View File

@ -26,7 +26,7 @@ namespace jax {
namespace JAX_GPU_NAMESPACE {
struct ThreeFry2x32Descriptor {
std::int64_t n;
std::int64_t n; // If -1 then the length is passed as a 5th operand
};
void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,

View File

@ -17,6 +17,7 @@ import functools
from functools import partial
import itertools
import operator
from typing import Optional, Union
import jaxlib.mlir.ir as ir
@ -41,29 +42,44 @@ except ImportError:
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
def _threefry2x32_lowering(prng, platform, keys, data):
def _threefry2x32_lowering(prng, platform, keys, data,
length: Optional[Union[int, ir.Value]] = None):
"""ThreeFry2x32 kernel for GPU."""
assert len(keys) == 2, keys
assert len(data) == 2, data
assert (ir.RankedTensorType(keys[0].type).element_type ==
ir.IntegerType.get_unsigned(32)), keys[0].type
typ = keys[0].type
dims = ir.RankedTensorType(typ).shape
if any(d < 0 for d in dims):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (threefry); b/261671778")
for x in itertools.chain(keys, data):
assert x.type == typ, (x.type, typ)
ndims = len(dims)
opaque = prng.threefry2x32_descriptor(_prod(dims))
layout = tuple(range(ndims - 1, -1, -1))
operand_layouts = [layout] * 4
operands = [keys[0], keys[1], data[0], data[1]]
if length is None:
length = _prod(dims)
if isinstance(length, int):
opaque = prng.threefry2x32_descriptor(length)
else:
opaque = prng.threefry2x32_descriptor(-1)
assert (ir.RankedTensorType(length.type).element_type ==
ir.IntegerType.get_signless(64)), length
assert (ir.RankedTensorType(length.type).shape ==
[1]), (length, ir.RankedTensorType(length.type).shape)
operands.append(length)
operand_layouts.append((0,))
return custom_call(
f"{platform}_threefry2x32",
[typ, typ],
[keys[0], keys[1], data[0], data[1]],
operands,
backend_config=opaque,
operand_layouts=[layout] * 4,
operand_layouts=operand_layouts,
result_layouts=[layout] * 2)