mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Ported threefry2x32 for GPU to the typed XLA FFI
This allows lowering of threefry2x32 for GPU even on a machine without GPUs. For the next 3 weeks, we only use the new custom call implementation if we are not in "export" mode, and if we use a new jaxlib. PiperOrigin-RevId: 647657084
This commit is contained in:
parent
3ebebdfb76
commit
cbe524298c
@ -174,6 +174,12 @@ and external users should:
|
||||
The compatibility guarantees do not apply if you bypass the `jax.export` APIs
|
||||
to obtain the StableHLO code.
|
||||
|
||||
In order to ensure forward compatibility, when we change the JAX lowering rules
|
||||
to use a new custom call target, JAX will refrain for 3 weeks to use the new
|
||||
target. To use the latest lowering rules, you can pass the
|
||||
`--jax_export_ignore_forward_compatibility=1` configuration flag
|
||||
or the `JAX_EXPORT_IGNORE_FORWARD_COMPATIBILITY=1` environment variable.
|
||||
|
||||
Only a subset of custom calls are guaranteed stable and have
|
||||
compatibility guarantees ([see list](https://github.com/search?q=repo%3Agoogle%2Fjax%20_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE&type=code)).
|
||||
We continuously
|
||||
|
@ -948,7 +948,7 @@ export_ignore_forward_compatibility = bool_state(
|
||||
default=bool_env('JAX_EXPORT_IGNORE_FORWARD_COMPATIBILIY', False),
|
||||
help=(
|
||||
'Whether to ignore the forward compatibility lowering rules. '
|
||||
'See file:///Users/necula/Source/jax/docs/build/html/export/export.html#ensuring-forward-and-backward-compatibility.'
|
||||
'See https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls.'
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -922,7 +922,7 @@ def _check_lowering(lowering) -> None:
|
||||
# Their backwards compatibility is tested by back_compat_test.py.
|
||||
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
||||
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
|
||||
"cu_threefry2x32",
|
||||
"cu_threefry2x32", "cu_threefry2x32_ffi",
|
||||
"__gpu$xla.gpu.triton", # Pallas call on GPU
|
||||
# cholesky on CPU
|
||||
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
|
||||
|
@ -47,6 +47,7 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import utils as lax_utils
|
||||
from jax._src.lib import gpu_prng
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy.array_methods import (
|
||||
@ -915,6 +916,13 @@ def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2):
|
||||
if not config.threefry_gpu_kernel_lowering.value: # back to default lowering
|
||||
return _threefry2x32_lowering_rule(ctx, k1, k2, x1, x2)
|
||||
|
||||
# TODO(b/338022728): when we export, use the old custom call target for now.
|
||||
# Make forward_compatibility_mode False after 3 weeks.
|
||||
lowering_parameters = ctx.module_context.lowering_parameters
|
||||
forward_compatibility_mode = (
|
||||
lowering_parameters.for_export and
|
||||
not lowering_parameters.export_ignore_forward_compatibility)
|
||||
|
||||
aval_out, aval_out_2 = ctx.avals_out
|
||||
assert aval_out == aval_out_2
|
||||
k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in
|
||||
@ -937,10 +945,17 @@ def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2):
|
||||
length = int(out_len) # will be passed statically
|
||||
output_shape = None
|
||||
|
||||
return lowering_func(
|
||||
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
||||
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length,
|
||||
output_shape)
|
||||
if jaxlib_version >= (0, 4, 31):
|
||||
return lowering_func(
|
||||
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
||||
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length,
|
||||
output_shape,
|
||||
forward_compatibility_mode)
|
||||
else:
|
||||
return lowering_func(
|
||||
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
||||
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length,
|
||||
output_shape)
|
||||
|
||||
threefry2x32_p = core.Primitive("threefry2x32")
|
||||
threefry2x32_p.multiple_results = True
|
||||
|
@ -377,7 +377,10 @@ cc_library(
|
||||
":cuda_prng_kernels_impl",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
@ -392,6 +395,8 @@ cuda_library(
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
@ -430,6 +435,8 @@ cc_library(
|
||||
":cusolver_kernels",
|
||||
":cusparse_kernels",
|
||||
":triton_kernels",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_target_registry",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -25,6 +25,8 @@ limitations under the License.
|
||||
#include "jaxlib/gpu/sparse_kernels.h"
|
||||
#include "jaxlib/gpu/triton_kernels.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
#include "xla/service/custom_call_target_registry.h"
|
||||
|
||||
namespace jax {
|
||||
@ -43,6 +45,8 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_lu_pivots_to_permutation",
|
||||
LuPivotsToPermutation, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_threefry2x32", ThreeFry2x32,
|
||||
"CUDA");
|
||||
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_threefry2x32_ffi",
|
||||
"CUDA", ThreeFry2x32Ffi);
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA");
|
||||
|
@ -28,12 +28,16 @@ std::string BuildThreeFry2x32Descriptor(std::int64_t n) {
|
||||
}
|
||||
nb::dict Registrations() {
|
||||
nb::dict dict;
|
||||
dict[JAX_GPU_PREFIX "_threefry2x32_ffi"] =
|
||||
EncapsulateFunction(ThreeFry2x32Ffi);
|
||||
// TODO(b/338022728): remove after 3 weeks
|
||||
dict[JAX_GPU_PREFIX "_threefry2x32"] = EncapsulateFunction(ThreeFry2x32);
|
||||
return dict;
|
||||
}
|
||||
|
||||
NB_MODULE(_prng, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
// TODO(b/338022728): remove after 3 weeks
|
||||
m.def("threefry2x32_descriptor", [](std::int64_t n) {
|
||||
std::string result = BuildThreeFry2x32Descriptor(n);
|
||||
return nb::bytes(result.data(), result.size());
|
||||
|
@ -15,16 +15,26 @@ limitations under the License.
|
||||
|
||||
#include "jaxlib/gpu/prng_kernels.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
namespace ffi = xla::ffi;
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO(b/338022728): old custom call target, remove after 6 months
|
||||
absl::Status ThreeFry2x32_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, std::size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<ThreeFry2x32Descriptor>(opaque, opaque_len);
|
||||
@ -36,6 +46,7 @@ absl::Status ThreeFry2x32_(gpuStream_t stream, void** buffers,
|
||||
|
||||
} // namespace
|
||||
|
||||
// TODO(b/338022728): remove after 6 months
|
||||
void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = ThreeFry2x32_(stream, buffers, opaque, opaque_len);
|
||||
@ -45,5 +56,32 @@ void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
}
|
||||
}
|
||||
|
||||
XLA_FFI_Error* ThreeFry2x32Ffi(XLA_FFI_CallFrame* call_frame) {
|
||||
static const auto* kImpl =
|
||||
ffi::Ffi::Bind()
|
||||
.Ctx<ffi::PlatformStream<gpuStream_t>>()
|
||||
.Arg<ffi::Buffer<ffi::DataType::U32>>()
|
||||
.Arg<ffi::Buffer<ffi::DataType::U32>>()
|
||||
.Arg<ffi::Buffer<ffi::DataType::U32>>()
|
||||
.Arg<ffi::Buffer<ffi::DataType::U32>>()
|
||||
.Ret<ffi::Buffer<ffi::DataType::U32>>()
|
||||
.Ret<ffi::Buffer<ffi::DataType::U32>>()
|
||||
.To([](gpuStream_t stream, auto keys0, auto keys1, auto data0,
|
||||
auto data1, auto out0, auto out1) -> ffi::Error {
|
||||
std::int64_t n = absl::c_accumulate(out0->dimensions, 1,
|
||||
std::multiplies<int64_t>());
|
||||
LaunchThreeFry2x32KernelFfi(stream, n, keys0.data, keys1.data,
|
||||
data0.data, data1.data, out0->data,
|
||||
out1->data);
|
||||
if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) {
|
||||
return ffi::Error(static_cast<XLA_FFI_Error_Code>(status.code()),
|
||||
std::string(status.message()));
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
})
|
||||
.release();
|
||||
return kImpl->Call(call_frame);
|
||||
}
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
@ -105,6 +105,23 @@ __global__ void ThreeFry2x32Kernel(const std::uint32_t* key0,
|
||||
|
||||
} // namespace
|
||||
|
||||
void LaunchThreeFry2x32KernelFfi(gpuStream_t stream,
|
||||
std::int64_t n,
|
||||
std::uint32_t *keys0,
|
||||
std::uint32_t *keys1,
|
||||
std::uint32_t *data0,
|
||||
std::uint32_t *data1,
|
||||
std::uint32_t *out0,
|
||||
std::uint32_t *out1) {
|
||||
const int block_dim = 128;
|
||||
const std::int64_t grid_dim =
|
||||
std::min<std::int64_t>(1024, (n + block_dim - 1) / block_dim);
|
||||
ThreeFry2x32Kernel<<<grid_dim, block_dim, /*dynamic_shared_mem_bytes=*/0,
|
||||
stream>>>(keys0, keys1, data0, data1, out0,
|
||||
out1, n, nullptr);
|
||||
}
|
||||
|
||||
// TODO(b/338022728): remove after 6 months
|
||||
void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
|
||||
ThreeFry2x32Descriptor descriptor) {
|
||||
std::array<const std::uint32_t*, 2> keys;
|
||||
|
@ -21,21 +21,33 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
// TODO(b/338022728): remove after 6 months
|
||||
struct ThreeFry2x32Descriptor {
|
||||
std::int64_t n; // If -1 then the length is passed as a 5th operand
|
||||
};
|
||||
|
||||
// TODO(b/338022728): remove after 6 months
|
||||
void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
|
||||
ThreeFry2x32Descriptor descriptor);
|
||||
|
||||
// TODO(b/338022728): remove after 6 months
|
||||
void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
XLA_FFI_Error* ThreeFry2x32Ffi(XLA_FFI_CallFrame* call_frame);
|
||||
|
||||
void LaunchThreeFry2x32KernelFfi(gpuStream_t stream,
|
||||
std::int64_t n,
|
||||
std::uint32_t *keys0, std::uint32_t *keys1,
|
||||
std::uint32_t *data0, std::uint32_t *data1,
|
||||
std::uint32_t *out0, std::uint32_t *out1);
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
||||
|
@ -39,7 +39,10 @@ for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
|
||||
if _cuda_prng:
|
||||
for _name, _value in _cuda_prng.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
# TODO(b/338022728): remove after 6 months, always api_version=1
|
||||
api_version = 1 if "_ffi" in _name else 0
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
|
||||
api_version=api_version)
|
||||
|
||||
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
|
||||
try:
|
||||
@ -53,19 +56,24 @@ for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
|
||||
|
||||
if _hip_prng:
|
||||
for _name, _value in _hip_prng.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
# TODO(b/338022728): remove after 6 months, always api_version=1
|
||||
api_version = 1 if "_ffi" in _name else 0
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
|
||||
api_version=api_version)
|
||||
|
||||
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
|
||||
|
||||
def _threefry2x32_lowering(prng, platform, keys, data,
|
||||
# TODO(b/338022728): forward_compatibility_mode=False after 3 weeks.
|
||||
def _threefry2x32_lowering(prng, platform: str, keys, data,
|
||||
length: int | ir.Value | None = None,
|
||||
output_shape: ir.Value | None = None):
|
||||
output_shape: ir.Value | None = None,
|
||||
forward_compatibility_mode: bool = True):
|
||||
"""ThreeFry2x32 kernel for GPU.
|
||||
|
||||
In presence of dynamic shapes, `length` is an `ir.Value` and `output_shape`
|
||||
is a 1D tensor describing the shape of the two outputs.
|
||||
"""
|
||||
if not prng:
|
||||
if forward_compatibility_mode and not prng:
|
||||
raise GpuLibNotLinkedError()
|
||||
assert len(keys) == 2, keys
|
||||
assert len(data) == 2, data
|
||||
@ -82,28 +90,37 @@ def _threefry2x32_lowering(prng, platform, keys, data,
|
||||
operand_layouts = [layout] * 4
|
||||
operands = [keys[0], keys[1], data[0], data[1]]
|
||||
|
||||
if length is None:
|
||||
if forward_compatibility_mode and length is None:
|
||||
length = _prod(dims)
|
||||
|
||||
opaque = {} # Use if not forward_compatibility_mode to trigger the FFI (v4).
|
||||
if isinstance(length, int):
|
||||
opaque = prng.threefry2x32_descriptor(length)
|
||||
if forward_compatibility_mode:
|
||||
opaque = prng.threefry2x32_descriptor(length)
|
||||
result_shapes = None
|
||||
else:
|
||||
assert output_shape is not None
|
||||
opaque = prng.threefry2x32_descriptor(-1)
|
||||
assert (ir.RankedTensorType(length.type).element_type ==
|
||||
ir.IntegerType.get_signless(64)), length
|
||||
assert (ir.RankedTensorType(length.type).shape ==
|
||||
[1]), (length, ir.RankedTensorType(length.type).shape)
|
||||
# Pass the length, which will be used by the custom call target since the
|
||||
# static length in the descriptor is -1.
|
||||
operands.append(length)
|
||||
operand_layouts.append((0,))
|
||||
if forward_compatibility_mode:
|
||||
opaque = prng.threefry2x32_descriptor(-1)
|
||||
assert (ir.RankedTensorType(length.type).element_type == # type: ignore[attribute-error]
|
||||
ir.IntegerType.get_signless(64)), length
|
||||
assert (ir.RankedTensorType(length.type).shape == # type: ignore[attribute-error]
|
||||
[1]), (length, ir.RankedTensorType(length.type).shape) # type: ignore[attribute-error]
|
||||
# Pass the length, which will be used by the custom call target since the
|
||||
# static length in the descriptor is -1.
|
||||
operands.append(length)
|
||||
operand_layouts.append((0,))
|
||||
# We also need to pass separately the shapes of the outputs.
|
||||
result_shapes = [output_shape, output_shape]
|
||||
|
||||
custom_call_target = (
|
||||
f"{platform}_threefry2x32"
|
||||
if forward_compatibility_mode
|
||||
else f"{platform}_threefry2x32_ffi"
|
||||
)
|
||||
return custom_call(
|
||||
f"{platform}_threefry2x32",
|
||||
custom_call_target,
|
||||
api_version=(2 if forward_compatibility_mode else 4),
|
||||
result_types=[typ, typ],
|
||||
operands=operands,
|
||||
backend_config=opaque,
|
||||
|
@ -1473,6 +1473,7 @@ jax_test(
|
||||
srcs = ["export_harnesses_multi_platform_test.py"],
|
||||
disable_configs = [
|
||||
"gpu_a100", # TODO(b/269593297): matmul precision issues
|
||||
"gpu_h100", # Scarce resources.
|
||||
],
|
||||
shard_count = {
|
||||
"cpu": 40,
|
||||
|
@ -19,6 +19,7 @@ update these tests.
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
@ -62,6 +63,7 @@ from jax.sharding import PartitionSpec as P
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -139,6 +141,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
"tf.call_tf_function", # tested in jax2tf/tests/back_compat_tf_test.py
|
||||
"tpu_custom_call", # tested separately
|
||||
"__gpu$xla.gpu.triton", # tested in pallas/export_back_compat_pallas_test.py
|
||||
"cu_threefry2x32_ffi", # TODO(b/338022728) add the actual backwards compatibility test
|
||||
})
|
||||
not_covered = targets_to_cover.difference(covered_targets)
|
||||
self.assertEmpty(not_covered,
|
||||
@ -577,6 +580,8 @@ class CompatTest(bctu.CompatTestBase):
|
||||
self.run_one_test(func, data)
|
||||
|
||||
def test_cuda_threefry2x32(self):
|
||||
logging.info("test_cuda_threefry2x32: xla_extension_version: %s",
|
||||
xla_extension_version)
|
||||
def func(x):
|
||||
return jax.random.uniform(x, (2, 4), dtype=np.float32)
|
||||
|
||||
|
@ -33,17 +33,23 @@ import numpy as np
|
||||
import jax
|
||||
from jax import export
|
||||
from jax import lax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.internal_test_util import test_harnesses
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax import random
|
||||
|
||||
|
||||
def make_disjunction_regexp(*parts: str) -> re.Pattern[str]:
|
||||
return re.compile("(" + "|".join(parts) + ")")
|
||||
if not parts:
|
||||
return re.compile("matches_no_test")
|
||||
else:
|
||||
return re.compile("(" + "|".join(parts) + ")")
|
||||
|
||||
# TODO(necula): Failures to be investigated (on GPU).
|
||||
_known_failures_gpu = make_disjunction_regexp(
|
||||
# Failures due to failure to export custom call targets for GPU, these
|
||||
# targets do not have backwards compatibility tests.
|
||||
# Failures on GPU due to failure to export custom call targets, these
|
||||
# involve GPU custom call targets withoutbackwards compatibility tests.
|
||||
"custom_linear_solve_",
|
||||
"lu_",
|
||||
"svd_",
|
||||
@ -54,9 +60,9 @@ _known_failures_gpu = make_disjunction_regexp(
|
||||
# CUDA lowering.
|
||||
_skip_cuda_lowering_unless_have_gpus = make_disjunction_regexp(
|
||||
"svd_", "lu_", "eigh_", "qr_", "custom_linear_", "tridiagonal_solve_",
|
||||
"random_",
|
||||
)
|
||||
|
||||
|
||||
class PrimitiveTest(jtu.JaxTestCase):
|
||||
|
||||
@classmethod
|
||||
@ -84,7 +90,7 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
@test_harnesses.parameterized(
|
||||
test_harnesses.all_harnesses,
|
||||
include_jax_unimpl=False,
|
||||
#one_containing="",
|
||||
# one_containing="",
|
||||
)
|
||||
@jtu.ignore_warning(
|
||||
category=UserWarning,
|
||||
@ -197,6 +203,20 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
x = (x % 2).astype(np.bool_)
|
||||
self.export_and_compare_to_native(f, x)
|
||||
|
||||
def test_random_with_threefry_gpu_kernel_lowering(self):
|
||||
if jaxlib_version < (0, 4, 31):
|
||||
self.skipTest("jaxlib.version < 0.4.31")
|
||||
# On GPU we use a custom call for thrteefry2x32
|
||||
with config.threefry_gpu_kernel_lowering(True):
|
||||
# TODO(b/338022728): clean up forward compatibility mode.
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
def f(x):
|
||||
return random.gamma(random.key(42), x)
|
||||
|
||||
shape = (4, 5)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
self.export_and_compare_to_native(f, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user