mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add a handwritten ThreeFry2x32 CUDA kernel. (#1756)
In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time. When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.
This commit is contained in:
parent
67038321f8
commit
534d812b57
@ -23,10 +23,10 @@ http_archive(
|
||||
# and update the sha256 with the result.
|
||||
http_archive(
|
||||
name = "org_tensorflow",
|
||||
sha256 = "d83221d413fd510ac8bc68ae158fcc17c2300c7c8c3bd439e30300438c2e3ce0",
|
||||
strip_prefix = "tensorflow-d85c68d5bdcb3f72abfe22a73a638c11def69a7e",
|
||||
sha256 = "412ef0824d5dcfe6e139e1fa25f72569e699b3ec06d374c0e19ba0bf60c32952",
|
||||
strip_prefix = "tensorflow-883b5becaced22f7dd9e3c23d9d259f55e087cb5",
|
||||
urls = [
|
||||
"https://github.com/tensorflow/tensorflow/archive/d85c68d5bdcb3f72abfe22a73a638c11def69a7e.tar.gz",
|
||||
"https://github.com/tensorflow/tensorflow/archive/883b5becaced22f7dd9e3c23d9d259f55e087cb5.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -14,10 +14,7 @@
|
||||
|
||||
# JAX is Autograd and XLA
|
||||
|
||||
load(
|
||||
"@org_tensorflow//tensorflow/core/platform:default/cuda_build_defs.bzl",
|
||||
"if_cuda_is_configured",
|
||||
)
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
||||
licenses(["notice"]) # Apache 2
|
||||
|
||||
@ -32,9 +29,10 @@ sh_binary(
|
||||
"//jaxlib",
|
||||
"//jaxlib:lapack.so",
|
||||
"//jaxlib:pytree",
|
||||
] + if_cuda_is_configured([
|
||||
] + if_cuda([
|
||||
"//jaxlib:cublas_kernels",
|
||||
"//jaxlib:cusolver_kernels",
|
||||
"//jaxlib:cuda_prng_kernels",
|
||||
]),
|
||||
deps = ["@bazel_tools//tools/bash/runfiles"],
|
||||
)
|
||||
|
@ -177,6 +177,9 @@ build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1
|
||||
# Sets the default Apple platform to macOS.
|
||||
build --apple_platform_type=macos
|
||||
|
||||
# Make Bazel print out all options from rc files.
|
||||
build --announce_rc
|
||||
|
||||
# Disable enabled-by-default TensorFlow features that we don't care about.
|
||||
build --define=no_aws_support=true
|
||||
build --define=no_gcp_support=true
|
||||
|
@ -57,9 +57,11 @@ cp -f "$(rlocation __main__/jaxlib/pytree.so)" "${TARGET}/jaxlib"
|
||||
if [[ -x "$(rlocation __main__/jaxlib/cusolver_kernels.so)" ]]; then
|
||||
cp -f "$(rlocation __main__/jaxlib/cublas_kernels.so)" "${TARGET}/jaxlib"
|
||||
cp -f "$(rlocation __main__/jaxlib/cusolver_kernels.so)" "${TARGET}/jaxlib"
|
||||
cp -f "$(rlocation __main__/jaxlib/cuda_prng_kernels.so)" "${TARGET}/jaxlib"
|
||||
fi
|
||||
cp -f "$(rlocation __main__/jaxlib/version.py)" "${TARGET}/jaxlib"
|
||||
cp -f "$(rlocation __main__/jaxlib/cusolver.py)" "${TARGET}/jaxlib"
|
||||
cp -f "$(rlocation __main__/jaxlib/cuda_prng.py)" "${TARGET}/jaxlib"
|
||||
cp -f "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so)" \
|
||||
"${TARGET}/jaxlib"
|
||||
cp -f "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so)" \
|
||||
|
@ -182,13 +182,15 @@ def broadcast_batcher(prim, args, dims, **params):
|
||||
if len(shapes) == 1:
|
||||
# if there's only agreeing batch dims and scalars, just call the primitive
|
||||
d = next(d for d in dims if d is not not_mapped)
|
||||
return prim.bind(*args, **params), d
|
||||
out = prim.bind(*args, **params)
|
||||
return (out, (d,) * len(out)) if prim.multiple_results else (out, d)
|
||||
else:
|
||||
size, = {shape[d] for shape, d in shapes if d is not not_mapped}
|
||||
args = [bdim_at_front(x, d, size) for x, d in zip(args, dims)]
|
||||
ndim = max(onp.ndim(x) for x in args) # special-case scalar broadcasting
|
||||
args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)]
|
||||
return prim.bind(*args, **params), 0
|
||||
out = prim.bind(*args, **params)
|
||||
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
|
||||
|
||||
def _handle_scalar_broadcasting(nd, x, d):
|
||||
if d is not_mapped or nd == onp.ndim(x):
|
||||
|
@ -20,7 +20,8 @@ from .lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
|
||||
_select_and_gather_add, _float, _complex,
|
||||
_input_dtype, _const, _eq_meet, _safe_mul,
|
||||
_broadcasting_select, _check_user_dtype_supported,
|
||||
_one, _const, _upcast_fp16_for_computation)
|
||||
_one, _const, _upcast_fp16_for_computation,
|
||||
_broadcasting_shape_rule)
|
||||
from .lax_control_flow import *
|
||||
from .lax_fft import *
|
||||
from .lax_parallel import *
|
||||
|
@ -53,3 +53,7 @@ from jaxlib import lapack
|
||||
|
||||
from jaxlib import pytree
|
||||
from jaxlib import cusolver
|
||||
try:
|
||||
from jaxlib import cuda_prng
|
||||
except ImportError:
|
||||
cuda_prng = None
|
||||
|
@ -25,6 +25,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from functools import partial
|
||||
import itertools
|
||||
|
||||
import numpy as onp
|
||||
|
||||
@ -35,9 +36,13 @@ from . import dtypes
|
||||
from .api import custom_transforms, defjvp, jit, vmap
|
||||
from .numpy.lax_numpy import _constant_like, asarray, stack
|
||||
from jax.lib import xla_bridge
|
||||
from jax.lib import cuda_prng
|
||||
from jax import core
|
||||
from jax import abstract_arrays
|
||||
from jax.scipy.special import logit
|
||||
from jax.scipy.linalg import cholesky
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import xla
|
||||
|
||||
|
||||
def PRNGKey(seed):
|
||||
@ -92,9 +97,18 @@ def _bit_stats(bits):
|
||||
|
||||
### hash function and split
|
||||
|
||||
def _threefry2x32_abstract_eval(*args):
|
||||
if any(a.dtype != np.uint32 for a in args):
|
||||
raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}"
|
||||
.format(args))
|
||||
if all(isinstance(arg, abstract_arrays.ShapedArray) for arg in args):
|
||||
shape = lax._broadcasting_shape_rule(*args)
|
||||
aval = abstract_arrays.ShapedArray(shape, np.dtype(np.uint32))
|
||||
else:
|
||||
aval = abstract_arrays.UnshapedArray(np.dtype(np.uint32))
|
||||
return (aval,) * 2
|
||||
|
||||
@jit
|
||||
def threefry_2x32(keypair, count):
|
||||
def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
|
||||
"""Apply the Threefry 2x32 hash.
|
||||
|
||||
Args:
|
||||
@ -104,13 +118,8 @@ def threefry_2x32(keypair, count):
|
||||
Returns:
|
||||
An array of dtype uint32 with the same shape as `count`.
|
||||
"""
|
||||
# Based on ThreeFry2x32 by phawkins@ in //.../xla/client/lib/prng.cc
|
||||
key1, key2 = keypair
|
||||
if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == onp.uint32:
|
||||
msg = "threefry_2x32 requires uint32 arguments, got {}"
|
||||
raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))
|
||||
|
||||
rotate_left = _make_rotate_left(lax.dtype(count))
|
||||
x = [x1, x2]
|
||||
rotate_left = _make_rotate_left(onp.uint32)
|
||||
|
||||
def apply_round(v, rot):
|
||||
v = v[:]
|
||||
@ -119,24 +128,11 @@ def threefry_2x32(keypair, count):
|
||||
v[1] = v[0] ^ v[1]
|
||||
return v
|
||||
|
||||
odd_size = count.size % 2
|
||||
if odd_size:
|
||||
x = list(np.split(np.concatenate([count.ravel(), onp.uint32([0])]), 2))
|
||||
else:
|
||||
x = list(np.split(count.ravel(), 2))
|
||||
|
||||
rotations = [onp.array([13, 15, 26, 6], dtype=onp.uint32),
|
||||
onp.array([17, 29, 16, 24], dtype=onp.uint32)]
|
||||
ks = [key1, key2, key1 ^ key2 ^ onp.uint32(0x1BD11BDA)]
|
||||
|
||||
# TODO(mattjj): see https://github.com/google/jax/issues/1267, as a hopefully
|
||||
# temporary workaround for the facts that (1) XLA:CPU compile time is too slow
|
||||
# with unrolled loops and (2) XLA:GPU execution time is too slow with rolled
|
||||
# loops, we switch on whether the default backend is CPU or GPU. If this kind
|
||||
# of switch ends up sticking around, we should take into account #1211 and put
|
||||
# the switch in the translation rule rather than here in the traceable.
|
||||
use_rolled_loops = xla_bridge.get_backend().platform == "cpu"
|
||||
|
||||
x[0] = x[0] + ks[0]
|
||||
x[1] = x[1] + ks[1]
|
||||
|
||||
@ -176,6 +172,56 @@ def threefry_2x32(keypair, count):
|
||||
x[0] = x[0] + ks[2]
|
||||
x[1] = x[1] + ks[0] + onp.uint32(5)
|
||||
|
||||
return tuple(x)
|
||||
|
||||
|
||||
def _threefry2x32_gpu_translation_rule(c, k1, k2, x1, x2):
|
||||
shape = lax.broadcast_shapes(
|
||||
c.GetShape(k1).dimensions(), c.GetShape(k2).dimensions(),
|
||||
c.GetShape(x1).dimensions(), c.GetShape(x2).dimensions())
|
||||
rank = len(shape)
|
||||
def _broadcast(x):
|
||||
ndims = c.GetShape(x).rank()
|
||||
return c.BroadcastInDim(x, shape, tuple(range(rank - ndims, rank)))
|
||||
return cuda_prng.threefry2x32(
|
||||
c, (_broadcast(k1), _broadcast(k2)), (_broadcast(x1), _broadcast(x2)))
|
||||
|
||||
threefry2x32_p = core.Primitive("threefry2x32")
|
||||
threefry2x32_p.multiple_results = True
|
||||
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
|
||||
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
|
||||
batching.defbroadcasting(threefry2x32_p)
|
||||
xla.translations[threefry2x32_p] = xla.lower_fun(
|
||||
partial(_threefry2x32_lowering, use_rolled_loops=False), instantiate=True)
|
||||
xla.backend_specific_translations['cpu'][threefry2x32_p] = xla.lower_fun(
|
||||
partial(_threefry2x32_lowering, use_rolled_loops=True), instantiate=True)
|
||||
if cuda_prng:
|
||||
xla.backend_specific_translations['gpu'][threefry2x32_p] = \
|
||||
_threefry2x32_gpu_translation_rule
|
||||
|
||||
@jit
|
||||
def threefry_2x32(keypair, count):
|
||||
"""Apply the Threefry 2x32 hash.
|
||||
|
||||
Args:
|
||||
keypair: a pair of 32bit unsigned integers used for the key.
|
||||
count: an array of dtype uint32 used for the counts.
|
||||
|
||||
Returns:
|
||||
An array of dtype uint32 with the same shape as `count`.
|
||||
"""
|
||||
key1, key2 = keypair
|
||||
if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == onp.uint32:
|
||||
msg = "threefry_2x32 requires uint32 arguments, got {}"
|
||||
raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))
|
||||
|
||||
odd_size = count.size % 2
|
||||
if odd_size:
|
||||
x = list(np.split(np.concatenate([count.ravel(), onp.uint32([0])]), 2))
|
||||
else:
|
||||
x = list(np.split(count.ravel(), 2))
|
||||
|
||||
x = threefry2x32_p.bind(key1, key2, x[0], x[1])
|
||||
out = np.concatenate(x)
|
||||
assert out.dtype == onp.uint32
|
||||
return lax.reshape(out[:-1] if odd_size else out, count.shape)
|
||||
|
58
jaxlib/BUILD
58
jaxlib/BUILD
@ -16,11 +16,24 @@
|
||||
|
||||
load("@org_tensorflow//tensorflow/core/platform:default/build_config.bzl", "pyx_library")
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library")
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
cc_library(
|
||||
name = "kernel_helpers",
|
||||
hdrs = ["kernel_helpers.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
deps = [
|
||||
"@com_google_absl//absl/base",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
pyx_library(
|
||||
name = "lapack",
|
||||
srcs = ["lapack.pyx"],
|
||||
@ -30,11 +43,21 @@ pyx_library(
|
||||
py_library(
|
||||
name = "jaxlib",
|
||||
srcs = [
|
||||
"cuda_prng.py",
|
||||
"cusolver.py",
|
||||
"version.py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gpu_support",
|
||||
deps = [
|
||||
":cublas_kernels",
|
||||
":cuda_prng_kernels",
|
||||
":cusolver_kernels",
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "pytree",
|
||||
srcs = ["pytree.cc"],
|
||||
@ -65,6 +88,7 @@ pybind_extension(
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "cublas_kernels",
|
||||
deps = [
|
||||
":kernel_helpers",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
@ -91,6 +115,7 @@ pybind_extension(
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "cusolver_kernels",
|
||||
deps = [
|
||||
":kernel_helpers",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
@ -102,6 +127,37 @@ pybind_extension(
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_cuda//cuda:cudart",
|
||||
"@local_config_cuda//cuda:cusolver",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_library(
|
||||
name = "cuda_prng_kernels_lib",
|
||||
srcs = ["cuda_prng_kernels.cu.cc"],
|
||||
hdrs = ["cuda_prng_kernels.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
deps = [
|
||||
":kernel_helpers",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "cuda_prng_kernels",
|
||||
srcs = ["cuda_prng_kernels.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "cuda_prng_kernels",
|
||||
deps = [
|
||||
":cuda_prng_kernels_lib",
|
||||
":kernel_helpers",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_cuda//cuda:cudart",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
@ -186,27 +187,6 @@ int SizeOfType(Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
// Descriptor objects are opaque host-side objects used to pass data from JAX
|
||||
// to the custom kernel launched by XLA. Currently simply treat host-side
|
||||
// structures as byte-strings; this is not portable across architectures. If
|
||||
// portability is needed, we could switch to using a representation such as
|
||||
// protocol buffers or flatbuffers.
|
||||
|
||||
// Packs a descriptor object into a py::bytes structure.
|
||||
template <typename T>
|
||||
py::bytes PackDescriptor(const T& descriptor) {
|
||||
return py::bytes(absl::bit_cast<const char*>(&descriptor), sizeof(T));
|
||||
}
|
||||
|
||||
// Unpacks a descriptor object from a byte string.
|
||||
template <typename T>
|
||||
const T* UnpackDescriptor(const char* opaque, size_t opaque_len) {
|
||||
if (opaque_len != sizeof(T)) {
|
||||
throw std::runtime_error("Invalid size for linalg operation descriptor.");
|
||||
}
|
||||
return absl::bit_cast<const T*>(opaque);
|
||||
}
|
||||
|
||||
// Builds an array of pointers to each array in a batch, in device memory.
|
||||
template <typename T>
|
||||
cudaError_t MakeBatchPointers(T* buffer, T** dev_ptrs, int batch,
|
||||
@ -390,11 +370,6 @@ void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
py::capsule EncapsulateFunction(T* fn) {
|
||||
return py::capsule(absl::bit_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
|
||||
}
|
||||
|
||||
py::dict Registrations() {
|
||||
py::dict dict;
|
||||
dict["cublas_trsm_batched"] = EncapsulateFunction(TrsmBatched);
|
||||
@ -409,4 +384,4 @@ PYBIND11_MODULE(cublas_kernels, m) {
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace jax
|
||||
} // namespace jax
|
||||
|
58
jaxlib/cuda_prng.py
Normal file
58
jaxlib/cuda_prng.py
Normal file
@ -0,0 +1,58 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
import operator
|
||||
|
||||
import numpy as np
|
||||
from six.moves import reduce
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
try:
|
||||
from jaxlib import cuda_prng_kernels
|
||||
for _name, _value in cuda_prng_kernels.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="gpu")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
_prod = lambda xs: reduce(operator.mul, xs, 1)
|
||||
|
||||
|
||||
def threefry2x32(c, keys, data):
|
||||
"""ThreeFry2x32 kernel for GPU."""
|
||||
assert len(keys) == 2, keys
|
||||
assert len(data) == 2, data
|
||||
dims = c.GetShape(keys[0]).dimensions()
|
||||
dtype = np.dtype(np.uint32)
|
||||
for x in itertools.chain(keys, data):
|
||||
x_shape = c.GetShape(x)
|
||||
assert x_shape.element_type() == dtype
|
||||
assert dims == x_shape.dimensions(), (dims, x_shape)
|
||||
ndims = len(dims)
|
||||
|
||||
opaque = cuda_prng_kernels.cuda_threefry2x32_descriptor(_prod(dims))
|
||||
layout = tuple(range(ndims - 1, -1, -1))
|
||||
shape = xla_client.Shape.array_shape(dtype, dims, layout)
|
||||
return c.CustomCall(
|
||||
b"cuda_threefry2x32",
|
||||
operands=(keys[0], keys[1], data[0], data[1]),
|
||||
shape_with_layout=xla_client.Shape.tuple_shape([shape, shape]),
|
||||
operand_shapes_with_layout=(shape,) * 4,
|
||||
opaque=opaque)
|
||||
|
36
jaxlib/cuda_prng_kernels.cc
Normal file
36
jaxlib/cuda_prng_kernels.cc
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2019 Google LLC
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/cuda_prng_kernels.h"
|
||||
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
pybind11::dict Registrations() {
|
||||
pybind11::dict dict;
|
||||
dict["cuda_threefry2x32"] = EncapsulateFunction(CudaThreeFry2x32);
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(cuda_prng_kernels, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("cuda_threefry2x32_descriptor", &BuildCudaThreeFry2x32Descriptor);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace jax
|
128
jaxlib/cuda_prng_kernels.cu.cc
Normal file
128
jaxlib/cuda_prng_kernels.cu.cc
Normal file
@ -0,0 +1,128 @@
|
||||
/* Copyright 2019 Google LLC
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include "jaxlib/cuda_prng_kernels.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
__global__ void ThreeFry2x32Kernel(const std::uint32_t* key0,
|
||||
const std::uint32_t* key1,
|
||||
const std::uint32_t* data0,
|
||||
const std::uint32_t* data1,
|
||||
std::uint32_t* out0, std::uint32_t* out1,
|
||||
std::int64_t n) {
|
||||
for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < n;
|
||||
idx += blockDim.x * gridDim.x) {
|
||||
// Rotation distances specified by the Threefry2x32 algorithm.
|
||||
std::uint32_t rotations[8] = {13, 15, 26, 6, 17, 29, 16, 24};
|
||||
std::uint32_t x[2];
|
||||
std::uint32_t ks[3];
|
||||
|
||||
// 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
|
||||
ks[2] = 0x1BD11BDA;
|
||||
|
||||
ks[0] = key0[idx];
|
||||
x[0] = data0[idx];
|
||||
ks[2] = ks[2] ^ key0[idx];
|
||||
|
||||
ks[1] = key1[idx];
|
||||
x[1] = data1[idx];
|
||||
ks[2] = ks[2] ^ key1[idx];
|
||||
|
||||
auto rotate_left = [](std::uint32_t v, std::uint32_t distance) {
|
||||
return (v << distance) | (v >> (32 - distance));
|
||||
};
|
||||
|
||||
// Performs a single round of the Threefry2x32 algorithm, with a rotation
|
||||
// amount 'rotation'.
|
||||
auto round = [&](std::uint32_t* v, std::uint32_t rotation) {
|
||||
v[0] += v[1];
|
||||
v[1] = rotate_left(v[1], rotation);
|
||||
v[1] ^= v[0];
|
||||
};
|
||||
|
||||
// There are no known statistical flaws with 13 rounds of Threefry2x32.
|
||||
// We are conservative and use 20 rounds.
|
||||
x[0] = x[0] + ks[0];
|
||||
x[1] = x[1] + ks[1];
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
round(x, rotations[i]);
|
||||
}
|
||||
|
||||
x[0] = x[0] + ks[1];
|
||||
x[1] = x[1] + ks[2] + 1u;
|
||||
for (int i = 4; i < 8; ++i) {
|
||||
round(x, rotations[i]);
|
||||
}
|
||||
|
||||
x[0] = x[0] + ks[2];
|
||||
x[1] = x[1] + ks[0] + 2u;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
round(x, rotations[i]);
|
||||
}
|
||||
|
||||
x[0] = x[0] + ks[0];
|
||||
x[1] = x[1] + ks[1] + 3u;
|
||||
for (int i = 4; i < 8; ++i) {
|
||||
round(x, rotations[i]);
|
||||
}
|
||||
|
||||
x[0] = x[0] + ks[1];
|
||||
x[1] = x[1] + ks[2] + 4u;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
round(x, rotations[i]);
|
||||
}
|
||||
|
||||
out0[idx] = x[0] + ks[2];
|
||||
out1[idx] = x[1] + ks[0] + 5u;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
struct ThreeFry2x32Descriptor {
|
||||
std::int64_t n;
|
||||
};
|
||||
|
||||
pybind11::bytes BuildCudaThreeFry2x32Descriptor(std::int64_t n) {
|
||||
return PackDescriptor(ThreeFry2x32Descriptor{n});
|
||||
}
|
||||
|
||||
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len) {
|
||||
std::array<const std::uint32_t*, 2> keys;
|
||||
keys[0] = reinterpret_cast<const std::uint32_t*>(buffers[0]);
|
||||
keys[1] = reinterpret_cast<const std::uint32_t*>(buffers[1]);
|
||||
std::array<const std::uint32_t*, 2> data;
|
||||
data[0] = reinterpret_cast<const std::uint32_t*>(buffers[2]);
|
||||
data[1] = reinterpret_cast<const std::uint32_t*>(buffers[3]);
|
||||
std::array<std::uint32_t*, 2> out;
|
||||
out[0] = reinterpret_cast<std::uint32_t*>(buffers[4]);
|
||||
out[1] = reinterpret_cast<std::uint32_t*>(buffers[5]);
|
||||
const auto& descriptor =
|
||||
*UnpackDescriptor<ThreeFry2x32Descriptor>(opaque, opaque_len);
|
||||
const int block_dim = 128;
|
||||
const std::int64_t grid_dim =
|
||||
std::min<std::int64_t>(1024, (descriptor.n + block_dim - 1) / block_dim);
|
||||
ThreeFry2x32Kernel<<<grid_dim, block_dim, /*dynamic_shared_mem_bytes=*/0,
|
||||
stream>>>(keys[0], keys[1], data[0], data[1], out[0],
|
||||
out[1], descriptor.n);
|
||||
}
|
||||
|
||||
} // namespace jax
|
33
jaxlib/cuda_prng_kernels.h
Normal file
33
jaxlib/cuda_prng_kernels.h
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2019 Google LLC
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_PRNG_KERNELS_H_
|
||||
#define JAXLIB_PRNG_KERNELS_H_
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
pybind11::bytes BuildCudaThreeFry2x32Descriptor(std::int64_t n);
|
||||
|
||||
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len);
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_PRNG_KERNELS_H_
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
@ -192,27 +193,6 @@ int SizeOfType(Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
// Descriptor objects are opaque host-side objects used to pass data from JAX
|
||||
// to the custom kernel launched by XLA. Currently simply treat host-side
|
||||
// structures as byte-strings; this is not portable across architectures. If
|
||||
// portability is needed, we could switch to using a representation such as
|
||||
// protocol buffers or flatbuffers.
|
||||
|
||||
// Packs a descriptor object into a py::bytes structure.
|
||||
template <typename T>
|
||||
py::bytes PackDescriptor(const T& descriptor) {
|
||||
return py::bytes(absl::bit_cast<const char*>(&descriptor), sizeof(T));
|
||||
}
|
||||
|
||||
// Unpacks a descriptor object from a byte string.
|
||||
template <typename T>
|
||||
const T* UnpackDescriptor(const char* opaque, size_t opaque_len) {
|
||||
if (opaque_len != sizeof(T)) {
|
||||
throw std::runtime_error("Invalid size for linalg operation descriptor.");
|
||||
}
|
||||
return absl::bit_cast<const T*>(opaque);
|
||||
}
|
||||
|
||||
// getrf: LU decomposition
|
||||
|
||||
struct GetrfDescriptor {
|
||||
@ -1150,11 +1130,6 @@ void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
py::capsule EncapsulateFunction(T* fn) {
|
||||
return py::capsule(absl::bit_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
|
||||
}
|
||||
|
||||
py::dict Registrations() {
|
||||
py::dict dict;
|
||||
dict["cusolver_getrf"] = EncapsulateFunction(Getrf);
|
||||
|
55
jaxlib/kernel_helpers.h
Normal file
55
jaxlib/kernel_helpers.h
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2019 Google LLC
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_KERNEL_HELPERS_H_
|
||||
#define JAXLIB_KERNEL_HELPERS_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "absl/base/casts.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
// Descriptor objects are opaque host-side objects used to pass data from JAX
|
||||
// to the custom kernel launched by XLA. Currently simply treat host-side
|
||||
// structures as byte-strings; this is not portable across architectures. If
|
||||
// portability is needed, we could switch to using a representation such as
|
||||
// protocol buffers or flatbuffers.
|
||||
|
||||
// Packs a descriptor object into a pybind11::bytes structure.
|
||||
template <typename T>
|
||||
pybind11::bytes PackDescriptor(const T& descriptor) {
|
||||
return pybind11::bytes(absl::bit_cast<const char*>(&descriptor), sizeof(T));
|
||||
}
|
||||
|
||||
// Unpacks a descriptor object from a byte string.
|
||||
template <typename T>
|
||||
const T* UnpackDescriptor(const char* opaque, std::size_t opaque_len) {
|
||||
if (opaque_len != sizeof(T)) {
|
||||
throw std::runtime_error("Invalid size for linalg operation descriptor.");
|
||||
}
|
||||
return absl::bit_cast<const T*>(opaque);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
pybind11::capsule EncapsulateFunction(T* fn) {
|
||||
return pybind11::capsule(absl::bit_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
|
||||
}
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_KERNEL_HELPERS_H_
|
@ -95,6 +95,17 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
onp.uint32([0x243f6a88, 0x85a308d3]))
|
||||
self.assertEqual(expected, result_to_hex(result))
|
||||
|
||||
def testThreefry2x32Large(self):
|
||||
n = 10000000
|
||||
result = random.threefry_2x32(
|
||||
(onp.uint32(0x13198a2e), onp.uint32(0x03707344)),
|
||||
np.concatenate([
|
||||
np.full((n,), 0x243f6a88, np.uint32),
|
||||
np.full((n,), 0x85a308d3, np.uint32)
|
||||
]))
|
||||
onp.testing.assert_equal(result[:n], onp.full((n,), 0xc4923a9c, dtype=onp.uint32))
|
||||
onp.testing.assert_equal(result[n:], onp.full((n,), 0x483df7a0, dtype=onp.uint32))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
||||
for dtype in [onp.float32, onp.float64]))
|
||||
|
Loading…
x
Reference in New Issue
Block a user