mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
reintroduce the Threefry GPU kernel lowering, under a flag
On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with: `jax.config.update('jax_threefry_gpu_kernel_lowering', True)` PiperOrigin-RevId: 629763763
This commit is contained in:
parent
9bf1148e74
commit
3f9540761e
@ -33,6 +33,12 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
be created and threaded in and out of computations to build up dependency.
|
||||
The singleton object `core.token` has been removed, users now should create
|
||||
and use fresh `core.Token` objects instead.
|
||||
* On GPU, the Threefry PRNG implementation no longer lowers to a kernel call
|
||||
by default. This choice can improve runtime memory usage at a compile-time
|
||||
cost. Prior behavior, which produces a kernel call, can be recovered with
|
||||
`jax.config.update('jax_threefry_gpu_kernel_lowering', True)`. If the new
|
||||
default causes issues, please file a bug. Otherwise, we intend to remove
|
||||
this flag in a future release.
|
||||
|
||||
* Deprecations & Removals
|
||||
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
|
||||
|
@ -210,6 +210,7 @@ def trace_context():
|
||||
dynamic_shapes.value, numpy_dtype_promotion.value,
|
||||
default_device.value, random_seed_offset.value,
|
||||
threefry_partitionable.value,
|
||||
threefry_gpu_kernel_lowering.value,
|
||||
softmax_custom_jvp.value,
|
||||
enable_memories.value,
|
||||
disable_jit.value,
|
||||
@ -811,6 +812,7 @@ class _GlobalExtraJitContext(NamedTuple):
|
||||
dynamic_shapes: bool = False
|
||||
random_seed_offset: int = 0
|
||||
threefry_partitionable: bool = False
|
||||
threefry_gpu_kernel_lowering: bool = False
|
||||
softmax_custom_jvp: bool = False
|
||||
xla_profile_version: int = 0
|
||||
|
||||
@ -845,6 +847,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
|
||||
dynamic_shapes: bool | None = None
|
||||
random_seed_offset: int | None = None
|
||||
threefry_partitionable: bool | None = None
|
||||
threefry_gpu_kernel_lowering: bool | None = None
|
||||
softmax_custom_jvp: bool | None = None
|
||||
xla_profile_version: int | None = None
|
||||
|
||||
@ -1083,6 +1086,17 @@ threefry_partitionable = define_bool_state(
|
||||
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
||||
threefry_partitionable=val))
|
||||
|
||||
threefry_gpu_kernel_lowering = define_bool_state(
|
||||
name='jax_threefry_gpu_kernel_lowering',
|
||||
default=False,
|
||||
help=('On GPU, lower threefry PRNG operations to a kernel implementation. '
|
||||
'This makes compile times faster at a potential runtime memory '
|
||||
'cost.'),
|
||||
update_global_hook=lambda val: _update_global_jit_state(
|
||||
threefry_gpu_kernel_lowering=val),
|
||||
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
||||
threefry_gpu_kernel_lowering=val))
|
||||
|
||||
|
||||
softmax_custom_jvp = define_bool_state(
|
||||
name='jax_softmax_custom_jvp',
|
||||
|
@ -47,8 +47,9 @@ from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import utils as lax_utils
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib import gpu_prng
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy.array_methods import (
|
||||
_array_operators, _set_array_base_attributes, _IndexUpdateHelper)
|
||||
@ -1002,17 +1003,63 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
|
||||
return tuple(x)
|
||||
|
||||
|
||||
_threefry2x32_lowering_rule = mlir.lower_fun(
|
||||
partial(_threefry2x32_lowering, use_rolled_loops=False),
|
||||
multiple_results=True)
|
||||
|
||||
_threefry2x32_cpu_lowering_rule = mlir.lower_fun(
|
||||
partial(_threefry2x32_lowering, use_rolled_loops=True),
|
||||
multiple_results=True)
|
||||
|
||||
|
||||
def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2):
|
||||
if not config.threefry_gpu_kernel_lowering.value: # back to default lowering
|
||||
return _threefry2x32_lowering_rule(ctx, k1, k2, x1, x2)
|
||||
|
||||
aval_out, aval_out_2 = ctx.avals_out
|
||||
assert aval_out == aval_out_2
|
||||
k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in
|
||||
rank = len(aval_out.shape)
|
||||
if 0 in aval_out.shape:
|
||||
zeros = mlir.full_like_aval(ctx, 0, aval_out)
|
||||
return [zeros, zeros]
|
||||
def _broadcast(x, aval):
|
||||
return mlir.broadcast_in_dim(ctx, x, aval_out,
|
||||
broadcast_dimensions=range(rank - len(aval.shape), rank))
|
||||
|
||||
out_len = reduce(op.mul, aval_out.shape, 1)
|
||||
if not core.is_constant_dim(out_len):
|
||||
length = mlir.eval_dynamic_shape_as_tensor(ctx, [out_len])
|
||||
length = mlir.hlo.convert(
|
||||
ir.RankedTensorType.get((1,), ir.IntegerType.get_signless(64)),
|
||||
length)
|
||||
output_shape = mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape)
|
||||
else:
|
||||
length = int(out_len) # will be passed statically
|
||||
output_shape = None
|
||||
|
||||
return lowering_func(
|
||||
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
||||
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length,
|
||||
output_shape)
|
||||
|
||||
threefry2x32_p = core.Primitive("threefry2x32")
|
||||
threefry2x32_p.multiple_results = True
|
||||
threefry2x32_p.def_impl(partial(dispatch.apply_primitive, threefry2x32_p))
|
||||
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
|
||||
batching.defbroadcasting(threefry2x32_p)
|
||||
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
|
||||
partial(_threefry2x32_lowering, use_rolled_loops=False),
|
||||
multiple_results=True))
|
||||
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
|
||||
partial(_threefry2x32_lowering, use_rolled_loops=True),
|
||||
multiple_results=True), platform='cpu')
|
||||
mlir.register_lowering(
|
||||
threefry2x32_p, _threefry2x32_lowering_rule)
|
||||
mlir.register_lowering(
|
||||
threefry2x32_p, _threefry2x32_cpu_lowering_rule, platform='cpu')
|
||||
mlir.register_lowering(
|
||||
threefry2x32_p,
|
||||
partial(_threefry2x32_gpu_lowering_rule, gpu_prng.cuda_threefry2x32),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
threefry2x32_p,
|
||||
partial(_threefry2x32_gpu_lowering_rule, gpu_prng.rocm_threefry2x32),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
def iota_2x32_shape(shape):
|
||||
|
@ -793,7 +793,7 @@ def _check_lowering(lowering) -> None:
|
||||
# Their backwards compatibility is tested by back_compat_test.py.
|
||||
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
||||
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
|
||||
"dynamic_ducc_fft",
|
||||
"dynamic_ducc_fft", "cu_threefry2x32",
|
||||
# cholesky on CPU
|
||||
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
|
||||
# eigh on CPU
|
||||
|
@ -65,7 +65,8 @@ config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.with_config(jax_legacy_prng_key='allow',
|
||||
jax_debug_key_reuse=False)
|
||||
jax_debug_key_reuse=False,
|
||||
jax_threefry_gpu_kernel_lowering=True)
|
||||
class CompatTest(bctu.CompatTestBase):
|
||||
def test_dummy(self):
|
||||
# Tests the testing mechanism. Let this test run on all platforms
|
||||
@ -573,12 +574,11 @@ class CompatTest(bctu.CompatTestBase):
|
||||
self.run_one_test(func, data)
|
||||
|
||||
def test_cuda_threefry2x32(self):
|
||||
# TODO(frostig): remove after 2024-11-01
|
||||
def func(x):
|
||||
return jax.random.uniform(x, (2, 4), dtype=np.float32)
|
||||
|
||||
data = self.load_testdata(cuda_threefry2x32.data_2023_03_15)
|
||||
self.run_one_test(func, data, expect_current_custom_calls=[])
|
||||
self.run_one_test(func, data)
|
||||
|
||||
def test_sharding(self):
|
||||
# Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU
|
||||
|
@ -385,6 +385,16 @@ class PrngTest(jtu.JaxTestCase):
|
||||
random.key_data(random.fold_in(make_key(seed), 4)),
|
||||
np.array([2285895361, 433833334], dtype='uint32'))
|
||||
|
||||
@jtu.run_on_devices("gpu")
|
||||
def test_threefry_gpu_kernel_lowering(self):
|
||||
f = lambda key: jax.random.uniform(key, (1,))
|
||||
with jax._src.config.threefry_gpu_kernel_lowering(False):
|
||||
hlo_text = jax.jit(f).lower(jax.random.key(17)).as_text()
|
||||
self.assertNotIn("cu_threefry2x32", hlo_text)
|
||||
with jax._src.config.threefry_gpu_kernel_lowering(True):
|
||||
hlo_text = jax.jit(f).lower(jax.random.key(17)).as_text()
|
||||
self.assertIn("cu_threefry2x32", hlo_text)
|
||||
|
||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||
def test_random_seed_offset(self, make_key):
|
||||
k1 = make_key(17)
|
||||
|
Loading…
x
Reference in New Issue
Block a user