Add a handwritten ThreeFry2x32 CUDA kernel. (#1756)

In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time.

When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.
This commit is contained in:
Peter Hawkins 2019-11-24 13:06:23 -05:00 committed by GitHub
parent 67038321f8
commit 534d812b57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 470 additions and 87 deletions

View File

@ -23,10 +23,10 @@ http_archive(
# and update the sha256 with the result.
http_archive(
name = "org_tensorflow",
sha256 = "d83221d413fd510ac8bc68ae158fcc17c2300c7c8c3bd439e30300438c2e3ce0",
strip_prefix = "tensorflow-d85c68d5bdcb3f72abfe22a73a638c11def69a7e",
sha256 = "412ef0824d5dcfe6e139e1fa25f72569e699b3ec06d374c0e19ba0bf60c32952",
strip_prefix = "tensorflow-883b5becaced22f7dd9e3c23d9d259f55e087cb5",
urls = [
"https://github.com/tensorflow/tensorflow/archive/d85c68d5bdcb3f72abfe22a73a638c11def69a7e.tar.gz",
"https://github.com/tensorflow/tensorflow/archive/883b5becaced22f7dd9e3c23d9d259f55e087cb5.tar.gz",
],
)

View File

@ -14,10 +14,7 @@
# JAX is Autograd and XLA
load(
"@org_tensorflow//tensorflow/core/platform:default/cuda_build_defs.bzl",
"if_cuda_is_configured",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
licenses(["notice"]) # Apache 2
@ -32,9 +29,10 @@ sh_binary(
"//jaxlib",
"//jaxlib:lapack.so",
"//jaxlib:pytree",
] + if_cuda_is_configured([
] + if_cuda([
"//jaxlib:cublas_kernels",
"//jaxlib:cusolver_kernels",
"//jaxlib:cuda_prng_kernels",
]),
deps = ["@bazel_tools//tools/bash/runfiles"],
)

View File

@ -177,6 +177,9 @@ build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1
# Sets the default Apple platform to macOS.
build --apple_platform_type=macos
# Make Bazel print out all options from rc files.
build --announce_rc
# Disable enabled-by-default TensorFlow features that we don't care about.
build --define=no_aws_support=true
build --define=no_gcp_support=true

View File

@ -57,9 +57,11 @@ cp -f "$(rlocation __main__/jaxlib/pytree.so)" "${TARGET}/jaxlib"
if [[ -x "$(rlocation __main__/jaxlib/cusolver_kernels.so)" ]]; then
cp -f "$(rlocation __main__/jaxlib/cublas_kernels.so)" "${TARGET}/jaxlib"
cp -f "$(rlocation __main__/jaxlib/cusolver_kernels.so)" "${TARGET}/jaxlib"
cp -f "$(rlocation __main__/jaxlib/cuda_prng_kernels.so)" "${TARGET}/jaxlib"
fi
cp -f "$(rlocation __main__/jaxlib/version.py)" "${TARGET}/jaxlib"
cp -f "$(rlocation __main__/jaxlib/cusolver.py)" "${TARGET}/jaxlib"
cp -f "$(rlocation __main__/jaxlib/cuda_prng.py)" "${TARGET}/jaxlib"
cp -f "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so)" \
"${TARGET}/jaxlib"
cp -f "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so)" \

View File

@ -182,13 +182,15 @@ def broadcast_batcher(prim, args, dims, **params):
if len(shapes) == 1:
# if there's only agreeing batch dims and scalars, just call the primitive
d = next(d for d in dims if d is not not_mapped)
return prim.bind(*args, **params), d
out = prim.bind(*args, **params)
return (out, (d,) * len(out)) if prim.multiple_results else (out, d)
else:
size, = {shape[d] for shape, d in shapes if d is not not_mapped}
args = [bdim_at_front(x, d, size) for x, d in zip(args, dims)]
ndim = max(onp.ndim(x) for x in args) # special-case scalar broadcasting
args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)]
return prim.bind(*args, **params), 0
out = prim.bind(*args, **params)
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
def _handle_scalar_broadcasting(nd, x, d):
if d is not_mapped or nd == onp.ndim(x):

View File

@ -20,7 +20,8 @@ from .lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
_select_and_gather_add, _float, _complex,
_input_dtype, _const, _eq_meet, _safe_mul,
_broadcasting_select, _check_user_dtype_supported,
_one, _const, _upcast_fp16_for_computation)
_one, _const, _upcast_fp16_for_computation,
_broadcasting_shape_rule)
from .lax_control_flow import *
from .lax_fft import *
from .lax_parallel import *

