Remove forward compatibility mode for old PRGN custom call on GPU

The backend support for the new custom call was added on June 28th.
Also add backwards compatibility test for the new custom call.

PiperOrigin-RevId: 658011228
This commit is contained in:
George Necula 2024-07-31 08:09:28 -07:00 committed by jax authors
parent 618754d829
commit 65450d165e
6 changed files with 114 additions and 68 deletions

File diff suppressed because one or more lines are too long

View File

@ -46,7 +46,6 @@ from jax._src.interpreters import xla
from jax._src.lax import lax as lax_internal
from jax._src.lib import gpu_prng
from jax._src.lib import xla_client as xc
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.array_methods import (
@ -901,16 +900,6 @@ def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2):
if not config.threefry_gpu_kernel_lowering.value: # back to default lowering
return _threefry2x32_lowering_rule(ctx, k1, k2, x1, x2)
# TODO(b/338022728): when we export, use the old custom call target for now.
# Make forward_compatibility_mode False after 3 weeks.
# TODO(b/350111820): figure out why we cannot use the new cu_threefry2x32_ffi
# in Kokoro tests. For now, use the old cu_threefry2x32.
# lowering_parameters = ctx.module_context.lowering_parameters
# forward_compatibility_mode = (
# lowering_parameters.for_export and
# not lowering_parameters.export_ignore_forward_compatibility)
forward_compatibility_mode = True
aval_out, aval_out_2 = ctx.avals_out
assert aval_out == aval_out_2
k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in
@ -933,17 +922,13 @@ def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2):
length = int(out_len) # will be passed statically
output_shape = None
if jaxlib_version >= (0, 4, 31):
return lowering_func(
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length,
output_shape,
forward_compatibility_mode)
else:
return lowering_func(
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length,
output_shape)
False, # forward_compatibility_mode
)
threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True

View File

@ -30,18 +30,13 @@ nb::dict Registrations() {
nb::dict dict;
dict[JAX_GPU_PREFIX "_threefry2x32_ffi"] =
EncapsulateFfiHandler(ThreeFry2x32Ffi);
// TODO(b/338022728): remove after 3 weeks
// TODO(b/338022728): remove after 6 months
dict[JAX_GPU_PREFIX "_threefry2x32"] = EncapsulateFunction(ThreeFry2x32);
return dict;
}
NB_MODULE(_prng, m) {
m.def("registrations", &Registrations);
// TODO(b/338022728): remove after 3 weeks
m.def("threefry2x32_descriptor", [](std::int64_t n) {
std::string result = BuildThreeFry2x32Descriptor(n);
return nb::bytes(result.data(), result.size());
});
}
} // namespace

View File

