mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove forward compatibility mode for old PRGN custom call on GPU
The backend support for the new custom call was added on June 28th. Also add backwards compatibility test for the new custom call. PiperOrigin-RevId: 658011228
This commit is contained in:
parent
618754d829
commit
65450d165e
File diff suppressed because one or more lines are too long
@ -46,7 +46,6 @@ from jax._src.interpreters import xla
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import gpu_prng
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy.array_methods import (
|
||||
@ -901,16 +900,6 @@ def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2):
|
||||
if not config.threefry_gpu_kernel_lowering.value: # back to default lowering
|
||||
return _threefry2x32_lowering_rule(ctx, k1, k2, x1, x2)
|
||||
|
||||
# TODO(b/338022728): when we export, use the old custom call target for now.
|
||||
# Make forward_compatibility_mode False after 3 weeks.
|
||||
# TODO(b/350111820): figure out why we cannot use the new cu_threefry2x32_ffi
|
||||
# in Kokoro tests. For now, use the old cu_threefry2x32.
|
||||
# lowering_parameters = ctx.module_context.lowering_parameters
|
||||
# forward_compatibility_mode = (
|
||||
# lowering_parameters.for_export and
|
||||
# not lowering_parameters.export_ignore_forward_compatibility)
|
||||
forward_compatibility_mode = True
|
||||
|
||||
aval_out, aval_out_2 = ctx.avals_out
|
||||
assert aval_out == aval_out_2
|
||||
k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in
|
||||
@ -933,17 +922,13 @@ def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2):
|
||||
length = int(out_len) # will be passed statically
|
||||
output_shape = None
|
||||
|
||||
if jaxlib_version >= (0, 4, 31):
|
||||
return lowering_func(
|
||||
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
||||
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length,
|
||||
output_shape,
|
||||
forward_compatibility_mode)
|
||||
else:
|
||||
return lowering_func(
|
||||
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
||||
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length,
|
||||
output_shape)
|
||||
False, # forward_compatibility_mode
|
||||
)
|
||||
|
||||
|
||||
threefry2x32_p = core.Primitive("threefry2x32")
|
||||
threefry2x32_p.multiple_results = True
|
||||
|
@ -30,18 +30,13 @@ nb::dict Registrations() {
|
||||
nb::dict dict;
|
||||
dict[JAX_GPU_PREFIX "_threefry2x32_ffi"] =
|
||||
EncapsulateFfiHandler(ThreeFry2x32Ffi);
|
||||
// TODO(b/338022728): remove after 3 weeks
|
||||
// TODO(b/338022728): remove after 6 months
|
||||
dict[JAX_GPU_PREFIX "_threefry2x32"] = EncapsulateFunction(ThreeFry2x32);
|
||||
return dict;
|
||||
}
|
||||
|
||||
NB_MODULE(_prng, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
// TODO(b/338022728): remove after 3 weeks
|
||||
m.def("threefry2x32_descriptor", [](std::int64_t n) {
|
||||
std::string result = BuildThreeFry2x32Descriptor(n);
|
||||
return nb::bytes(result.data(), result.size());
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -25,7 +25,7 @@ import jaxlib.mlir.ir as ir
|
||||
from jaxlib import xla_client
|
||||
|
||||
from .hlo_helpers import custom_call
|
||||
from .gpu_common_utils import GpuLibNotLinkedError
|
||||
|
||||
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
try:
|
||||
@ -63,18 +63,17 @@ if _hip_prng:
|
||||
|
||||
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
|
||||
|
||||
# TODO(b/338022728): forward_compatibility_mode=False after 3 weeks.
|
||||
|
||||
def _threefry2x32_lowering(prng, platform: str, keys, data,
|
||||
length: int | ir.Value | None = None,
|
||||
output_shape: ir.Value | None = None,
|
||||
forward_compatibility_mode: bool = True):
|
||||
forward_compatibility_mode: bool = False):
|
||||
"""ThreeFry2x32 kernel for GPU.
|
||||
|
||||
In presence of dynamic shapes, `length` is an `ir.Value` and `output_shape`
|
||||
is a 1D tensor describing the shape of the two outputs.
|
||||
"""
|
||||
if forward_compatibility_mode and not prng:
|
||||
raise GpuLibNotLinkedError()
|
||||
del forward_compatibility_mode
|
||||
assert len(keys) == 2, keys
|
||||
assert len(data) == 2, data
|
||||
assert (ir.RankedTensorType(keys[0].type).element_type ==
|
||||
@ -90,37 +89,18 @@ def _threefry2x32_lowering(prng, platform: str, keys, data,
|
||||
operand_layouts = [layout] * 4
|
||||
operands = [keys[0], keys[1], data[0], data[1]]
|
||||
|
||||
if forward_compatibility_mode and length is None:
|
||||
length = _prod(dims)
|
||||
|
||||
opaque = {} # Use if not forward_compatibility_mode to trigger the FFI (v4).
|
||||
if isinstance(length, int):
|
||||
if forward_compatibility_mode:
|
||||
opaque = prng.threefry2x32_descriptor(length)
|
||||
result_shapes = None
|
||||
else:
|
||||
assert output_shape is not None
|
||||
if forward_compatibility_mode:
|
||||
opaque = prng.threefry2x32_descriptor(-1)
|
||||
assert (ir.RankedTensorType(length.type).element_type == # type: ignore[attribute-error]
|
||||
ir.IntegerType.get_signless(64)), length
|
||||
assert (ir.RankedTensorType(length.type).shape == # type: ignore[attribute-error]
|
||||
[1]), (length, ir.RankedTensorType(length.type).shape) # type: ignore[attribute-error]
|
||||
# Pass the length, which will be used by the custom call target since the
|
||||
# static length in the descriptor is -1.
|
||||
operands.append(length)
|
||||
operand_layouts.append((0,))
|
||||
# We also need to pass separately the shapes of the outputs.
|
||||
result_shapes = [output_shape, output_shape]
|
||||
|
||||
custom_call_target = (
|
||||
f"{platform}_threefry2x32"
|
||||
if forward_compatibility_mode
|
||||
else f"{platform}_threefry2x32_ffi"
|
||||
)
|
||||
custom_call_target = f"{platform}_threefry2x32_ffi"
|
||||
return custom_call(
|
||||
custom_call_target,
|
||||
api_version=(2 if forward_compatibility_mode else 4),
|
||||
api_version=4,
|
||||
result_types=[typ, typ],
|
||||
operands=operands,
|
||||
backend_config=opaque,
|
||||
|
@ -66,11 +66,13 @@ from jax._src.lib import version as jaxlib_version
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def _is_required_cusolver_version_satisfied(required_version):
|
||||
if cuda_versions is None:
|
||||
return False
|
||||
return cuda_versions.cusolver_get_version() >= required_version
|
||||
|
||||
|
||||
@jtu.with_config(jax_legacy_prng_key="allow",
|
||||
jax_debug_key_reuse=False,
|
||||
jax_include_full_tracebacks_in_locations=False,
|
||||
@ -116,7 +118,8 @@ class CompatTest(bctu.CompatTestBase):
|
||||
cpu_cholesky_lapack_potrf.data_2023_06_19,
|
||||
cpu_eig_lapack_geev.data_2023_06_19,
|
||||
cpu_eigh_lapack_syev.data_2023_03_17,
|
||||
cpu_qr_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2023_03_15,
|
||||
cpu_qr_lapack_geqrf.data_2023_03_17,
|
||||
cuda_threefry2x32.data_2023_03_15, cuda_threefry2x32.data_2024_07_30,
|
||||
cpu_lu_lapack_getrf.data_2023_06_14,
|
||||
cuda_qr_cusolver_geqrf.data_2023_03_18, cuda_eigh_cusolver_syev.data_2023_03_17,
|
||||
cpu_schur_lapack_gees.data_2023_07_16,
|
||||
@ -144,7 +147,6 @@ class CompatTest(bctu.CompatTestBase):
|
||||
"tf.call_tf_function", # tested in jax2tf/tests/back_compat_tf_test.py
|
||||
"tpu_custom_call", # tested separately
|
||||
"__gpu$xla.gpu.triton", # tested in pallas/export_back_compat_pallas_test.py
|
||||
"cu_threefry2x32_ffi", # TODO(b/338022728) add the actual backwards compatibility test
|
||||
})
|
||||
not_covered = targets_to_cover.difference(covered_targets)
|
||||
self.assertEmpty(not_covered,
|
||||
@ -592,7 +594,12 @@ class CompatTest(bctu.CompatTestBase):
|
||||
def func(x):
|
||||
return jax.random.uniform(x, (2, 4), dtype=np.float32)
|
||||
|
||||
# TODO(b/338022728): remove after 6 months
|
||||
data = self.load_testdata(cuda_threefry2x32.data_2023_03_15)
|
||||
self.run_one_test(func, data,
|
||||
expect_current_custom_calls=["cu_threefry2x32_ffi"])
|
||||
|
||||
data = self.load_testdata(cuda_threefry2x32.data_2024_07_30)
|
||||
self.run_one_test(func, data)
|
||||
|
||||
def test_sharding(self):
|
||||
|
@ -36,7 +36,6 @@ from jax import lax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.internal_test_util import test_harnesses
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax import random
|
||||
|
||||
|
||||
@ -206,14 +205,8 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
self.export_and_compare_to_native(f, x)
|
||||
|
||||
def test_random_with_threefry_gpu_kernel_lowering(self):
|
||||
# TODO(b/350111820): fix the FFI registration mechanism
|
||||
self.skipTest("b/350111820: fix the FFI registration mechanism")
|
||||
if jaxlib_version < (0, 4, 31):
|
||||
self.skipTest("jaxlib.version < 0.4.31")
|
||||
# On GPU we use a custom call for threefry2x32
|
||||
with config.threefry_gpu_kernel_lowering(True):
|
||||
# TODO(b/338022728): clean up forward compatibility mode.
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
def f(x):
|
||||
return random.gamma(random.key(42), x)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user