View File

@ -53,3 +53,7 @@ from jaxlib import lapack
from jaxlib import pytree
from jaxlib import cusolver
try:
from jaxlib import cuda_prng
except ImportError:
cuda_prng = None

View File

@ -25,6 +25,7 @@ from __future__ import division
from __future__ import print_function
from functools import partial
import itertools
import numpy as onp
@ -35,9 +36,13 @@ from . import dtypes
from .api import custom_transforms, defjvp, jit, vmap
from .numpy.lax_numpy import _constant_like, asarray, stack
from jax.lib import xla_bridge
from jax.lib import cuda_prng
from jax import core
from jax import abstract_arrays
from jax.scipy.special import logit
from jax.scipy.linalg import cholesky
from jax.interpreters import batching
from jax.interpreters import xla
def PRNGKey(seed):
@ -92,9 +97,18 @@ def _bit_stats(bits):
### hash function and split
def _threefry2x32_abstract_eval(*args):
if any(a.dtype != np.uint32 for a in args):
raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}"
.format(args))
if all(isinstance(arg, abstract_arrays.ShapedArray) for arg in args):
shape = lax._broadcasting_shape_rule(*args)
aval = abstract_arrays.ShapedArray(shape, np.dtype(np.uint32))
else:
aval = abstract_arrays.UnshapedArray(np.dtype(np.uint32))
return (aval,) * 2
@jit
def threefry_2x32(keypair, count):
def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
"""Apply the Threefry 2x32 hash.
Args:
@ -104,13 +118,8 @@ def threefry_2x32(keypair, count):
Returns:
An array of dtype uint32 with the same shape as `count`.
"""
# Based on ThreeFry2x32 by phawkins@ in //.../xla/client/lib/prng.cc
key1, key2 = keypair
if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == onp.uint32:
msg = "threefry_2x32 requires uint32 arguments, got {}"
raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))
rotate_left = _make_rotate_left(lax.dtype(count))
x = [x1, x2]
rotate_left = _make_rotate_left(onp.uint32)
def apply_round(v, rot):
v = v[:]
@ -119,24 +128,11 @@ def threefry_2x32(keypair, count):
v[1] = v[0] ^ v[1]
return v
odd_size = count.size % 2
if odd_size:
x = list(np.split(np.concatenate([count.ravel(), onp.uint32([0])]), 2))
else:
x = list(np.split(count.ravel(), 2))
rotations = [onp.array([13, 15, 26, 6], dtype=onp.uint32),
onp.array([17, 29, 16, 24], dtype=onp.uint32)]
ks = [key1, key2, key1 ^ key2 ^ onp.uint32(0x1BD11BDA)]
# TODO(mattjj): see https://github.com/google/jax/issues/1267, as a hopefully
# temporary workaround for the facts that (1) XLA:CPU compile time is too slow
# with unrolled loops and (2) XLA:GPU execution time is too slow with rolled
# loops, we switch on whether the default backend is CPU or GPU. If this kind
# of switch ends up sticking around, we should take into account #1211 and put
# the switch in the translation rule rather than here in the traceable.
use_rolled_loops = xla_bridge.get_backend().platform == "cpu"
x[0] = x[0] + ks[0]
x[1] = x[1] + ks[1]
@ -176,6 +172,56 @@ def threefry_2x32(keypair, count):
x[0] = x[0] + ks[2]
x[1] = x[1] + ks[0] + onp.uint32(5)
return tuple(x)
def _threefry2x32_gpu_translation_rule(c, k1, k2, x1, x2):
shape = lax.broadcast_shapes(
c.GetShape(k1).dimensions(), c.GetShape(k2).dimensions(),
c.GetShape(x1).dimensions(), c.GetShape(x2).dimensions())
rank = len(shape)
def _broadcast(x):
ndims = c.GetShape(x).rank()
return c.BroadcastInDim(x, shape, tuple(range(rank - ndims, rank)))
return cuda_prng.threefry2x32(
c, (_broadcast(k1), _broadcast(k2)), (_broadcast(x1), _broadcast(x2)))
threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
batching.defbroadcasting(threefry2x32_p)
xla.translations[threefry2x32_p] = xla.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=False), instantiate=True)
xla.backend_specific_translations['cpu'][threefry2x32_p] = xla.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=True), instantiate=True)
if cuda_prng:
xla.backend_specific_translations['gpu'][threefry2x32_p] = \
_threefry2x32_gpu_translation_rule
@jit
def threefry_2x32(keypair, count):
"""Apply the Threefry 2x32 hash.
Args:
keypair: a pair of 32bit unsigned integers used for the key.
count: an array of dtype uint32 used for the counts.
Returns:
An array of dtype uint32 with the same shape as `count`.
"""
key1, key2 = keypair
if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == onp.uint32:
msg = "threefry_2x32 requires uint32 arguments, got {}"
raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))
odd_size = count.size % 2
if odd_size:
x = list(np.split(np.concatenate([count.ravel(), onp.uint32([0])]), 2))
else:
x = list(np.split(count.ravel(), 2))
x = threefry2x32_p.bind(key1, key2, x[0], x[1])
out = np.concatenate(x)
assert out.dtype == onp.uint32
return lax.reshape(out[:-1] if odd_size else out, count.shape)

