reintroduce the Threefry GPU kernel lowering, under a flag

On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with:

   `jax.config.update('jax_threefry_gpu_kernel_lowering', True)`

PiperOrigin-RevId: 629763763
This commit is contained in:
Roy Frostig 2024-05-01 10:32:36 -07:00 committed by jax authors
parent 9bf1148e74
commit 3f9540761e
6 changed files with 88 additions and 11 deletions

View File

@ -33,6 +33,12 @@ Remember to align the itemized text with the first line of an item within a list
be created and threaded in and out of computations to build up dependency.
The singleton object `core.token` has been removed, users now should create
and use fresh `core.Token` objects instead.
* On GPU, the Threefry PRNG implementation no longer lowers to a kernel call
by default. This choice can improve runtime memory usage at a compile-time
cost. Prior behavior, which produces a kernel call, can be recovered with
`jax.config.update('jax_threefry_gpu_kernel_lowering', True)`. If the new
default causes issues, please file a bug. Otherwise, we intend to remove
this flag in a future release.
* Deprecations & Removals
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old

View File

@ -210,6 +210,7 @@ def trace_context():
dynamic_shapes.value, numpy_dtype_promotion.value,
default_device.value, random_seed_offset.value,
threefry_partitionable.value,
threefry_gpu_kernel_lowering.value,
softmax_custom_jvp.value,
enable_memories.value,
disable_jit.value,
@ -811,6 +812,7 @@ class _GlobalExtraJitContext(NamedTuple):
dynamic_shapes: bool = False
random_seed_offset: int = 0
threefry_partitionable: bool = False
threefry_gpu_kernel_lowering: bool = False
softmax_custom_jvp: bool = False
xla_profile_version: int = 0
@ -845,6 +847,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
dynamic_shapes: bool | None = None
random_seed_offset: int | None = None
threefry_partitionable: bool | None = None
threefry_gpu_kernel_lowering: bool | None = None
softmax_custom_jvp: bool | None = None
xla_profile_version: int | None = None
@ -1083,6 +1086,17 @@ threefry_partitionable = define_bool_state(
update_thread_local_hook=lambda val: update_thread_local_jit_state(
threefry_partitionable=val))
threefry_gpu_kernel_lowering = define_bool_state(
name='jax_threefry_gpu_kernel_lowering',
default=False,
help=('On GPU, lower threefry PRNG operations to a kernel implementation. '
'This makes compile times faster at a potential runtime memory '
'cost.'),
update_global_hook=lambda val: _update_global_jit_state(
threefry_gpu_kernel_lowering=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
threefry_gpu_kernel_lowering=val))
softmax_custom_jvp = define_bool_state(
name='jax_softmax_custom_jvp',

View File

@ -47,8 +47,9 @@ from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.lax import lax as lax_internal
from jax._src.lax import utils as lax_utils
from jax._src.lib.mlir import ir
from jax._src.lib import gpu_prng
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.array_methods import (
_array_operators, _set_array_base_attributes, _IndexUpdateHelper)
@ -1002,17 +1003,63 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
return tuple(x)
_threefry2x32_lowering_rule = mlir.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=False),
multiple_results=True)
_threefry2x32_cpu_lowering_rule = mlir.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=True),
multiple_results=True)
def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2):
if not config.threefry_gpu_kernel_lowering.value: # back to default lowering
return _threefry2x32_lowering_rule(ctx, k1, k2, x1, x2)
aval_out, aval_out_2 = ctx.avals_out
assert aval_out == aval_out_2
k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in
rank = len(aval_out.shape)
if 0 in aval_out.shape:
zeros = mlir.full_like_aval(ctx, 0, aval_out)
return [zeros, zeros]
def _broadcast(x, aval):
return mlir.broadcast_in_dim(ctx, x, aval_out,
broadcast_dimensions=range(rank - len(aval.shape), rank))
out_len = reduce(op.mul, aval_out.shape, 1)
if not core.is_constant_dim(out_len):
length = mlir.eval_dynamic_shape_as_tensor(ctx, [out_len])
length = mlir.hlo.convert(
ir.RankedTensorType.get((1,), ir.IntegerType.get_signless(64)),
length)
output_shape = mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape)
else:
length = int(out_len) # will be passed statically
output_shape = None
return lowering_func(
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length,
output_shape)
threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True
threefry2x32_p.def_impl(partial(dispatch.apply_primitive, threefry2x32_p))
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
batching.defbroadcasting(threefry2x32_p)
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=False),
multiple_results=True))
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=True),
multiple_results=True), platform='cpu')
mlir.register_lowering(
threefry2x32_p, _threefry2x32_lowering_rule)
mlir.register_lowering(
threefry2x32_p, _threefry2x32_cpu_lowering_rule, platform='cpu')
mlir.register_lowering(
threefry2x32_p,
partial(_threefry2x32_gpu_lowering_rule, gpu_prng.cuda_threefry2x32),
platform='cuda')
mlir.register_lowering(
threefry2x32_p,
partial(_threefry2x32_gpu_lowering_rule, gpu_prng.rocm_threefry2x32),
platform='rocm')
def iota_2x32_shape(shape):

View File

@ -793,7 +793,7 @@ def _check_lowering(lowering) -> None:
# Their backwards compatibility is tested by back_compat_test.py.
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
"dynamic_ducc_fft",
"dynamic_ducc_fft", "cu_threefry2x32",
# cholesky on CPU
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
# eigh on CPU

View File

@ -65,7 +65,8 @@ config.parse_flags_with_absl()
@jtu.with_config(jax_legacy_prng_key='allow',
jax_debug_key_reuse=False)
jax_debug_key_reuse=False,
jax_threefry_gpu_kernel_lowering=True)
class CompatTest(bctu.CompatTestBase):
def test_dummy(self):
# Tests the testing mechanism. Let this test run on all platforms
@ -573,12 +574,11 @@ class CompatTest(bctu.CompatTestBase):
self.run_one_test(func, data)
def test_cuda_threefry2x32(self):
# TODO(frostig): remove after 2024-11-01
def func(x):
return jax.random.uniform(x, (2, 4), dtype=np.float32)
data = self.load_testdata(cuda_threefry2x32.data_2023_03_15)
self.run_one_test(func, data, expect_current_custom_calls=[])
self.run_one_test(func, data)
def test_sharding(self):
# Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU

View File

@ -385,6 +385,16 @@ class PrngTest(jtu.JaxTestCase):
random.key_data(random.fold_in(make_key(seed), 4)),
np.array([2285895361, 433833334], dtype='uint32'))
@jtu.run_on_devices("gpu")
def test_threefry_gpu_kernel_lowering(self):
f = lambda key: jax.random.uniform(key, (1,))
with jax._src.config.threefry_gpu_kernel_lowering(False):
hlo_text = jax.jit(f).lower(jax.random.key(17)).as_text()
self.assertNotIn("cu_threefry2x32", hlo_text)
with jax._src.config.threefry_gpu_kernel_lowering(True):
hlo_text = jax.jit(f).lower(jax.random.key(17)).as_text()
self.assertIn("cu_threefry2x32", hlo_text)
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_random_seed_offset(self, make_key):
k1 = make_key(17)