mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 19:16:06 +00:00

Some folks want to be able to run JAX-generated HLO computations from C++, and those computations may refer to JAX's custom kernels. This change splits the custom kernels into separate modules that may be used independently of Python. The general pattern is that each extension now has two parts: * xyz_kernels.{cc, h} — the C++ parts * xyz.cc — Python bindings around the C++ parts, including code to build any descriptor objects. There's also a new (minimally supported) module named "gpu_kernels.cc" which registers JAX's GPU kernels with the XLA C++ custom kernel registry. PiperOrigin-RevId: 394460343
430 lines
15 KiB
Python
430 lines
15 KiB
Python
# Copyright 2019 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import functools
|
|
import operator
|
|
|
|
import numpy as np
|
|
|
|
from jaxlib import xla_client
|
|
|
|
try:
|
|
from . import _cublas
|
|
for _name, _value in _cublas.registrations().items():
|
|
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
|
except ImportError:
|
|
pass
|
|
|
|
try:
|
|
from . import _cusolver
|
|
for _name, _value in _cusolver.registrations().items():
|
|
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
|
except ImportError:
|
|
pass
|
|
|
|
_ops = xla_client.ops
|
|
_Shape = xla_client.Shape
|
|
|
|
# TODO(phawkins): remove after we no longer need to support old jax releases.
|
|
def _unpack_builder(c):
|
|
# If `c` is a ComputationBuilder object, extracts the underlying XlaBuilder.
|
|
return getattr(c, "_builder", c)
|
|
|
|
def _real_type(dtype):
|
|
"""Returns the real equivalent of 'dtype'."""
|
|
if dtype == np.float32:
|
|
return np.float32
|
|
elif dtype == np.float64:
|
|
return np.float64
|
|
elif dtype == np.complex64:
|
|
return np.float32
|
|
elif dtype == np.complex128:
|
|
return np.float64
|
|
else:
|
|
raise NotImplementedError("Unsupported dtype {}".format(dtype))
|
|
|
|
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
|
|
|
|
def trsm(c, a, b, left_side=False, lower=False, trans_a=False, conj_a=False,
|
|
diag=False):
|
|
"""Batched triangular solve.
|
|
|
|
XLA implements unbatched triangular solve directly, so we need only implement
|
|
the batched case."""
|
|
c = _unpack_builder(c)
|
|
b_shape = c.get_shape(b)
|
|
dtype = b_shape.element_type()
|
|
dims = b_shape.dimensions()
|
|
assert len(dims) >= 2
|
|
m, n = dims[-2:]
|
|
batch_dims = tuple(dims[:-2])
|
|
num_bd = len(batch_dims)
|
|
batch = _prod(batch_dims)
|
|
k = m if left_side else n
|
|
|
|
a_shape = c.get_shape(a)
|
|
if (batch_dims + (k, k) != a_shape.dimensions() or
|
|
a_shape.element_type() != dtype):
|
|
raise ValueError("Argument mismatch for trsm, got {} and {}".format(
|
|
a_shape, b_shape))
|
|
|
|
if conj_a and not trans_a:
|
|
raise NotImplementedError("Conjugation without transposition not supported")
|
|
|
|
lwork, opaque = _cublas.build_trsm_batched_descriptor(
|
|
np.dtype(dtype), batch, m, n, left_side, lower, trans_a, conj_a, diag)
|
|
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
|
out = _ops.CustomCallWithLayout(
|
|
c, b"cublas_trsm_batched",
|
|
operands=(a, b),
|
|
shape_with_layout=_Shape.tuple_shape((
|
|
_Shape.array_shape(dtype, b_shape.dimensions(), layout),
|
|
_Shape.array_shape(np.dtype(np.int8), (lwork,), (0,)),
|
|
_Shape.array_shape(np.dtype(np.int8), (lwork,), (0,)))),
|
|
operand_shapes_with_layout=(
|
|
_Shape.array_shape(dtype, a_shape.dimensions(), layout),
|
|
_Shape.array_shape(dtype, b_shape.dimensions(), layout),
|
|
),
|
|
opaque=opaque,
|
|
api_version=xla_client.ops.CustomCallApiVersion
|
|
.API_VERSION_STATUS_RETURNING)
|
|
return _ops.GetTupleElement(out, 0)
|
|
|
|
|
|
def potrf(c, a, lower):
|
|
"""Cholesky decomposition."""
|
|
c = _unpack_builder(c)
|
|
a_shape = c.get_shape(a)
|
|
dtype = a_shape.element_type()
|
|
dims = a_shape.dimensions()
|
|
m, n = dims[-2:]
|
|
assert m == n
|
|
batch_dims = tuple(dims[:-2])
|
|
num_bd = len(batch_dims)
|
|
batch = _prod(batch_dims)
|
|
|
|
lwork, opaque = _cusolver.build_potrf_descriptor(
|
|
np.dtype(dtype), lower, batch, n)
|
|
kernel = b"cusolver_potrf"
|
|
|
|
out = _ops.CustomCallWithLayout(
|
|
c, kernel,
|
|
operands=(a,),
|
|
shape_with_layout=_Shape.tuple_shape((
|
|
_Shape.array_shape(
|
|
dtype, batch_dims + (n, n),
|
|
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
|
|
_Shape.array_shape(
|
|
np.dtype(np.int32), batch_dims, tuple(range(num_bd - 1, -1, -1))),
|
|
_Shape.array_shape(np.dtype(np.int8), (lwork,), (0,)),
|
|
)),
|
|
operand_shapes_with_layout=(_Shape.array_shape(
|
|
dtype, batch_dims + (n, n),
|
|
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),),
|
|
opaque=opaque,
|
|
api_version=xla_client.ops.CustomCallApiVersion
|
|
.API_VERSION_STATUS_RETURNING)
|
|
return _ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1)
|
|
|
|
|
|
def getrf(c, a):
|
|
"""LU decomposition."""
|
|
c = _unpack_builder(c)
|
|
a_shape = c.get_shape(a)
|
|
dtype = a_shape.element_type()
|
|
dims = a_shape.dimensions()
|
|
assert len(dims) >= 2
|
|
m, n = dims[-2:]
|
|
batch_dims = tuple(dims[:-2])
|
|
num_bd = len(batch_dims)
|
|
batch = _prod(batch_dims)
|
|
|
|
if batch > 1 and m == n and m // batch <= 128:
|
|
lwork, opaque = _cublas.build_getrf_batched_descriptor(
|
|
np.dtype(dtype), batch, m)
|
|
workspace = _Shape.array_shape(np.dtype(np.int8), (lwork,), (0,))
|
|
kernel = b"cublas_getrf_batched"
|
|
else:
|
|
lwork, opaque = _cusolver.build_getrf_descriptor(
|
|
np.dtype(dtype), batch, m, n)
|
|
workspace = _Shape.array_shape(dtype, (lwork,), (0,))
|
|
kernel = b"cusolver_getrf"
|
|
|
|
out = _ops.CustomCallWithLayout(
|
|
c, kernel,
|
|
operands=(a,),
|
|
shape_with_layout=_Shape.tuple_shape((
|
|
_Shape.array_shape(
|
|
dtype, batch_dims + (m, n),
|
|
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
|
|
_Shape.array_shape(
|
|
np.dtype(np.int32), batch_dims + (min(m, n),),
|
|
tuple(range(num_bd, -1, -1))),
|
|
_Shape.array_shape(
|
|
np.dtype(np.int32), batch_dims, tuple(range(num_bd - 1, -1, -1))),
|
|
workspace,
|
|
)),
|
|
operand_shapes_with_layout=(_Shape.array_shape(
|
|
dtype, batch_dims + (m, n),
|
|
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),),
|
|
opaque=opaque,
|
|
api_version=xla_client.ops.CustomCallApiVersion
|
|
.API_VERSION_STATUS_RETURNING)
|
|
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
|
|
_ops.GetTupleElement(out, 2))
|
|
|
|
def geqrf(c, a):
|
|
"""QR decomposition."""
|
|
c = _unpack_builder(c)
|
|
a_shape = c.get_shape(a)
|
|
dtype = a_shape.element_type()
|
|
dims = a_shape.dimensions()
|
|
assert len(dims) >= 2
|
|
m, n = dims[-2:]
|
|
batch_dims = tuple(dims[:-2])
|
|
num_bd = len(batch_dims)
|
|
batch = _prod(batch_dims)
|
|
|
|
lwork, opaque = _cusolver.build_geqrf_descriptor(
|
|
np.dtype(dtype), batch, m, n)
|
|
workspace = _Shape.array_shape(dtype, (lwork,), (0,))
|
|
kernel = b"cusolver_geqrf"
|
|
|
|
out = _ops.CustomCallWithLayout(
|
|
c, kernel,
|
|
operands=(a,),
|
|
shape_with_layout=_Shape.tuple_shape((
|
|
_Shape.array_shape(
|
|
dtype, batch_dims + (m, n),
|
|
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
|
|
_Shape.array_shape(
|
|
dtype, batch_dims + (min(m, n),),
|
|
tuple(range(num_bd, -1, -1))),
|
|
_Shape.array_shape(
|
|
np.dtype(np.int32), batch_dims, tuple(range(num_bd - 1, -1, -1))),
|
|
workspace,
|
|
)),
|
|
operand_shapes_with_layout=(_Shape.array_shape(
|
|
dtype, batch_dims + (m, n),
|
|
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),),
|
|
opaque=opaque,
|
|
api_version=xla_client.ops.CustomCallApiVersion
|
|
.API_VERSION_STATUS_RETURNING)
|
|
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
|
|
_ops.GetTupleElement(out, 2))
|
|
|
|
def orgqr(c, a, tau):
|
|
"""Product of elementary Householder reflections."""
|
|
c = _unpack_builder(c)
|
|
a_shape = c.get_shape(a)
|
|
dtype = a_shape.element_type()
|
|
dims = a_shape.dimensions()
|
|
assert len(dims) >= 2
|
|
m, n = dims[-2:]
|
|
batch_dims = tuple(dims[:-2])
|
|
num_bd = len(batch_dims)
|
|
batch = _prod(batch_dims)
|
|
|
|
tau_dims = c.get_shape(tau).dimensions()
|
|
assert tau_dims[:-1] == dims[:-2]
|
|
k = tau_dims[-1]
|
|
|
|
lwork, opaque = _cusolver.build_orgqr_descriptor(
|
|
np.dtype(dtype), batch, m, n, k)
|
|
workspace = _Shape.array_shape(dtype, (lwork,), (0,))
|
|
kernel = b"cusolver_orgqr"
|
|
|
|
out = _ops.CustomCallWithLayout(
|
|
c, kernel,
|
|
operands=(a, tau),
|
|
shape_with_layout=_Shape.tuple_shape((
|
|
_Shape.array_shape(
|
|
dtype, batch_dims + (m, n),
|
|
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
|
|
_Shape.array_shape(
|
|
np.dtype(np.int32), batch_dims, tuple(range(num_bd - 1, -1, -1))),
|
|
workspace,
|
|
)),
|
|
operand_shapes_with_layout=(
|
|
_Shape.array_shape(
|
|
dtype, batch_dims + (m, n),
|
|
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
|
|
_Shape.array_shape(
|
|
dtype, batch_dims + (k,),
|
|
tuple(range(num_bd, -1, -1))),
|
|
),
|
|
opaque=opaque,
|
|
api_version=xla_client.ops.CustomCallApiVersion
|
|
.API_VERSION_STATUS_RETURNING)
|
|
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1))
|
|
|
|
|
|
def syevd(c, a, lower=False):
|
|
"""Symmetric (Hermitian) eigendecomposition."""
|
|
c = _unpack_builder(c)
|
|
|
|
a_shape = c.get_shape(a)
|
|
dtype = a_shape.element_type()
|
|
dims = a_shape.dimensions()
|
|
assert len(dims) >= 2
|
|
m, n = dims[-2:]
|
|
assert m == n
|
|
batch_dims = tuple(dims[:-2])
|
|
num_bd = len(batch_dims)
|
|
batch = _prod(batch_dims)
|
|
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
|
|
|
if n <= 32:
|
|
kernel = b"cusolver_syevj"
|
|
lwork, opaque = _cusolver.build_syevj_descriptor(
|
|
np.dtype(dtype), lower, batch, n)
|
|
else:
|
|
kernel = b"cusolver_syevd"
|
|
lwork, opaque = _cusolver.build_syevd_descriptor(
|
|
np.dtype(dtype), lower, batch, n)
|
|
eigvals_type = _real_type(dtype)
|
|
|
|
out = _ops.CustomCallWithLayout(
|
|
c, kernel,
|
|
operands=(a,),
|
|
shape_with_layout=_Shape.tuple_shape((
|
|
_Shape.array_shape(dtype, dims, layout),
|
|
_Shape.array_shape(
|
|
np.dtype(eigvals_type), batch_dims + (n,),
|
|
tuple(range(num_bd, -1, -1))),
|
|
_Shape.array_shape(
|
|
np.dtype(np.int32), batch_dims,
|
|
tuple(range(num_bd - 1, -1, -1))),
|
|
_Shape.array_shape(dtype, (lwork,), (0,))
|
|
)),
|
|
operand_shapes_with_layout=(
|
|
_Shape.array_shape(dtype, dims, layout),
|
|
),
|
|
opaque=opaque,
|
|
api_version=xla_client.ops.CustomCallApiVersion
|
|
.API_VERSION_STATUS_RETURNING)
|
|
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
|
|
_ops.GetTupleElement(out, 2))
|
|
|
|
|
|
def gesvd(c, a, full_matrices=True, compute_uv=True):
|
|
"""Singular value decomposition."""
|
|
c = _unpack_builder(c)
|
|
|
|
a_shape = c.get_shape(a)
|
|
dims = a_shape.dimensions()
|
|
dtype = a_shape.element_type()
|
|
assert len(dims) >= 2
|
|
m, n = dims[-2:]
|
|
batch_dims = tuple(dims[:-2])
|
|
num_bd = len(batch_dims)
|
|
b = _prod(batch_dims)
|
|
singular_vals_dtype = np.dtype(_real_type(dtype))
|
|
|
|
if m < 32 and n < 32:
|
|
lwork, opaque = _cusolver.build_gesvdj_descriptor(
|
|
np.dtype(dtype), b, m, n, compute_uv)
|
|
scalar_layout = tuple(range(num_bd - 1, -1, -1))
|
|
vector_layout = (num_bd,) + scalar_layout
|
|
matrix_layout = (num_bd, num_bd + 1) + scalar_layout
|
|
out = _ops.CustomCallWithLayout(
|
|
c, b"cusolver_gesvdj",
|
|
operands=(a,),
|
|
shape_with_layout=_Shape.tuple_shape((
|
|
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
|
|
_Shape.array_shape(singular_vals_dtype, batch_dims + (min(m, n),),
|
|
vector_layout),
|
|
_Shape.array_shape(dtype, batch_dims + (m, m), matrix_layout),
|
|
_Shape.array_shape(dtype, batch_dims + (n, n), matrix_layout),
|
|
_Shape.array_shape(np.dtype(np.int32), batch_dims, scalar_layout),
|
|
_Shape.array_shape(dtype, (lwork,), (0,)),
|
|
)),
|
|
operand_shapes_with_layout=(
|
|
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
|
|
),
|
|
opaque=opaque,
|
|
api_version=xla_client.ops.CustomCallApiVersion
|
|
.API_VERSION_STATUS_RETURNING)
|
|
s = _ops.GetTupleElement(out, 1)
|
|
u = _ops.GetTupleElement(out, 2)
|
|
v = _ops.GetTupleElement(out, 3)
|
|
info = _ops.GetTupleElement(out, 4)
|
|
vt = _ops.Transpose(v, tuple(range(num_bd)) + (num_bd + 1, num_bd))
|
|
if np.issubdtype(dtype, np.complexfloating):
|
|
vt = _ops.Conj(vt)
|
|
elif m < n:
|
|
lwork, opaque = _cusolver.build_gesvd_descriptor(
|
|
np.dtype(dtype), b, n, m, compute_uv, full_matrices)
|
|
scalar_layout = tuple(range(num_bd - 1, -1, -1))
|
|
vector_layout = (num_bd,) + scalar_layout
|
|
matrix_layout = (num_bd + 1, num_bd) + scalar_layout
|
|
out = _ops.CustomCallWithLayout(
|
|
c, b"cusolver_gesvd",
|
|
operands=(a,),
|
|
shape_with_layout=_Shape.tuple_shape((
|
|
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
|
|
_Shape.array_shape(singular_vals_dtype, batch_dims + (min(m, n),),
|
|
vector_layout),
|
|
_Shape.array_shape(dtype, batch_dims + (n, n), matrix_layout),
|
|
_Shape.array_shape(dtype, batch_dims + (m, m), matrix_layout),
|
|
_Shape.array_shape(np.dtype(np.int32), batch_dims, scalar_layout),
|
|
_Shape.array_shape(dtype, (lwork,), (0,)),
|
|
)),
|
|
operand_shapes_with_layout=(
|
|
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
|
|
),
|
|
opaque=opaque,
|
|
api_version=xla_client.ops.CustomCallApiVersion
|
|
.API_VERSION_STATUS_RETURNING)
|
|
s = _ops.GetTupleElement(out, 1)
|
|
vt = _ops.GetTupleElement(out, 2)
|
|
u = _ops.GetTupleElement(out, 3)
|
|
info = _ops.GetTupleElement(out, 4)
|
|
else:
|
|
lwork, opaque = _cusolver.build_gesvd_descriptor(
|
|
np.dtype(dtype), b, m, n, compute_uv, full_matrices)
|
|
|
|
scalar_layout = tuple(range(num_bd - 1, -1, -1))
|
|
vector_layout = (num_bd,) + scalar_layout
|
|
matrix_layout = (num_bd, num_bd + 1) + scalar_layout
|
|
out = _ops.CustomCallWithLayout(
|
|
c, b"cusolver_gesvd",
|
|
operands=(a,),
|
|
shape_with_layout=_Shape.tuple_shape((
|
|
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
|
|
_Shape.array_shape(singular_vals_dtype, batch_dims + (min(m, n),),
|
|
vector_layout),
|
|
_Shape.array_shape(dtype, batch_dims + (m, m), matrix_layout),
|
|
_Shape.array_shape(dtype, batch_dims + (n, n), matrix_layout),
|
|
_Shape.array_shape(np.dtype(np.int32), batch_dims, scalar_layout),
|
|
_Shape.array_shape(dtype, (lwork,), (0,)),
|
|
)),
|
|
operand_shapes_with_layout=(
|
|
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
|
|
),
|
|
opaque=opaque,
|
|
api_version=xla_client.ops.CustomCallApiVersion
|
|
.API_VERSION_STATUS_RETURNING)
|
|
s = _ops.GetTupleElement(out, 1)
|
|
u = _ops.GetTupleElement(out, 2)
|
|
vt = _ops.GetTupleElement(out, 3)
|
|
info = _ops.GetTupleElement(out, 4)
|
|
if not full_matrices:
|
|
u = _ops.Slice(u, (0,) * len(dims), batch_dims + (m, min(m, n)),
|
|
(1,) * len(dims))
|
|
vt = _ops.Slice(vt, (0,) * len(dims), batch_dims + (min(m, n), n),
|
|
(1,) * len(dims))
|
|
return s, u, vt, info
|