diff --git a/CHANGELOG.md b/CHANGELOG.md index 785d3da3d..197e81797 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,12 @@ Remember to align the itemized text with the first line of an item within a list be created and threaded in and out of computations to build up dependency. The singleton object `core.token` has been removed, users now should create and use fresh `core.Token` objects instead. + * On GPU, the Threefry PRNG implementation no longer lowers to a kernel call + by default. This choice can improve runtime memory usage at a compile-time + cost. Prior behavior, which produces a kernel call, can be recovered with + `jax.config.update('jax_threefry_gpu_kernel_lowering', True)`. If the new + default causes issues, please file a bug. Otherwise, we intend to remove + this flag in a future release. * Deprecations & Removals * Pallas now exclusively uses XLA for compiling kernels on GPU. The old diff --git a/jax/_src/config.py b/jax/_src/config.py index 2282b3f10..8ae6dab0b 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -210,6 +210,7 @@ def trace_context(): dynamic_shapes.value, numpy_dtype_promotion.value, default_device.value, random_seed_offset.value, threefry_partitionable.value, + threefry_gpu_kernel_lowering.value, softmax_custom_jvp.value, enable_memories.value, disable_jit.value, @@ -811,6 +812,7 @@ class _GlobalExtraJitContext(NamedTuple): dynamic_shapes: bool = False random_seed_offset: int = 0 threefry_partitionable: bool = False + threefry_gpu_kernel_lowering: bool = False softmax_custom_jvp: bool = False xla_profile_version: int = 0 @@ -845,6 +847,7 @@ class _ThreadLocalExtraJitContext(NamedTuple): dynamic_shapes: bool | None = None random_seed_offset: int | None = None threefry_partitionable: bool | None = None + threefry_gpu_kernel_lowering: bool | None = None softmax_custom_jvp: bool | None = None xla_profile_version: int | None = None @@ -1083,6 +1086,17 @@ threefry_partitionable = define_bool_state( update_thread_local_hook=lambda val: update_thread_local_jit_state( threefry_partitionable=val)) +threefry_gpu_kernel_lowering = define_bool_state( + name='jax_threefry_gpu_kernel_lowering', + default=False, + help=('On GPU, lower threefry PRNG operations to a kernel implementation. ' + 'This makes compile times faster at a potential runtime memory ' + 'cost.'), + update_global_hook=lambda val: _update_global_jit_state( + threefry_gpu_kernel_lowering=val), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + threefry_gpu_kernel_lowering=val)) + softmax_custom_jvp = define_bool_state( name='jax_softmax_custom_jvp', diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 2d3a76eb0..1d26342c9 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -47,8 +47,9 @@ from jax._src.interpreters import pxla from jax._src.interpreters import xla from jax._src.lax import lax as lax_internal from jax._src.lax import utils as lax_utils -from jax._src.lib.mlir import ir +from jax._src.lib import gpu_prng from jax._src.lib import xla_client as xc +from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.array_methods import ( _array_operators, _set_array_base_attributes, _IndexUpdateHelper) @@ -1002,17 +1003,63 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True): return tuple(x) +_threefry2x32_lowering_rule = mlir.lower_fun( + partial(_threefry2x32_lowering, use_rolled_loops=False), + multiple_results=True) + +_threefry2x32_cpu_lowering_rule = mlir.lower_fun( + partial(_threefry2x32_lowering, use_rolled_loops=True), + multiple_results=True) + + +def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2): + if not config.threefry_gpu_kernel_lowering.value: # back to default lowering + return _threefry2x32_lowering_rule(ctx, k1, k2, x1, x2) + + aval_out, aval_out_2 = ctx.avals_out + assert aval_out == aval_out_2 + k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in + rank = len(aval_out.shape) + if 0 in aval_out.shape: + zeros = mlir.full_like_aval(ctx, 0, aval_out) + return [zeros, zeros] + def _broadcast(x, aval): + return mlir.broadcast_in_dim(ctx, x, aval_out, + broadcast_dimensions=range(rank - len(aval.shape), rank)) + + out_len = reduce(op.mul, aval_out.shape, 1) + if not core.is_constant_dim(out_len): + length = mlir.eval_dynamic_shape_as_tensor(ctx, [out_len]) + length = mlir.hlo.convert( + ir.RankedTensorType.get((1,), ir.IntegerType.get_signless(64)), + length) + output_shape = mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape) + else: + length = int(out_len) # will be passed statically + output_shape = None + + return lowering_func( + (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)), + (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length, + output_shape) + threefry2x32_p = core.Primitive("threefry2x32") threefry2x32_p.multiple_results = True threefry2x32_p.def_impl(partial(dispatch.apply_primitive, threefry2x32_p)) threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval) batching.defbroadcasting(threefry2x32_p) -mlir.register_lowering(threefry2x32_p, mlir.lower_fun( - partial(_threefry2x32_lowering, use_rolled_loops=False), - multiple_results=True)) -mlir.register_lowering(threefry2x32_p, mlir.lower_fun( - partial(_threefry2x32_lowering, use_rolled_loops=True), - multiple_results=True), platform='cpu') +mlir.register_lowering( + threefry2x32_p, _threefry2x32_lowering_rule) +mlir.register_lowering( + threefry2x32_p, _threefry2x32_cpu_lowering_rule, platform='cpu') +mlir.register_lowering( + threefry2x32_p, + partial(_threefry2x32_gpu_lowering_rule, gpu_prng.cuda_threefry2x32), + platform='cuda') +mlir.register_lowering( + threefry2x32_p, + partial(_threefry2x32_gpu_lowering_rule, gpu_prng.rocm_threefry2x32), + platform='rocm') def iota_2x32_shape(shape): diff --git a/jax/experimental/export/_export.py b/jax/experimental/export/_export.py index 441da1a3c..733412772 100644 --- a/jax/experimental/export/_export.py +++ b/jax/experimental/export/_export.py @@ -793,7 +793,7 @@ def _check_lowering(lowering) -> None: # Their backwards compatibility is tested by back_compat_test.py. _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = { "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", - "dynamic_ducc_fft", + "dynamic_ducc_fft", "cu_threefry2x32", # cholesky on CPU "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", # eigh on CPU diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index c9aff9df4..2ac808c34 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -65,7 +65,8 @@ config.parse_flags_with_absl() @jtu.with_config(jax_legacy_prng_key='allow', - jax_debug_key_reuse=False) + jax_debug_key_reuse=False, + jax_threefry_gpu_kernel_lowering=True) class CompatTest(bctu.CompatTestBase): def test_dummy(self): # Tests the testing mechanism. Let this test run on all platforms @@ -573,12 +574,11 @@ class CompatTest(bctu.CompatTestBase): self.run_one_test(func, data) def test_cuda_threefry2x32(self): - # TODO(frostig): remove after 2024-11-01 def func(x): return jax.random.uniform(x, (2, 4), dtype=np.float32) data = self.load_testdata(cuda_threefry2x32.data_2023_03_15) - self.run_one_test(func, data, expect_current_custom_calls=[]) + self.run_one_test(func, data) def test_sharding(self): # Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU diff --git a/tests/random_test.py b/tests/random_test.py index 02c731c11..08e039080 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -385,6 +385,16 @@ class PrngTest(jtu.JaxTestCase): random.key_data(random.fold_in(make_key(seed), 4)), np.array([2285895361, 433833334], dtype='uint32')) + @jtu.run_on_devices("gpu") + def test_threefry_gpu_kernel_lowering(self): + f = lambda key: jax.random.uniform(key, (1,)) + with jax._src.config.threefry_gpu_kernel_lowering(False): + hlo_text = jax.jit(f).lower(jax.random.key(17)).as_text() + self.assertNotIn("cu_threefry2x32", hlo_text) + with jax._src.config.threefry_gpu_kernel_lowering(True): + hlo_text = jax.jit(f).lower(jax.random.key(17)).as_text() + self.assertIn("cu_threefry2x32", hlo_text) + @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) def test_random_seed_offset(self, make_key): k1 = make_key(17)