View File

@ -16,11 +16,24 @@
load("@org_tensorflow//tensorflow/core/platform:default/build_config.bzl", "pyx_library")
load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library")
licenses(["notice"])
package(default_visibility = ["//visibility:public"])
cc_library(
name = "kernel_helpers",
hdrs = ["kernel_helpers.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
deps = [
"@com_google_absl//absl/base",
"@pybind11",
],
)
pyx_library(
name = "lapack",
srcs = ["lapack.pyx"],
@ -30,11 +43,21 @@ pyx_library(
py_library(
name = "jaxlib",
srcs = [
"cuda_prng.py",
"cusolver.py",
"version.py",
],
)
py_library(
name = "gpu_support",
deps = [
":cublas_kernels",
":cuda_prng_kernels",
":cusolver_kernels",
],
)
pybind_extension(
name = "pytree",
srcs = ["pytree.cc"],
@ -65,6 +88,7 @@ pybind_extension(
features = ["-use_header_modules"],
module_name = "cublas_kernels",
deps = [
":kernel_helpers",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
@ -91,6 +115,7 @@ pybind_extension(
features = ["-use_header_modules"],
module_name = "cusolver_kernels",
deps = [
":kernel_helpers",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
@ -102,6 +127,37 @@ pybind_extension(
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cudart",
"@local_config_cuda//cuda:cusolver",
],
)
cuda_library(
name = "cuda_prng_kernels_lib",
srcs = ["cuda_prng_kernels.cu.cc"],
hdrs = ["cuda_prng_kernels.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
deps = [
":kernel_helpers",
"@local_config_cuda//cuda:cuda_headers",
],
)
pybind_extension(
name = "cuda_prng_kernels",
srcs = ["cuda_prng_kernels.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "cuda_prng_kernels",
deps = [
":cuda_prng_kernels_lib",
":kernel_helpers",
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cudart",
"@pybind11",
],
)

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "jaxlib/kernel_helpers.h"
namespace jax {
namespace {
@ -186,27 +187,6 @@ int SizeOfType(Type type) {
}
}
// Descriptor objects are opaque host-side objects used to pass data from JAX
// to the custom kernel launched by XLA. Currently simply treat host-side
// structures as byte-strings; this is not portable across architectures. If
// portability is needed, we could switch to using a representation such as
// protocol buffers or flatbuffers.
// Packs a descriptor object into a py::bytes structure.
template <typename T>
py::bytes PackDescriptor(const T& descriptor) {
return py::bytes(absl::bit_cast<const char*>(&descriptor), sizeof(T));
}
// Unpacks a descriptor object from a byte string.
template <typename T>
const T* UnpackDescriptor(const char* opaque, size_t opaque_len) {
if (opaque_len != sizeof(T)) {
throw std::runtime_error("Invalid size for linalg operation descriptor.");
}
return absl::bit_cast<const T*>(opaque);
}
// Builds an array of pointers to each array in a batch, in device memory.
template <typename T>
cudaError_t MakeBatchPointers(T* buffer, T** dev_ptrs, int batch,
@ -390,11 +370,6 @@ void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
}
}
template <typename T>
py::capsule EncapsulateFunction(T* fn) {
return py::capsule(absl::bit_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
}
py::dict Registrations() {
py::dict dict;
dict["cublas_trsm_batched"] = EncapsulateFunction(TrsmBatched);
@ -409,4 +384,4 @@ PYBIND11_MODULE(cublas_kernels, m) {
}
} // namespace
} // namespace jax
} // namespace jax

58
jaxlib/cuda_prng.py Normal file
View File

@ -0,0 +1,58 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import operator
import numpy as np
from six.moves import reduce
from jaxlib import xla_client
try:
from jaxlib import cuda_prng_kernels
for _name, _value in cuda_prng_kernels.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")
except ImportError:
pass
_prod = lambda xs: reduce(operator.mul, xs, 1)
def threefry2x32(c, keys, data):
"""ThreeFry2x32 kernel for GPU."""
assert len(keys) == 2, keys
assert len(data) == 2, data
dims = c.GetShape(keys[0]).dimensions()
dtype = np.dtype(np.uint32)
for x in itertools.chain(keys, data):
x_shape = c.GetShape(x)
assert x_shape.element_type() == dtype
assert dims == x_shape.dimensions(), (dims, x_shape)
ndims = len(dims)
opaque = cuda_prng_kernels.cuda_threefry2x32_descriptor(_prod(dims))
layout = tuple(range(ndims - 1, -1, -1))
shape = xla_client.Shape.array_shape(dtype, dims, layout)
return c.CustomCall(
b"cuda_threefry2x32",
operands=(keys[0], keys[1], data[0], data[1]),
shape_with_layout=xla_client.Shape.tuple_shape([shape, shape]),
operand_shapes_with_layout=(shape,) * 4,
opaque=opaque)

View File

@ -0,0 +1,36 @@
/* Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/cuda_prng_kernels.h"
#include "jaxlib/kernel_helpers.h"
#include "include/pybind11/pybind11.h"
namespace jax {
namespace {
pybind11::dict Registrations() {
pybind11::dict dict;
dict["cuda_threefry2x32"] = EncapsulateFunction(CudaThreeFry2x32);
return dict;
}
PYBIND11_MODULE(cuda_prng_kernels, m) {
m.def("registrations", &Registrations);
m.def("cuda_threefry2x32_descriptor", &BuildCudaThreeFry2x32Descriptor);
}
} // namespace
} // namespace jax

View File

@ -0,0 +1,128 @@
/* Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstddef>
#include "jaxlib/cuda_prng_kernels.h"
#include "jaxlib/kernel_helpers.h"
namespace jax {
namespace {
__global__ void ThreeFry2x32Kernel(const std::uint32_t* key0,
const std::uint32_t* key1,
const std::uint32_t* data0,
const std::uint32_t* data1,
std::uint32_t* out0, std::uint32_t* out1,
std::int64_t n) {
for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < n;
idx += blockDim.x * gridDim.x) {
// Rotation distances specified by the Threefry2x32 algorithm.
std::uint32_t rotations[8] = {13, 15, 26, 6, 17, 29, 16, 24};
std::uint32_t x[2];
std::uint32_t ks[3];
// 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
ks[2] = 0x1BD11BDA;
ks[0] = key0[idx];
x[0] = data0[idx];
ks[2] = ks[2] ^ key0[idx];
ks[1] = key1[idx];
x[1] = data1[idx];
ks[2] = ks[2] ^ key1[idx];
auto rotate_left = [](std::uint32_t v, std::uint32_t distance) {
return (v << distance) | (v >> (32 - distance));
};
// Performs a single round of the Threefry2x32 algorithm, with a rotation
// amount 'rotation'.
auto round = [&](std::uint32_t* v, std::uint32_t rotation) {
v[0] += v[1];
v[1] = rotate_left(v[1], rotation);
v[1] ^= v[0];
};
// There are no known statistical flaws with 13 rounds of Threefry2x32.
// We are conservative and use 20 rounds.
x[0] = x[0] + ks[0];
x[1] = x[1] + ks[1];
for (int i = 0; i < 4; ++i) {
round(x, rotations[i]);
}
x[0] = x[0] + ks[1];
x[1] = x[1] + ks[2] + 1u;
for (int i = 4; i < 8; ++i) {
round(x, rotations[i]);
}
x[0] = x[0] + ks[2];
x[1] = x[1] + ks[0] + 2u;
for (int i = 0; i < 4; ++i) {
round(x, rotations[i]);
}
x[0] = x[0] + ks[0];
x[1] = x[1] + ks[1] + 3u;
for (int i = 4; i < 8; ++i) {
round(x, rotations[i]);
}
x[0] = x[0] + ks[1];
x[1] = x[1] + ks[2] + 4u;
for (int i = 0; i < 4; ++i) {
round(x, rotations[i]);
}
out0[idx] = x[0] + ks[2];
out1[idx] = x[1] + ks[0] + 5u;
}
}
} // namespace
struct ThreeFry2x32Descriptor {
std::int64_t n;
};
pybind11::bytes BuildCudaThreeFry2x32Descriptor(std::int64_t n) {
return PackDescriptor(ThreeFry2x32Descriptor{n});
}
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len) {
std::array<const std::uint32_t*, 2> keys;
keys[0] = reinterpret_cast<const std::uint32_t*>(buffers[0]);
keys[1] = reinterpret_cast<const std::uint32_t*>(buffers[1]);
std::array<const std::uint32_t*, 2> data;
data[0] = reinterpret_cast<const std::uint32_t*>(buffers[2]);
data[1] = reinterpret_cast<const std::uint32_t*>(buffers[3]);
std::array<std::uint32_t*, 2> out;
out[0] = reinterpret_cast<std::uint32_t*>(buffers[4]);
out[1] = reinterpret_cast<std::uint32_t*>(buffers[5]);
const auto& descriptor =
*UnpackDescriptor<ThreeFry2x32Descriptor>(opaque, opaque_len);
const int block_dim = 128;
const std::int64_t grid_dim =
std::min<std::int64_t>(1024, (descriptor.n + block_dim - 1) / block_dim);
ThreeFry2x32Kernel<<<grid_dim, block_dim, /*dynamic_shared_mem_bytes=*/0,
stream>>>(keys[0], keys[1], data[0], data[1], out[0],
out[1], descriptor.n);
}
} // namespace jax

View File

@ -0,0 +1,33 @@
/* Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_PRNG_KERNELS_H_
#define JAXLIB_PRNG_KERNELS_H_
#include <cstddef>
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "include/pybind11/pybind11.h"
namespace jax {
pybind11::bytes BuildCudaThreeFry2x32Descriptor(std::int64_t n);
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len);
} // namespace jax
#endif // JAXLIB_PRNG_KERNELS_H_

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "jaxlib/kernel_helpers.h"
namespace jax {
namespace {
@ -192,27 +193,6 @@ int SizeOfType(Type type) {
}
}
// Descriptor objects are opaque host-side objects used to pass data from JAX
// to the custom kernel launched by XLA. Currently simply treat host-side
// structures as byte-strings; this is not portable across architectures. If
// portability is needed, we could switch to using a representation such as
// protocol buffers or flatbuffers.
// Packs a descriptor object into a py::bytes structure.
template <typename T>
py::bytes PackDescriptor(const T& descriptor) {
return py::bytes(absl::bit_cast<const char*>(&descriptor), sizeof(T));
}
// Unpacks a descriptor object from a byte string.
template <typename T>
const T* UnpackDescriptor(const char* opaque, size_t opaque_len) {
if (opaque_len != sizeof(T)) {
throw std::runtime_error("Invalid size for linalg operation descriptor.");
}
return absl::bit_cast<const T*>(opaque);
}
// getrf: LU decomposition
struct GetrfDescriptor {
@ -1150,11 +1130,6 @@ void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque,
}
}
template <typename T>
py::capsule EncapsulateFunction(T* fn) {
return py::capsule(absl::bit_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
}
py::dict Registrations() {
py::dict dict;
dict["cusolver_getrf"] = EncapsulateFunction(Getrf);

55
jaxlib/kernel_helpers.h Normal file
View File

@ -0,0 +1,55 @@
/* Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_KERNEL_HELPERS_H_
#define JAXLIB_KERNEL_HELPERS_H_
#include <cstddef>
#include <stdexcept>
#include "absl/base/casts.h"
#include "include/pybind11/pybind11.h"
namespace jax {
// Descriptor objects are opaque host-side objects used to pass data from JAX
// to the custom kernel launched by XLA. Currently simply treat host-side
// structures as byte-strings; this is not portable across architectures. If
// portability is needed, we could switch to using a representation such as
// protocol buffers or flatbuffers.
// Packs a descriptor object into a pybind11::bytes structure.
template <typename T>
pybind11::bytes PackDescriptor(const T& descriptor) {
return pybind11::bytes(absl::bit_cast<const char*>(&descriptor), sizeof(T));
}
// Unpacks a descriptor object from a byte string.
template <typename T>
const T* UnpackDescriptor(const char* opaque, std::size_t opaque_len) {
if (opaque_len != sizeof(T)) {
throw std::runtime_error("Invalid size for linalg operation descriptor.");
}
return absl::bit_cast<const T*>(opaque);
}
template <typename T>
pybind11::capsule EncapsulateFunction(T* fn) {
return pybind11::capsule(absl::bit_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
}
} // namespace jax
#endif // JAXLIB_KERNEL_HELPERS_H_

View File

@ -95,6 +95,17 @@ class LaxRandomTest(jtu.JaxTestCase):
onp.uint32([0x243f6a88, 0x85a308d3]))
self.assertEqual(expected, result_to_hex(result))
def testThreefry2x32Large(self):
n = 10000000
result = random.threefry_2x32(
(onp.uint32(0x13198a2e), onp.uint32(0x03707344)),
np.concatenate([
np.full((n,), 0x243f6a88, np.uint32),
np.full((n,), 0x85a308d3, np.uint32)
]))
onp.testing.assert_equal(result[:n], onp.full((n,), 0xc4923a9c, dtype=onp.uint32))
onp.testing.assert_equal(result[n:], onp.full((n,), 0x483df7a0, dtype=onp.uint32))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64]))