@ -25,7 +25,7 @@ import jaxlib.mlir.ir as ir
from jaxlib import xla_client
from .hlo_helpers import custom_call
from .gpu_common_utils import GpuLibNotLinkedError
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
@ -63,18 +63,17 @@ if _hip_prng:
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
# TODO(b/338022728): forward_compatibility_mode=False after 3 weeks.
def _threefry2x32_lowering(prng, platform: str, keys, data,
length: int | ir.Value | None = None,
output_shape: ir.Value | None = None,
forward_compatibility_mode: bool = True):
forward_compatibility_mode: bool = False):
"""ThreeFry2x32 kernel for GPU.
In presence of dynamic shapes, `length` is an `ir.Value` and `output_shape`
is a 1D tensor describing the shape of the two outputs.
"""
if forward_compatibility_mode and not prng:
raise GpuLibNotLinkedError()
del forward_compatibility_mode
assert len(keys) == 2, keys
assert len(data) == 2, data
assert (ir.RankedTensorType(keys[0].type).element_type ==
@ -90,37 +89,18 @@ def _threefry2x32_lowering(prng, platform: str, keys, data,
operand_layouts = [layout] * 4
operands = [keys[0], keys[1], data[0], data[1]]
if forward_compatibility_mode and length is None:
length = _prod(dims)
opaque = {} # Use if not forward_compatibility_mode to trigger the FFI (v4).
if isinstance(length, int):
if forward_compatibility_mode:
opaque = prng.threefry2x32_descriptor(length)
result_shapes = None
else:
assert output_shape is not None
if forward_compatibility_mode:
opaque = prng.threefry2x32_descriptor(-1)
assert (ir.RankedTensorType(length.type).element_type == # type: ignore[attribute-error]
ir.IntegerType.get_signless(64)), length
assert (ir.RankedTensorType(length.type).shape == # type: ignore[attribute-error]
[1]), (length, ir.RankedTensorType(length.type).shape) # type: ignore[attribute-error]
# Pass the length, which will be used by the custom call target since the
# static length in the descriptor is -1.
operands.append(length)
operand_layouts.append((0,))
# We also need to pass separately the shapes of the outputs.
result_shapes = [output_shape, output_shape]
custom_call_target = (
f"{platform}_threefry2x32"
if forward_compatibility_mode
else f"{platform}_threefry2x32_ffi"
)
custom_call_target = f"{platform}_threefry2x32_ffi"
return custom_call(
custom_call_target,
api_version=(2 if forward_compatibility_mode else 4),
api_version=4,
result_types=[typ, typ],
operands=operands,
backend_config=opaque,

View File

@ -66,11 +66,13 @@ from jax._src.lib import version as jaxlib_version
config.parse_flags_with_absl()
def _is_required_cusolver_version_satisfied(required_version):
if cuda_versions is None:
return False
return cuda_versions.cusolver_get_version() >= required_version
@jtu.with_config(jax_legacy_prng_key="allow",
jax_debug_key_reuse=False,
jax_include_full_tracebacks_in_locations=False,
@ -116,7 +118,8 @@ class CompatTest(bctu.CompatTestBase):
cpu_cholesky_lapack_potrf.data_2023_06_19,
cpu_eig_lapack_geev.data_2023_06_19,
cpu_eigh_lapack_syev.data_2023_03_17,
cpu_qr_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2023_03_15,
cpu_qr_lapack_geqrf.data_2023_03_17,
cuda_threefry2x32.data_2023_03_15, cuda_threefry2x32.data_2024_07_30,
cpu_lu_lapack_getrf.data_2023_06_14,
cuda_qr_cusolver_geqrf.data_2023_03_18, cuda_eigh_cusolver_syev.data_2023_03_17,
cpu_schur_lapack_gees.data_2023_07_16,
@ -144,7 +147,6 @@ class CompatTest(bctu.CompatTestBase):
"tf.call_tf_function", # tested in jax2tf/tests/back_compat_tf_test.py
"tpu_custom_call", # tested separately
"__gpu$xla.gpu.triton", # tested in pallas/export_back_compat_pallas_test.py
"cu_threefry2x32_ffi", # TODO(b/338022728) add the actual backwards compatibility test
})
not_covered = targets_to_cover.difference(covered_targets)
self.assertEmpty(not_covered,
@ -592,7 +594,12 @@ class CompatTest(bctu.CompatTestBase):
def func(x):
return jax.random.uniform(x, (2, 4), dtype=np.float32)
# TODO(b/338022728): remove after 6 months
data = self.load_testdata(cuda_threefry2x32.data_2023_03_15)
self.run_one_test(func, data,
expect_current_custom_calls=["cu_threefry2x32_ffi"])
data = self.load_testdata(cuda_threefry2x32.data_2024_07_30)
self.run_one_test(func, data)
def test_sharding(self):

View File

@ -36,7 +36,6 @@ from jax import lax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.internal_test_util import test_harnesses
from jax._src.lib import version as jaxlib_version
from jax import random
@ -206,14 +205,8 @@ class PrimitiveTest(jtu.JaxTestCase):
self.export_and_compare_to_native(f, x)
def test_random_with_threefry_gpu_kernel_lowering(self):
# TODO(b/350111820): fix the FFI registration mechanism
self.skipTest("b/350111820: fix the FFI registration mechanism")
if jaxlib_version < (0, 4, 31):
self.skipTest("jaxlib.version < 0.4.31")
# On GPU we use a custom call for threefry2x32
with config.threefry_gpu_kernel_lowering(True):
# TODO(b/338022728): clean up forward compatibility mode.
with config.export_ignore_forward_compatibility(True):
def f(x):
return random.gamma(random.key(42), x)