mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
reintroduce the Threefry GPU kernel lowering, under a flag
On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with: `jax.config.update('jax_threefry_gpu_kernel_lowering', True)` PiperOrigin-RevId: 629763763
This commit is contained in:
parent
9bf1148e74
commit
3f9540761e
@ -33,6 +33,12 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
be created and threaded in and out of computations to build up dependency.
|
be created and threaded in and out of computations to build up dependency.
|
||||||
The singleton object `core.token` has been removed, users now should create
|
The singleton object `core.token` has been removed, users now should create
|
||||||
and use fresh `core.Token` objects instead.
|
and use fresh `core.Token` objects instead.
|
||||||
|
* On GPU, the Threefry PRNG implementation no longer lowers to a kernel call
|
||||||
|
by default. This choice can improve runtime memory usage at a compile-time
|
||||||
|
cost. Prior behavior, which produces a kernel call, can be recovered with
|
||||||
|
`jax.config.update('jax_threefry_gpu_kernel_lowering', True)`. If the new
|
||||||
|
default causes issues, please file a bug. Otherwise, we intend to remove
|
||||||
|
this flag in a future release.
|
||||||
|
|
||||||
* Deprecations & Removals
|
* Deprecations & Removals
|
||||||
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
|
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
|
||||||
|
@ -210,6 +210,7 @@ def trace_context():
|
|||||||
dynamic_shapes.value, numpy_dtype_promotion.value,
|
dynamic_shapes.value, numpy_dtype_promotion.value,
|
||||||
default_device.value, random_seed_offset.value,
|
default_device.value, random_seed_offset.value,
|
||||||
threefry_partitionable.value,
|
threefry_partitionable.value,
|
||||||
|
threefry_gpu_kernel_lowering.value,
|
||||||
softmax_custom_jvp.value,
|
softmax_custom_jvp.value,
|
||||||
enable_memories.value,
|
enable_memories.value,
|
||||||
disable_jit.value,
|
disable_jit.value,
|
||||||
@ -811,6 +812,7 @@ class _GlobalExtraJitContext(NamedTuple):
|
|||||||
dynamic_shapes: bool = False
|
dynamic_shapes: bool = False
|
||||||
random_seed_offset: int = 0
|
random_seed_offset: int = 0
|
||||||
threefry_partitionable: bool = False
|
threefry_partitionable: bool = False
|
||||||
|
threefry_gpu_kernel_lowering: bool = False
|
||||||
softmax_custom_jvp: bool = False
|
softmax_custom_jvp: bool = False
|
||||||
xla_profile_version: int = 0
|
xla_profile_version: int = 0
|
||||||
|
|
||||||
@ -845,6 +847,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
|
|||||||
dynamic_shapes: bool | None = None
|
dynamic_shapes: bool | None = None
|
||||||
random_seed_offset: int | None = None
|
random_seed_offset: int | None = None
|
||||||
threefry_partitionable: bool | None = None
|
threefry_partitionable: bool | None = None
|
||||||
|
threefry_gpu_kernel_lowering: bool | None = None
|
||||||
softmax_custom_jvp: bool | None = None
|
softmax_custom_jvp: bool | None = None
|
||||||
xla_profile_version: int | None = None
|
xla_profile_version: int | None = None
|
||||||
|
|
||||||
@ -1083,6 +1086,17 @@ threefry_partitionable = define_bool_state(
|
|||||||
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
||||||
threefry_partitionable=val))
|
threefry_partitionable=val))
|
||||||
|
|
||||||
|
threefry_gpu_kernel_lowering = define_bool_state(
|
||||||
|
name='jax_threefry_gpu_kernel_lowering',
|
||||||
|
default=False,
|
||||||
|
help=('On GPU, lower threefry PRNG operations to a kernel implementation. '
|
||||||
|
'This makes compile times faster at a potential runtime memory '
|
||||||
|
'cost.'),
|
||||||
|
update_global_hook=lambda val: _update_global_jit_state(
|
||||||
|
threefry_gpu_kernel_lowering=val),
|
||||||
|
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
||||||
|
threefry_gpu_kernel_lowering=val))
|
||||||
|
|
||||||
|
|
||||||
softmax_custom_jvp = define_bool_state(
|
softmax_custom_jvp = define_bool_state(
|
||||||
name='jax_softmax_custom_jvp',
|
name='jax_softmax_custom_jvp',
|
||||||
|
@ -47,8 +47,9 @@ from jax._src.interpreters import pxla
|
|||||||
from jax._src.interpreters import xla
|
from jax._src.interpreters import xla
|
||||||
from jax._src.lax import lax as lax_internal
|
from jax._src.lax import lax as lax_internal
|
||||||
from jax._src.lax import utils as lax_utils
|
from jax._src.lax import utils as lax_utils
|
||||||
from jax._src.lib.mlir import ir
|
from jax._src.lib import gpu_prng
|
||||||
from jax._src.lib import xla_client as xc
|
from jax._src.lib import xla_client as xc
|
||||||
|
from jax._src.lib.mlir import ir
|
||||||
from jax._src.lib.mlir.dialects import hlo
|
from jax._src.lib.mlir.dialects import hlo
|
||||||
from jax._src.numpy.array_methods import (
|
from jax._src.numpy.array_methods import (
|
||||||
_array_operators, _set_array_base_attributes, _IndexUpdateHelper)
|
_array_operators, _set_array_base_attributes, _IndexUpdateHelper)
|
||||||
@ -1002,17 +1003,63 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
|
|||||||
return tuple(x)
|
return tuple(x)
|
||||||
|
|
||||||
|
|
||||||
|
_threefry2x32_lowering_rule = mlir.lower_fun(
|
||||||
|
partial(_threefry2x32_lowering, use_rolled_loops=False),
|
||||||
|
multiple_results=True)
|
||||||
|
|
||||||
|
_threefry2x32_cpu_lowering_rule = mlir.lower_fun(
|
||||||
|
partial(_threefry2x32_lowering, use_rolled_loops=True),
|
||||||
|
multiple_results=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2):
|
||||||
|
if not config.threefry_gpu_kernel_lowering.value: # back to default lowering
|
||||||
|
return _threefry2x32_lowering_rule(ctx, k1, k2, x1, x2)
|
||||||
|
|
||||||
|
aval_out, aval_out_2 = ctx.avals_out
|
||||||
|
assert aval_out == aval_out_2
|
||||||
|
k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in
|
||||||
|
rank = len(aval_out.shape)
|
||||||
|
if 0 in aval_out.shape:
|
||||||
|
zeros = mlir.full_like_aval(ctx, 0, aval_out)
|
||||||
|
return [zeros, zeros]
|
||||||
|
def _broadcast(x, aval):
|
||||||
|
return mlir.broadcast_in_dim(ctx, x, aval_out,
|
||||||
|
broadcast_dimensions=range(rank - len(aval.shape), rank))
|
||||||
|
|
||||||
|
out_len = reduce(op.mul, aval_out.shape, 1)
|
||||||
|
if not core.is_constant_dim(out_len):
|
||||||
|
length = mlir.eval_dynamic_shape_as_tensor(ctx, [out_len])
|
||||||
|
length = mlir.hlo.convert(
|
||||||
|
ir.RankedTensorType.get((1,), ir.IntegerType.get_signless(64)),
|
||||||
|
length)
|
||||||
|
output_shape = mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape)
|
||||||
|
else:
|
||||||
|
length = int(out_len) # will be passed statically
|
||||||
|
output_shape = None
|
||||||
|
|
||||||
|
return lowering_func(
|
||||||
|
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
||||||
|
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length,
|
||||||
|
output_shape)
|
||||||
|
|
||||||
threefry2x32_p = core.Primitive("threefry2x32")
|
threefry2x32_p = core.Primitive("threefry2x32")
|
||||||
threefry2x32_p.multiple_results = True
|
threefry2x32_p.multiple_results = True
|
||||||
threefry2x32_p.def_impl(partial(dispatch.apply_primitive, threefry2x32_p))
|
threefry2x32_p.def_impl(partial(dispatch.apply_primitive, threefry2x32_p))
|
||||||
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
|
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
|
||||||
batching.defbroadcasting(threefry2x32_p)
|
batching.defbroadcasting(threefry2x32_p)
|
||||||
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
|
mlir.register_lowering(
|
||||||
partial(_threefry2x32_lowering, use_rolled_loops=False),
|
threefry2x32_p, _threefry2x32_lowering_rule)
|
||||||
multiple_results=True))
|
mlir.register_lowering(
|
||||||
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
|
threefry2x32_p, _threefry2x32_cpu_lowering_rule, platform='cpu')
|
||||||
partial(_threefry2x32_lowering, use_rolled_loops=True),
|
mlir.register_lowering(
|
||||||
multiple_results=True), platform='cpu')
|
threefry2x32_p,
|
||||||
|
partial(_threefry2x32_gpu_lowering_rule, gpu_prng.cuda_threefry2x32),
|
||||||
|
platform='cuda')
|
||||||
|
mlir.register_lowering(
|
||||||
|
threefry2x32_p,
|
||||||
|
partial(_threefry2x32_gpu_lowering_rule, gpu_prng.rocm_threefry2x32),
|
||||||
|
platform='rocm')
|
||||||
|
|
||||||
|
|
||||||
def iota_2x32_shape(shape):
|
def iota_2x32_shape(shape):
|
||||||
|
@ -793,7 +793,7 @@ def _check_lowering(lowering) -> None:
|
|||||||
# Their backwards compatibility is tested by back_compat_test.py.
|
# Their backwards compatibility is tested by back_compat_test.py.
|
||||||
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
||||||
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
|
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
|
||||||
"dynamic_ducc_fft",
|
"dynamic_ducc_fft", "cu_threefry2x32",
|
||||||
# cholesky on CPU
|
# cholesky on CPU
|
||||||
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
|
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
|
||||||
# eigh on CPU
|
# eigh on CPU
|
||||||
|
@ -65,7 +65,8 @@ config.parse_flags_with_absl()
|
|||||||
|
|
||||||
|
|
||||||
@jtu.with_config(jax_legacy_prng_key='allow',
|
@jtu.with_config(jax_legacy_prng_key='allow',
|
||||||
jax_debug_key_reuse=False)
|
jax_debug_key_reuse=False,
|
||||||
|
jax_threefry_gpu_kernel_lowering=True)
|
||||||
class CompatTest(bctu.CompatTestBase):
|
class CompatTest(bctu.CompatTestBase):
|
||||||
def test_dummy(self):
|
def test_dummy(self):
|
||||||
# Tests the testing mechanism. Let this test run on all platforms
|
# Tests the testing mechanism. Let this test run on all platforms
|
||||||
@ -573,12 +574,11 @@ class CompatTest(bctu.CompatTestBase):
|
|||||||
self.run_one_test(func, data)
|
self.run_one_test(func, data)
|
||||||
|
|
||||||
def test_cuda_threefry2x32(self):
|
def test_cuda_threefry2x32(self):
|
||||||
# TODO(frostig): remove after 2024-11-01
|
|
||||||
def func(x):
|
def func(x):
|
||||||
return jax.random.uniform(x, (2, 4), dtype=np.float32)
|
return jax.random.uniform(x, (2, 4), dtype=np.float32)
|
||||||
|
|
||||||
data = self.load_testdata(cuda_threefry2x32.data_2023_03_15)
|
data = self.load_testdata(cuda_threefry2x32.data_2023_03_15)
|
||||||
self.run_one_test(func, data, expect_current_custom_calls=[])
|
self.run_one_test(func, data)
|
||||||
|
|
||||||
def test_sharding(self):
|
def test_sharding(self):
|
||||||
# Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU
|
# Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU
|
||||||
|
@ -385,6 +385,16 @@ class PrngTest(jtu.JaxTestCase):
|
|||||||
random.key_data(random.fold_in(make_key(seed), 4)),
|
random.key_data(random.fold_in(make_key(seed), 4)),
|
||||||
np.array([2285895361, 433833334], dtype='uint32'))
|
np.array([2285895361, 433833334], dtype='uint32'))
|
||||||
|
|
||||||
|
@jtu.run_on_devices("gpu")
|
||||||
|
def test_threefry_gpu_kernel_lowering(self):
|
||||||
|
f = lambda key: jax.random.uniform(key, (1,))
|
||||||
|
with jax._src.config.threefry_gpu_kernel_lowering(False):
|
||||||
|
hlo_text = jax.jit(f).lower(jax.random.key(17)).as_text()
|
||||||
|
self.assertNotIn("cu_threefry2x32", hlo_text)
|
||||||
|
with jax._src.config.threefry_gpu_kernel_lowering(True):
|
||||||
|
hlo_text = jax.jit(f).lower(jax.random.key(17)).as_text()
|
||||||
|
self.assertIn("cu_threefry2x32", hlo_text)
|
||||||
|
|
||||||
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
|
||||||
def test_random_seed_offset(self, make_key):
|
def test_random_seed_offset(self, make_key):
|
||||||
k1 = make_key(17)
|
k1 = make_key(17)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user