Port remaining translation rules inside JAX to new style.

PiperOrigin-RevId: 404288551
This commit is contained in:
Peter Hawkins 2021-10-19 09:47:55 -07:00 committed by jax authors
parent f66985de25
commit e783cbcb72
16 changed files with 359 additions and 289 deletions

View File

@ -616,7 +616,7 @@
"source": [
"from jax._src.lib import xla_client\n",
"@trace(\"multiply_add_xla_translation\")\n",
"def multiply_add_xla_translation(c, xc, yc, zc):\n",
"def multiply_add_xla_translation(ctx, avals_in, avals_out, xc, yc, zc):\n",
" \"\"\"The compilation to XLA of the primitive.\n",
"\n",
" Given an XlaBuilder and XlaOps for each argument, return the XlaOp for the\n",
@ -624,12 +624,12 @@
"\n",
" Does not need to be a JAX-traceable function.\n",
" \"\"\"\n",
" return xla_client.ops.Add(xla_client.ops.Mul(xc, yc), zc)\n",
" return [xla_client.ops.Add(xla_client.ops.Mul(xc, yc), zc)]\n",
"\n",
"# Now we register the XLA compilation rule with JAX\n",
"# TODO: for GPU? and TPU?\n",
"from jax.interpreters import xla\n",
"xla.backend_specific_translations['cpu'][multiply_add_p] = multiply_add_xla_translation"
"xla.register_translation(multiply_add_p, multiply_add_xla_translation, platform='cpu')"
]
},
{

View File

@ -356,7 +356,7 @@ for most of them. However, XLA includes a `CustomCall` operation that can be use
from jax._src.lib import xla_client
@trace("multiply_add_xla_translation")
def multiply_add_xla_translation(c, xc, yc, zc):
def multiply_add_xla_translation(ctx, avals_in, avals_out, xc, yc, zc):
"""The compilation to XLA of the primitive.
Given an XlaBuilder and XlaOps for each argument, return the XlaOp for the
@ -364,12 +364,12 @@ def multiply_add_xla_translation(c, xc, yc, zc):
Does not need to be a JAX-traceable function.
"""
return xla_client.ops.Add(xla_client.ops.Mul(xc, yc), zc)
return [xla_client.ops.Add(xla_client.ops.Mul(xc, yc), zc)]
# Now we register the XLA compilation rule with JAX
# TODO: for GPU? and TPU?
from jax.interpreters import xla
xla.backend_specific_translations['cpu'][multiply_add_p] = multiply_add_xla_translation
xla.register_translation(multiply_add_p, multiply_add_xla_translation, platform='cpu')
```
+++ {"id": "K98LX-VaJkFu"}

View File

@ -405,7 +405,8 @@ def name_jvp(primals, tangents, *, name):
return name_p.bind(x, name=name), xdot # don't name the tangent value
ad.primitive_jvps[name_p] = name_jvp
xla.translations[name_p] = lambda c, x, *, name: x
xla.register_translation(name_p,
lambda ctx, avals_in, avals_out, x, *, name: [x])
def name_batcher(args, dims, *, name):
(x,), (d,) = args, dims

View File

@ -696,7 +696,7 @@ xla.register_translation(
initial_style=True)
batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
xla.translations[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
xla.register_translation(ad.custom_lin_p, ad._raise_custom_vjp_error_on_jvp)
pe.partial_eval_jaxpr_custom_rules[custom_vjp_call_jaxpr_p] = \
custom_jvp_jaxpr_custom_partial_eval_rule # type: ignore

View File

@ -2718,10 +2718,11 @@ def _cumulative_reduction_primitive(name,
translation_rule=xla.lower_fun(
partial(associative_scan, reduce_fn),
multiple_results=False, new_style=True))
xla.backend_specific_translations['tpu'][reducer_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, tpu_reduce_window_fn),
multiple_results=False)
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p)
xla.register_translation(reducer_p, xla.lower_fun(
partial(_cumred_tpu_translation_rule, tpu_reduce_window_fn),
multiple_results=False, new_style=True), platform='tpu')
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule,
reducer_p)
return reducer_p
cumsum_p = _cumulative_reduction_primitive("cumsum", lax.add, lax._reduce_window_sum)

View File

@ -77,8 +77,14 @@ def fft_abstract_eval(x, fft_type, fft_lengths):
dtype = x.dtype
return x.update(shape=shape, dtype=dtype)
def fft_translation_rule(c, x, fft_type, fft_lengths):
return xops.Fft(x, fft_type, fft_lengths)
def _fft_translation_rule(ctx, avals_in, avals_out, x, *, fft_type,
fft_lengths):
return [xops.Fft(x, fft_type, fft_lengths)]
def _fft_translation_rule_cpu(ctx, avals_in, avals_out, x, *, fft_type,
fft_lengths):
return [pocketfft.pocketfft(ctx.builder, x, fft_type=fft_type,
fft_lengths=fft_lengths)]
def _naive_rfft(x, fft_lengths):
y = fft(x, xla_client.FftType.FFT, fft_lengths)
@ -138,8 +144,8 @@ def fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths):
fft_p = Primitive('fft')
fft_p.def_impl(fft_impl)
fft_p.def_abstract_eval(fft_abstract_eval)
xla.translations[fft_p] = fft_translation_rule
xla.register_translation(fft_p, _fft_translation_rule)
ad.deflinear2(fft_p, fft_transpose_rule)
batching.primitive_batchers[fft_p] = fft_batching_rule
if pocketfft:
xla.backend_specific_translations['cpu'][fft_p] = pocketfft.pocketfft
xla.register_translation(fft_p, _fft_translation_rule_cpu, platform='cpu')

View File

@ -350,25 +350,33 @@ def _nan_like(c, operand):
nan = xops.Constant(c, np.array(np.nan, dtype=dtype))
return xops.Broadcast(nan, shape.dimensions())
def _cholesky_cpu_gpu_translation_rule(potrf_impl, c, operand):
shape = c.get_shape(operand)
batch_dims = shape.dimensions()[:-2]
def _cholesky_cpu_gpu_translation_rule(potrf_impl, ctx, avals_in, avals_out,
operand):
operand_aval, = avals_in
c = ctx.builder
batch_dims = operand_aval.shape[:-2]
result, info = potrf_impl(c, operand, lower=True)
ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
return _broadcasting_select(c,
xops.Reshape(ok, batch_dims + (1, 1)), result,
_nan_like(c, result))
ok = xops.Eq(info, xops.Constant(c, np.array(0, np.int32)))
return [_broadcasting_select(c,
xops.Reshape(ok, batch_dims + (1, 1)), result,
_nan_like(c, result))]
xla.backend_specific_translations['cpu'][cholesky_p] = partial(
_cholesky_cpu_gpu_translation_rule, lapack.potrf)
xla.register_translation(
cholesky_p,
partial(_cholesky_cpu_gpu_translation_rule, lapack.potrf),
platform='cpu')
if cusolver is not None:
xla.backend_specific_translations['gpu'][cholesky_p] = partial(
_cholesky_cpu_gpu_translation_rule, cusolver.potrf)
xla.register_translation(
cholesky_p,
partial(_cholesky_cpu_gpu_translation_rule, cusolver.potrf),
platform='gpu')
if rocsolver is not None:
xla.backend_specific_translations['gpu'][cholesky_p] = partial(
_cholesky_cpu_gpu_translation_rule, rocsolver.potrf)
xla.register_translation(
cholesky_p,
partial(_cholesky_cpu_gpu_translation_rule, rocsolver.potrf),
platform='gpu')
# Asymmetric eigendecomposition
@ -409,15 +417,17 @@ def eig_abstract_eval(operand, *, compute_left_eigenvectors,
_cpu_geev = lapack.geev
def eig_cpu_translation_rule(c, operand, *, compute_left_eigenvectors,
compute_right_eigenvectors):
shape = c.get_shape(operand)
batch_dims = shape.dimensions()[:-2]
def _eig_cpu_translation_rule(ctx, avals_in, avals_out, operand, *,
compute_left_eigenvectors,
compute_right_eigenvectors):
operand_aval, = avals_in
batch_dims = operand_aval.shape[:-2]
c = ctx.builder
w, vl, vr, info = _cpu_geev(c, operand, jobvl=compute_left_eigenvectors,
jobvr=compute_right_eigenvectors)
ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
ok = xops.Eq(info, xops.Constant(c, np.array(0, np.int32)))
w = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), w,
_nan_like(c, w))
output = [w]
@ -432,7 +442,7 @@ def eig_cpu_translation_rule(c, operand, *, compute_left_eigenvectors,
_nan_like(c, vr))
output.append(vr)
return xops.Tuple(c, output)
return output
def eig_batching_rule(batched_args, batch_dims, *, compute_left_eigenvectors,
compute_right_eigenvectors):
@ -462,8 +472,8 @@ eig_p = Primitive('eig')
eig_p.multiple_results = True
eig_p.def_impl(eig_impl)
eig_p.def_abstract_eval(eig_abstract_eval)
xla.translations[eig_p] = eig_translation_rule
xla.backend_specific_translations['cpu'][eig_p] = eig_cpu_translation_rule
xla.register_translation(eig_p, eig_translation_rule)
xla.register_translation(eig_p, _eig_cpu_translation_rule, platform='cpu')
batching.primitive_batchers[eig_p] = eig_batching_rule
ad.primitive_jvps[eig_p] = eig_jvp_rule
@ -474,12 +484,11 @@ def eigh_impl(operand, lower):
v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
return v, w
def eigh_translation_rule(c, operand, lower):
shape = c.get_shape(operand)
dims = shape.dimensions()
if dims[-1] == 0:
return xops.Tuple(c, [operand, xops.Real(xops.Reshape(operand, dims[:-1]))])
return xops.Tuple(c, xops.Eigh(operand, lower=lower))
def _eigh_translation_rule(ctx, avals_in, avals_out, operand, *, lower):
operand_aval, = avals_in
if operand_aval.shape[-1] == 0:
return [operand, xops.Real(xops.Reshape(operand, operand_aval.shape[:-1]))]
return xops.Eigh(operand, lower=lower)
def eigh_abstract_eval(operand, lower):
if isinstance(operand, ShapedArray):
@ -497,16 +506,18 @@ def eigh_abstract_eval(operand, lower):
v, w = operand, operand
return v, w
def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower):
shape = c.get_shape(operand)
batch_dims = shape.dimensions()[:-2]
def _eigh_cpu_gpu_translation_rule(syevd_impl, ctx, avals_in, avals_out,
operand, *, lower):
operand_aval, = avals_in
batch_dims = operand_aval.shape[:-2]
c = ctx.builder
v, w, info = syevd_impl(c, operand, lower=lower)
ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
ok = xops.Eq(info, xops.Constant(c, np.array(0, np.int32)))
v = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), v,
_nan_like(c, v))
w = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), w,
_nan_like(c, w))
return xops.Tuple(c, [v, w])
return [v, w]
def eigh_jvp_rule(primals, tangents, lower):
# Derivative for eigh in the simplest case of distinct eigenvalues.
@ -545,22 +556,25 @@ eigh_p = Primitive('eigh')
eigh_p.multiple_results = True
eigh_p.def_impl(eigh_impl)
eigh_p.def_abstract_eval(eigh_abstract_eval)
xla.translations[eigh_p] = eigh_translation_rule
xla.register_translation(eigh_p, _eigh_translation_rule)
ad.primitive_jvps[eigh_p] = eigh_jvp_rule
batching.primitive_batchers[eigh_p] = eigh_batching_rule
_cpu_syevd = lapack.syevd
xla.backend_specific_translations['cpu'][eigh_p] = partial(
_eigh_cpu_gpu_translation_rule, _cpu_syevd)
xla.register_translation(
eigh_p, partial(_eigh_cpu_gpu_translation_rule, _cpu_syevd),
platform='cpu')
if cusolver is not None:
xla.backend_specific_translations['gpu'][eigh_p] = partial(
_eigh_cpu_gpu_translation_rule, cusolver.syevd)
xla.register_translation(
eigh_p, partial(_eigh_cpu_gpu_translation_rule, cusolver.syevd),
platform='gpu')
if rocsolver is not None:
xla.backend_specific_translations['gpu'][eigh_p] = partial(
_eigh_cpu_gpu_translation_rule, rocsolver.syevd)
xla.register_translation(
eigh_p, partial(_eigh_cpu_gpu_translation_rule, rocsolver.syevd),
platform='gpu')
triangular_solve_dtype_rule = partial(
@ -686,17 +700,18 @@ batching.primitive_batchers[triangular_solve_p] = triangular_solve_batching_rule
def _triangular_solve_cpu_translation_rule(
c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
shape = c.get_shape(a)
dtype = shape.element_type().type
ctx, avals_in, avals_out, a, b, *, left_side, lower, transpose_a,
conjugate_a, unit_diagonal):
a_aval, _ = avals_in
c = ctx.builder
if conjugate_a and not transpose_a:
a = xops.Conj(a)
conjugate_a = False
if len(shape.dimensions()) == 2 and np.dtype(dtype) in _cpu_lapack_types:
return lapack.jax_trsm(
c, xops.Constant(c, np.array(1, dtype=dtype)),
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
return [lapack.jax_trsm(
c, xops.Constant(c, np.array(1, dtype=a_aval.dtype)),
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)]
else:
# Fall back to the HLO implementation for unsupported types or batching.
# TODO: Consider swapping XLA for LAPACK in batched case
@ -705,24 +720,26 @@ def _triangular_solve_cpu_translation_rule(
else:
transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT if conjugate_a
else xops.TriangularSolveOptions_Transpose.TRANSPOSE)
return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose)
return [xops.TriangularSolve(a, b, left_side, lower, unit_diagonal,
transpose)]
xla.backend_specific_translations['cpu'][triangular_solve_p] = \
_triangular_solve_cpu_translation_rule
xla.register_translation(triangular_solve_p,
_triangular_solve_cpu_translation_rule,
platform='cpu')
def _triangular_solve_gpu_translation_rule(trsm_impl,
c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
shape = c.get_shape(a)
dims = shape.dimensions()
m, n = dims[-2:]
batch = prod(dims[:-2])
def _triangular_solve_gpu_translation_rule(
trsm_impl, ctx, avals_in, avals_out, a, b, *, left_side, lower, transpose_a,
conjugate_a, unit_diagonal):
c = ctx.builder
a_aval, _ = avals_in
m, n = a_aval.shape[-2:]
batch = prod(a_aval.shape[:-2])
if conjugate_a and not transpose_a:
a = xops.Conj(a)
conjugate_a = False
if batch > 1 and m <= 256 and n <= 256:
return trsm_impl(
c, a, b, left_side, lower, transpose_a,
conjugate_a, unit_diagonal)
return [trsm_impl(c, a, b, left_side, lower, transpose_a,
conjugate_a, unit_diagonal)]
else:
# Use the XLA implementation for unbatched triangular_solve.
if not transpose_a:
@ -730,16 +747,20 @@ def _triangular_solve_gpu_translation_rule(trsm_impl,
else:
transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT if conjugate_a
else xops.TriangularSolveOptions_Transpose.TRANSPOSE)
return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal,
transpose)
return [xops.TriangularSolve(a, b, left_side, lower, unit_diagonal,
transpose)]
if cusolver is not None:
xla.backend_specific_translations['gpu'][triangular_solve_p] = \
partial(_triangular_solve_gpu_translation_rule, cusolver.trsm)
xla.register_translation(
triangular_solve_p,
partial(_triangular_solve_gpu_translation_rule, cusolver.trsm),
platform='gpu')
if rocsolver is not None:
xla.backend_specific_translations['gpu'][triangular_solve_p] = \
partial(_triangular_solve_gpu_translation_rule, rocsolver.trsm)
xla.register_translation(
triangular_solve_p,
partial(_triangular_solve_gpu_translation_rule, rocsolver.trsm),
platform='gpu')
# Support operation for LU decomposition: Transformation of the pivots returned
# by LU decomposition into permutations.
@ -757,8 +778,7 @@ def _lu_pivots_body_fn(i, permutation_and_swaps):
return permutation.at[iotas + (j,)].set(x), swaps
@partial(api.jit, static_argnums=(1,))
def _generic_lu_pivots_to_permutation(swaps, m):
def _generic_lu_pivots_to_permutation(swaps, permutation_size):
"""Converts the pivots (row swaps) returned by LU to a permutation.
We build a permutation rather than applying `swaps` directly to the rows
@ -766,13 +786,14 @@ def _generic_lu_pivots_to_permutation(swaps, m):
Args:
swaps: an array of shape (..., k) of row swaps to perform
m: the size of the output permutation. m should be >= k.
permutation_size: the size of the output permutation. Should be >= k.
Returns:
An int32 array of shape (..., m).
"""
assert len(swaps.shape) >= 1
batch_dims = swaps.shape[:-1]
k = swaps.shape[-1]
m = permutation_size
permutation = lax.broadcasted_iota(jnp.int32, batch_dims + (m,),
len(batch_dims))
@ -812,13 +833,10 @@ def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *,
return lu_pivots_to_permutation_p.bind(
x, permutation_size=permutation_size), 0
def _lu_pivots_to_permutation_translation_rule(c, pivots, *, permutation_size):
lowered_fun = xla.lower_fun(
lambda x: _generic_lu_pivots_to_permutation(x, permutation_size),
multiple_results=False)
return lowered_fun(c, pivots)
def _lu_pivots_to_permutation_gpu(ctx, avals_in, avals_out, pivots, *,
permutation_size):
return [cuda_linalg.lu_pivots_to_permutation(
ctx.builder, pivots, permutation_size=permutation_size)]
lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation')
lu_pivots_to_permutation_p.multiple_results = False
@ -828,12 +846,15 @@ lu_pivots_to_permutation_p.def_abstract_eval(
_lu_pivots_to_permutation_abstract_eval)
batching.primitive_batchers[lu_pivots_to_permutation_p] = (
_lu_pivots_to_permutation_batching_rule)
xla.translations[lu_pivots_to_permutation_p] = (
_lu_pivots_to_permutation_translation_rule)
xla.register_translation(
lu_pivots_to_permutation_p,
xla.lower_fun(_generic_lu_pivots_to_permutation, multiple_results=False,
new_style=True))
if cuda_linalg:
xla.backend_specific_translations['gpu'][lu_pivots_to_permutation_p] = (
cuda_linalg.lu_pivots_to_permutation)
xla.register_translation(lu_pivots_to_permutation_p,
_lu_pivots_to_permutation_gpu,
platform='gpu')
# LU decomposition
@ -988,49 +1009,50 @@ def _lu_batching_rule(batched_args, batch_dims):
x = batching.moveaxis(x, bd, 0)
return lu_p.bind(x), (0, 0, 0)
def _lu_cpu_gpu_translation_rule(getrf_impl, c, operand, backend):
shape = c.get_shape(operand)
batch_dims = shape.dimensions()[:-2]
m = shape.dimensions()[-2]
def _lu_cpu_gpu_translation_rule(getrf_impl, ctx, avals_in, avals_out, operand):
operand_aval, = avals_in
c = ctx.builder
batch_dims = operand_aval.shape[:-2]
m = operand_aval.shape[-2]
lu, pivot, info = getrf_impl(c, operand)
# Subtract 1 from the pivot to get 0-based indices.
pivot = xops.Sub(pivot, xops.ConstantLiteral(c, np.array(1, np.int32)))
ok = xops.Ge(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
pivot = xops.Sub(pivot, xops.Constant(c, np.array(1, np.int32)))
ok = xops.Ge(info, xops.Constant(c, np.array(0, np.int32)))
lu = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), lu,
_nan_like(c, lu))
perm = xla.lower_fun(lambda x: lu_pivots_to_permutation(x, m),
multiple_results=False, backend=backend)(c, pivot)
return xops.Tuple(c, [lu, pivot, perm])
multiple_results=False, backend=ctx.platform)(c, pivot)
return [lu, pivot, perm]
def _lu_tpu_translation_rule(c, operand):
if hasattr(xops, "LU"):
lu, pivot, perm = xops.LU(operand)
return xops.Tuple(c, [lu, pivot, perm])
else:
return xla.lower_fun(_lu_python, multiple_results=True)(c, operand)
def _lu_tpu_translation_rule(ctx, avals_in, avals_out, operand):
return xops.LU(operand)
lu_p = Primitive('lu')
lu_p.multiple_results = True
lu_p.def_impl(_lu_impl)
lu_p.def_abstract_eval(_lu_abstract_eval)
xla.translations[lu_p] = xla.lower_fun(_lu_python, multiple_results=True)
xla.register_translation(lu_p, xla.lower_fun(_lu_python, multiple_results=True,
new_style=True))
ad.primitive_jvps[lu_p] = _lu_jvp_rule
batching.primitive_batchers[lu_p] = _lu_batching_rule
xla.backend_specific_translations['cpu'][lu_p] = partial(
_lu_cpu_gpu_translation_rule, lapack.getrf, backend='cpu')
xla.register_translation(lu_p,
partial(_lu_cpu_gpu_translation_rule, lapack.getrf),
platform='cpu')
if cusolver is not None:
xla.backend_specific_translations['gpu'][lu_p] = partial(
_lu_cpu_gpu_translation_rule, cusolver.getrf, backend='gpu')
xla.register_translation(
lu_p, partial(_lu_cpu_gpu_translation_rule, cusolver.getrf),
platform='gpu')
if rocsolver is not None:
xla.backend_specific_translations['gpu'][lu_p] = partial(
_lu_cpu_gpu_translation_rule, rocsolver.getrf, backend='gpu')
xla.register_translation(
lu_p, partial(_lu_cpu_gpu_translation_rule, rocsolver.getrf),
platform='gpu')
xla.backend_specific_translations['tpu'][lu_p] = _lu_tpu_translation_rule
xla.register_translation(lu_p, _lu_tpu_translation_rule, platform='tpu')
@partial(vectorize, excluded={3}, signature='(n,n),(n),(n,k)->(n,k)')
@ -1094,8 +1116,8 @@ def qr_impl(operand, full_matrices):
q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
return q, r
def qr_translation_rule(c, operand, full_matrices):
return xops.Tuple(c, xops.QR(operand, full_matrices))
def _qr_translation_rule(ctx, avals_in, avals_out, operand, *, full_matrices):
return xops.QR(operand, full_matrices)
def qr_abstract_eval(operand, full_matrices):
if isinstance(operand, ShapedArray):
@ -1137,10 +1159,11 @@ def qr_batching_rule(batched_args, batch_dims, full_matrices):
x = batching.moveaxis(x, bd, 0)
return qr_p.bind(x, full_matrices=full_matrices), (0, 0)
def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, c, operand,
full_matrices):
shape = c.get_shape(operand)
dims = shape.dimensions()
def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, ctx, avals_in,
avals_out, operand, *, full_matrices):
c = ctx.builder
operand_aval, = avals_in
dims = operand_aval.shape
m, n = dims[-2:]
batch_dims = dims[:-2]
r, tau, info_geqrf = geqrf_impl(c, operand)
@ -1155,13 +1178,13 @@ def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, c, operand,
else:
padding_config = [(0, 0, 0)] * len(dims)
padding_config[-1] = (0, m - n, 0)
q = xops.Pad(r, xops.Constant(c, np.array(0, dtype=shape.element_type())),
q = xops.Pad(r, xops.Constant(c, np.array(0, dtype=operand_aval.dtype)),
xla_client.make_padding_config(padding_config))
q, info_orgqr = orgqr_impl(c, q, tau)
if info_geqrf is not None:
ok = xops.And(
xops.Eq(info_geqrf, xops.ConstantLiteral(c, np.array(0, np.int32))),
xops.Eq(info_orgqr, xops.ConstantLiteral(c, np.array(0, np.int32))))
xops.Eq(info_geqrf, xops.Constant(c, np.array(0, np.int32))),
xops.Eq(info_orgqr, xops.Constant(c, np.array(0, np.int32))))
q = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), q,
_nan_like(c, q))
r = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), r,
@ -1170,26 +1193,31 @@ def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, c, operand,
pass # rocsolver does not return info
r = xla.lower_fun(jnp.triu, multiple_results=False)(c, r)
return xops.Tuple(c, [q, r])
return [q, r]
qr_p = Primitive('qr')
qr_p.multiple_results = True
qr_p.def_impl(qr_impl)
qr_p.def_abstract_eval(qr_abstract_eval)
xla.translations[qr_p] = qr_translation_rule
xla.register_translation(qr_p, _qr_translation_rule)
ad.primitive_jvps[qr_p] = qr_jvp_rule
batching.primitive_batchers[qr_p] = qr_batching_rule
xla.backend_specific_translations['cpu'][qr_p] = partial(
_qr_cpu_gpu_translation_rule, lapack.geqrf, lapack.orgqr)
xla.register_translation(
qr_p, partial(_qr_cpu_gpu_translation_rule, lapack.geqrf, lapack.orgqr),
platform='cpu')
if cusolver is not None:
xla.backend_specific_translations['gpu'][qr_p] = partial(
_qr_cpu_gpu_translation_rule, cusolver.geqrf, cusolver.orgqr)
xla.register_translation(
qr_p,
partial(_qr_cpu_gpu_translation_rule, cusolver.geqrf, cusolver.orgqr),
platform='gpu')
if rocsolver is not None:
xla.backend_specific_translations['gpu'][qr_p] = partial(
_qr_cpu_gpu_translation_rule, rocsolver.geqrf, rocsolver.orgqr)
xla.register_translation(
qr_p,
partial(_qr_cpu_gpu_translation_rule, rocsolver.geqrf, rocsolver.orgqr),
platform='gpu')
# Singular value decomposition
@ -1198,12 +1226,15 @@ def svd_impl(operand, full_matrices, compute_uv):
return xla.apply_primitive(svd_p, operand, full_matrices=full_matrices,
compute_uv=compute_uv)
def svd_translation_rule(c, operand, full_matrices, compute_uv):
shape = c.get_shape(operand).dimensions()
def _svd_translation_rule(ctx, avals_in, avals_out, operand, *, full_matrices,
compute_uv):
operand_aval, = avals_in
shape = operand_aval.shape
m, n = shape[-2:]
if m == 0 or n == 0:
return xla.lower_fun(_empty_svd, multiple_results=True)(
c, operand, full_matrices=full_matrices, compute_uv=compute_uv)
return xla.lower_fun(_empty_svd, multiple_results=True, new_style=True)(
ctx, avals_in, avals_out, operand, full_matrices=full_matrices,
compute_uv=compute_uv)
u, s, v = xops.SVD(operand)
permutation = list(range(len(shape)))
@ -1214,9 +1245,9 @@ def svd_translation_rule(c, operand, full_matrices, compute_uv):
vt = xops.SliceInDim(vt, 0, min(m, n), stride=1, dimno=len(shape) - 2)
if not compute_uv:
return xops.Tuple(c, [s])
return [s]
else:
return xops.Tuple(c, [s, u, vt])
return [s, u, vt]
def svd_abstract_eval(operand, full_matrices, compute_uv):
@ -1293,19 +1324,22 @@ def _empty_svd(a, *, full_matrices, compute_uv):
u, v = v, u
return s, u, v
def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute_uv):
shape = c.get_shape(operand).dimensions()
m, n = shape[-2:]
batch_dims = shape[:-2]
def _svd_cpu_gpu_translation_rule(gesvd_impl, ctx, avals_in, avals_out, operand,
*, full_matrices, compute_uv):
operand_aval, = avals_in
m, n = operand_aval.shape[-2:]
batch_dims = operand_aval.shape[:-2]
c = ctx.builder
if m == 0 or n == 0:
return xla.lower_fun(_empty_svd, multiple_results=True)(
c, operand, full_matrices=full_matrices, compute_uv=compute_uv)
return xla.lower_fun(_empty_svd, multiple_results=True, new_style=True)(
ctx, avals_in, avals_out, operand, full_matrices=full_matrices,
compute_uv=compute_uv)
s, u, vt, info = gesvd_impl(c, operand,
full_matrices=full_matrices,
compute_uv=compute_uv)
ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
ok = xops.Eq(info, xops.Constant(c, np.array(0, np.int32)))
s = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), s,
_nan_like(c, s))
@ -1318,7 +1352,7 @@ def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute
_nan_like(c, vt))
result += [u, vt]
return xops.Tuple(c, result)
return result
def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
x, = batched_args
@ -1337,20 +1371,27 @@ svd_p.def_impl(svd_impl)
svd_p.def_abstract_eval(svd_abstract_eval)
ad.primitive_jvps[svd_p] = svd_jvp_rule
batching.primitive_batchers[svd_p] = svd_batching_rule
xla.translations[svd_p] = svd_translation_rule
xla.register_translation(svd_p, _svd_translation_rule)
xla.backend_specific_translations['cpu'][svd_p] = partial(
_svd_cpu_gpu_translation_rule, lapack.gesdd)
xla.register_translation(
svd_p, partial(_svd_cpu_gpu_translation_rule, lapack.gesdd),
platform='cpu')
if cusolver is not None:
xla.backend_specific_translations['gpu'][svd_p] = partial(
_svd_cpu_gpu_translation_rule, cusolver.gesvd)
xla.register_translation(
svd_p, partial(_svd_cpu_gpu_translation_rule, cusolver.gesvd),
platform='gpu')
if rocsolver is not None:
xla.backend_specific_translations['gpu'][svd_p] = partial(
_svd_cpu_gpu_translation_rule, rocsolver.gesvd)
xla.register_translation(
svd_p, partial(_svd_cpu_gpu_translation_rule, rocsolver.gesvd),
platform='gpu')
def _tridiagonal_solve_gpu_translation_rule(ctx, avals_in, avals_out, dl, d, du,
b, *, m, n, ldb, t):
return [cusparse.gtsv2(ctx.builder, dl, d, du, b, m=m, n=n, ldb=ldb, t=t)]
tridiagonal_solve_p = Primitive('tridiagonal_solve')
tridiagonal_solve_p.multiple_results = False
tridiagonal_solve_p.def_impl(
@ -1358,18 +1399,11 @@ tridiagonal_solve_p.def_impl(
tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b)
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?
if cusparse is not None and hasattr(cusparse, "gtsv2"):
xla.backend_specific_translations['gpu'][tridiagonal_solve_p] = cusparse.gtsv2
xla.register_translation(tridiagonal_solve_p,
_tridiagonal_solve_gpu_translation_rule,
platform='gpu')
def _tridiagonal_solve_translation_rule(c, dl, d, du, b, *, m, n, ldb, t):
del m, n, ldb, t
lowered_fun = xla.lower_fun(_tridiagonal_solve_jax, multiple_results=False)
return lowered_fun(c, dl, d, du, b)
xla.translations[tridiagonal_solve_p] = _tridiagonal_solve_translation_rule
def _tridiagonal_solve_jax(dl, d, du, b):
def _tridiagonal_solve_jax(dl, d, du, b, **kw):
"""Pure JAX implementation of `tridiagonal_solve`."""
prepend_zero = lambda x: jnp.append(jnp.zeros([1], dtype=x.dtype), x[:-1])
fwd1 = lambda tu_, x: x[1] / (x[0] - x[2] * tu_)
@ -1397,6 +1431,10 @@ def _tridiagonal_solve_jax(dl, d, du, b):
return x_[::-1]
xla.register_translation(tridiagonal_solve_p, xla.lower_fun(
_tridiagonal_solve_jax, multiple_results=False, new_style=True))
def tridiagonal_solve(dl, d, du, b):
r"""Computes the solution of a tridiagonal linear system.
@ -1469,8 +1507,8 @@ def _schur_impl(operand, *, compute_schur_vectors, sort_eig_vals,
select_callable=select_callable)
def _schur_translation_rule(c, operand, *, compute_schur_vectors,
sort_eig_vals):
def _schur_translation_rule(ctx, avals_in, avals_out, operand, *,
compute_schur_vectors, sort_eig_vals):
raise NotImplementedError(
"Schur decomposition is only implemented on the CPU backend.")
@ -1492,10 +1530,12 @@ def _schur_abstract_eval(operand, *, compute_schur_vectors, sort_eig_vals,
return (T, vs) if compute_schur_vectors else (T,)
def _schur_cpu_translation_rule(c, operand, *, compute_schur_vectors,
sort_eig_vals, select_callable):
shape = c.get_shape(operand)
batch_dims = shape.dimensions()[:-2]
def _schur_cpu_translation_rule(ctx, avals_in, avals_out, operand, *,
compute_schur_vectors, sort_eig_vals,
select_callable):
operand_aval, = avals_in
batch_dims = operand_aval.shape[:-2]
c = ctx.builder
if jaxlib_version < (0, 1, 72):
raise NotImplementedError(
@ -1519,7 +1559,7 @@ def _schur_cpu_translation_rule(c, operand, *, compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable)
ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
ok = xops.Eq(info, xops.Constant(c, np.array(0, np.int32)))
T = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), T,
_nan_like(c, T))
output = [T]
@ -1529,7 +1569,7 @@ def _schur_cpu_translation_rule(c, operand, *, compute_schur_vectors,
output.append(vs)
return xops.Tuple(c, output)
return output
def _schur_batching_rule(batched_args, batch_dims, *, compute_schur_vectors,
@ -1555,7 +1595,7 @@ schur_p = Primitive('schur')
schur_p.multiple_results = True
schur_p.def_impl(_schur_impl)
schur_p.def_abstract_eval(_schur_abstract_eval)
xla.translations[schur_p] = _schur_translation_rule
xla.backend_specific_translations['cpu'][schur_p] = _schur_cpu_translation_rule
xla.register_translation(schur_p, _schur_translation_rule)
xla.register_translation(schur_p, _schur_cpu_translation_rule, platform='cpu')
batching.primitive_batchers[schur_p] = _schur_batching_rule
ad.primitive_jvps[schur_p] = _schur_jvp_rule

View File

@ -350,21 +350,24 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
return tuple(x)
def _threefry2x32_gpu_translation_rule(c, k1, k2, x1, x2):
shape = lax.broadcast_shapes(
c.get_shape(k1).dimensions(), c.get_shape(k2).dimensions(),
c.get_shape(x1).dimensions(), c.get_shape(x2).dimensions())
rank = len(shape)
if 0 in shape:
def _threefry2x32_gpu_translation_rule(ctx, avals_in, avals_out, k1, k2, x1,
x2):
aval_out, _ = avals_out
k1_aval, k2_aval, x1_aval, x2_aval = avals_in
rank = len(aval_out.shape)
if 0 in aval_out.shape:
zeros = xla_client.ops.Broadcast(
xla_client.ops.Constant(c, np.array(0, np.uint32)), shape)
return xla_client.ops.Tuple(c, [zeros, zeros])
def _broadcast(x):
ndims = c.get_shape(x).rank()
return xla_client.ops.BroadcastInDim(x, shape,
tuple(range(rank - ndims, rank)))
return cuda_prng.threefry2x32(
c, (_broadcast(k1), _broadcast(k2)), (_broadcast(x1), _broadcast(x2)))
xla_client.ops.Constant(ctx.builder, np.array(0, np.uint32)),
aval_out.shape)
return [zeros, zeros]
def _broadcast(x, aval):
return xla_client.ops.BroadcastInDim(
x, aval_out.shape, tuple(range(rank - len(aval.shape), rank)))
return xla.xla_destructure(
ctx.builder,
cuda_prng.threefry2x32(
ctx.builder, (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval))))
threefry2x32_p = core.Primitive("threefry2x32")
@ -379,8 +382,8 @@ xla.register_translation(threefry2x32_p, xla.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=True),
multiple_results=True, new_style=True), platform='cpu')
if cuda_prng:
xla.backend_specific_translations['gpu'][threefry2x32_p] = \
_threefry2x32_gpu_translation_rule
xla.register_translation(threefry2x32_p, _threefry2x32_gpu_translation_rule,
platform='gpu')
@partial(jit, inline=True)

View File

@ -798,7 +798,8 @@ def _values_to_avals(vals) -> Sequence[core.ShapedArray]:
id_tap_dep_p = core.Primitive("id_tap_dep")
id_tap_dep_p.multiple_results = False
id_tap_dep_p.def_impl(lambda r, _: r)
xla.translations[id_tap_dep_p] = lambda comp, a_res, a_tap: a_res
xla.register_translation(id_tap_dep_p,
lambda ctx, avals_in, avals_out, a_res, a_tap: [a_res])
id_tap_dep_p.def_abstract_eval(lambda r_a, _: r_a)
def _id_tap_dep_jvp_rule(primals, tangents):
@ -923,9 +924,8 @@ def _outside_call_impl(*args, **params):
outside_call_p.def_impl(_outside_call_impl)
def _outside_call_translation_rule(comp: XlaBuilder,
def _outside_call_translation_rule(ctx, avals_in, avals_out,
*args_op: XlaOp,
platform="tpu",
has_token,
identity,
flat_results_aval=(),
@ -934,6 +934,7 @@ def _outside_call_translation_rule(comp: XlaBuilder,
assert has_token
current_token = args_op[-2]
current_itoken = args_op[-1]
comp = ctx.builder
assert comp.get_shape(current_token).is_token() and comp.get_shape(current_itoken).is_token(), (
"The last two arguments must be tokens")
@ -944,7 +945,7 @@ def _outside_call_translation_rule(comp: XlaBuilder,
flat_results_aval))
need_callback_results_on_device = (not identity and
len(non_empty_flat_results_aval) > 0)
use_outfeed = _use_outfeed(platform)
use_outfeed = _use_outfeed(ctx.platform)
send_infeed = use_outfeed and need_callback_results_on_device
generated_infeed = False # Keep track if we emitted an infeed op
if use_outfeed:
@ -1041,7 +1042,7 @@ def _outside_call_translation_rule(comp: XlaBuilder,
xla.aval_to_xla_shapes(res_aval)[0]
for res_aval in callback_flat_results_aval
]
backend = xb.get_backend(platform)
backend = xb.get_backend(ctx.platform)
token_and_results_op, keep_alive = backend.emit_python_callback(
wrapped_callback,
comp,
@ -1062,12 +1063,10 @@ def _outside_call_translation_rule(comp: XlaBuilder,
assert identity or len(results) == len(flat_results_aval), (
f"got {len(results)} but expected {len(flat_results_aval)}. "
f"identity = {identity}")
return xops.Tuple(comp, results + [next_token, next_itoken])
return results + [next_token, next_itoken]
for platform in ["cpu", "gpu", "tpu"]:
xla.backend_specific_translations[platform][outside_call_p] = (
functools.partial(_outside_call_translation_rule, platform=platform))
xla.register_translation(outside_call_p, _outside_call_translation_rule)
def _outside_call_run_callback(
@ -1383,7 +1382,7 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
else:
output_token_var = mk_new_var(last_token_var.aval)
output_itoken_var = mk_new_var(last_itoken_var.aval)
_rewrite_eqn(platform, eqn, eqns, last_token_var, output_token_var,
_rewrite_eqn(eqn, eqns, last_token_var, output_token_var,
last_itoken_var, output_itoken_var, mk_new_var)
last_token_var = output_token_var
last_itoken_var = output_itoken_var
@ -1393,7 +1392,7 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
return new_jaxpr
def _rewrite_eqn(platform: str, eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
input_token_var: core.Var, output_token_var: core.Var,
input_itoken_var: core.Var, output_itoken_var: core.Var,
mk_new_var: Callable[[core.AbstractValue], core.Var]):
@ -1691,7 +1690,7 @@ id_p = core.Primitive("id")
id_p.multiple_results = True
id_p.def_impl(lambda *args: args)
id_p.def_abstract_eval(lambda *args: args)
xla.translations[id_p] = lambda c, *args: xops.Tuple(c, args)
xla.register_translation(id_p, lambda ctx, avals_in, avals_out, *args: args)
xla.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False)

View File

@ -237,7 +237,7 @@ def _call_tf_abstract_eval(*_,
call_tf_p.def_abstract_eval(_call_tf_abstract_eval)
def _call_tf_translation_rule(builder: xla.XlaBuilder, *args_op,
def _call_tf_translation_rule(ctx, avals_in, avals_out, *args_op,
function_flat_tf,
args_flat_sig_tf,
**_):
@ -245,7 +245,7 @@ def _call_tf_translation_rule(builder: xla.XlaBuilder, *args_op,
code_gen, _ = _code_generator_and_avals(function_flat_tf, args_flat_sig_tf, # type: ignore
code_gen_optional=False)
assert code_gen is not None
return code_gen(builder, args_op)
return code_gen(ctx.builder, args_op)
@functools.lru_cache(maxsize=128)
@ -253,7 +253,8 @@ def _code_generator_and_avals(
function_flat_tf,
args_flat_sig_tf,
code_gen_optional=False
) -> Tuple[Optional[Callable[[xla.XlaBuilder, Sequence[xla.XlaOp]], xla.XlaOp]],
) -> Tuple[Optional[Callable[[xla.XlaBuilder, Sequence[xla.XlaOp]],
Sequence[xla.XlaOp]]],
Sequence[core.ShapedArray]]:
# Returns and caches a code generator (taking a builder and the
# XlaOps for the arguments) and a sequence of result abstract shapes.
@ -384,8 +385,9 @@ def _code_generator_and_avals(
result_avals = tuple(map(canonical_res_aval, result_shapes)) # type: ignore
def code_gen(builder: xla.XlaBuilder, args_op: Sequence[xla.XlaOp]) -> xla.XlaOp:
captured_ops = [xops.ConstantLiteral(builder, np.asarray(inp))
def code_gen(builder: xla.XlaBuilder, args_op: Sequence[xla.XlaOp]
) -> Sequence[xla.XlaOp]:
captured_ops = [xops.Constant(builder, np.asarray(inp))
for inp in captured_inputs]
res_tf = xops.Call(builder, xla_comp, args_op + tuple(captured_ops)) # type: ignore
@ -399,15 +401,14 @@ def _code_generator_and_avals(
new_element_type=xla.dtype_to_primitive_type(res_aval.dtype))
return res_op
results = [
return [
post_process_result(i, res_aval, res_shape)
for i, (res_aval, res_shape) in enumerate(zip(result_avals,
result_shapes))]
return xops.Tuple(builder, results)
return code_gen, result_avals
xla.translations[call_tf_p] = _call_tf_translation_rule
xla.register_translation(call_tf_p, _call_tf_translation_rule)
TfVal = jax2tf_internal.TfVal
def _jax2tf_call_tf(*args: TfVal,

View File

@ -752,11 +752,13 @@ ad.deflinear2(sharding_constraint_p,
sharding_constraint_p.bind(
ct, axis_resources=axis_resources, resource_env=resource_env),))
def _sharding_constraint_translation_rule(c, x_node, axis_resources, resource_env):
def _sharding_constraint_translation_rule(ctx, avals_in, avals_out, x_node, *,
axis_resources, resource_env):
mesh = resource_env.physical_mesh
return xb.set_sharding_proto(c, x_node,
get_sharding_proto(c, x_node, axis_resources, mesh))
xla.translations[sharding_constraint_p] = _sharding_constraint_translation_rule
return [xb.set_sharding_proto(
ctx.builder, x_node,
get_sharding_proto(ctx.builder, x_node, axis_resources, mesh))]
xla.register_translation(sharding_constraint_p, _sharding_constraint_translation_rule)
def _sharding_constraint_batcher(insert_axis, axis_size, axis_name, main_type, vals_in, dims_in,
axis_resources, resource_env):

View File

@ -188,8 +188,8 @@ def _bcoo_todense_batching_rule(batched_args, batch_dims, *, shape):
ad.defjvp(bcoo_todense_p, _bcoo_todense_jvp, None)
ad.primitive_transposes[bcoo_todense_p] = _bcoo_todense_transpose
batching.primitive_batchers[bcoo_todense_p] = _bcoo_todense_batching_rule
xla.translations[bcoo_todense_p] = xla.lower_fun(
_bcoo_todense_impl, multiple_results=False)
xla.register_translation(bcoo_todense_p, xla.lower_fun(
_bcoo_todense_impl, multiple_results=False, new_style=True))
#--------------------------------------------------------------------
# bcoo_fromdense
@ -288,8 +288,8 @@ def _bcoo_fromdense_batching_rule(batched_args, batch_dims, *, nse, n_batch, n_d
ad.primitive_jvps[bcoo_fromdense_p] = _bcoo_fromdense_jvp
ad.primitive_transposes[bcoo_fromdense_p] = _bcoo_fromdense_transpose
batching.primitive_batchers[bcoo_fromdense_p] = _bcoo_fromdense_batching_rule
xla.translations[bcoo_fromdense_p] = xla.lower_fun(
_bcoo_fromdense_impl, multiple_results=True)
xla.register_translation(bcoo_fromdense_p, xla.lower_fun(
_bcoo_fromdense_impl, multiple_results=True, new_style=True))
#----------------------------------------------------------------------
# bcoo_extract
@ -355,8 +355,8 @@ def _bcoo_extract_batching_rule(batched_args, batch_dims):
ad.defjvp(bcoo_extract_p, None, _bcoo_extract_jvp)
ad.primitive_transposes[bcoo_extract_p] = _bcoo_extract_transpose
batching.primitive_batchers[bcoo_extract_p] = _bcoo_extract_batching_rule
xla.translations[bcoo_extract_p] = xla.lower_fun(
_bcoo_extract_impl, multiple_results=False)
xla.register_translation(bcoo_extract_p, xla.lower_fun(
_bcoo_extract_impl, multiple_results=False, new_style=True))
#----------------------------------------------------------------------
# bcoo_transpose
@ -450,8 +450,8 @@ def _bcoo_transpose_batch_rule(batched_args, batch_dims, *, permutation, shape):
ad.primitive_jvps[bcoo_transpose_p] = _bcoo_transpose_jvp
ad.primitive_transposes[bcoo_transpose_p] = _bcoo_transpose_transpose
batching.primitive_batchers[bcoo_transpose_p] = _bcoo_transpose_batch_rule
xla.translations[bcoo_transpose_p] = xla.lower_fun(
_bcoo_transpose_impl, multiple_results=True)
xla.register_translation(bcoo_transpose_p, xla.lower_fun(
_bcoo_transpose_impl, multiple_results=True, new_style=True))
#----------------------------------------------------------------------
# bcoo_dot_general
@ -620,8 +620,8 @@ def _bcoo_dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
ad.defjvp(bcoo_dot_general_p, _bcoo_dot_general_jvp_lhs, None, _bcoo_dot_general_jvp_rhs)
ad.primitive_transposes[bcoo_dot_general_p] = _bcoo_dot_general_transpose
batching.primitive_batchers[bcoo_dot_general_p] = _bcoo_dot_general_batch_rule
xla.translations[bcoo_dot_general_p] = xla.lower_fun(
_bcoo_dot_general_impl, multiple_results=False)
xla.register_translation(bcoo_dot_general_p, xla.lower_fun(
_bcoo_dot_general_impl, multiple_results=False, new_style=True))
#----------------------------------------------------------------------
# bcoo_dot_general_sampled
@ -672,8 +672,8 @@ ad.defjvp(bcoo_dot_general_sampled_p, _bcoo_dot_general_sampled_jvp_A,
_bcoo_dot_general_sampled_jvp_B, None)
ad.primitive_transposes[bcoo_dot_general_sampled_p] = _bcoo_dot_general_sampled_transpose
batching.primitive_batchers[bcoo_dot_general_sampled_p] = _bcoo_dot_general_sampled_batch_rule
xla.translations[bcoo_dot_general_sampled_p] = xla.lower_fun(
_bcoo_dot_general_sampled_impl, multiple_results=False)
xla.register_translation(bcoo_dot_general_sampled_p, xla.lower_fun(
_bcoo_dot_general_sampled_impl, multiple_results=False, new_style=True))
#----------------------------------------------------------------------
# bcoo_spdot_general
@ -816,8 +816,8 @@ def _bcoo_spdot_general_batch_rule(batched_args, batch_dims, *, dimension_number
# TODO(JVP): jvp, transpose
batching.primitive_batchers[bcoo_spdot_general_p] = _bcoo_spdot_general_batch_rule
xla.translations[bcoo_spdot_general_p] = xla.lower_fun(
_bcoo_spdot_general_impl, multiple_results=True)
xla.register_translation(bcoo_spdot_general_p, xla.lower_fun(
_bcoo_spdot_general_impl, multiple_results=True, new_style=True))
#----------------------------------------------------------------------
# BCOO functions that maybe should be primitives?

View File

@ -103,14 +103,15 @@ def _csr_todense_abstract_eval(data, indices, indptr, *, shape):
assert indptr.shape[0] == shape[0] + 1
return core.ShapedArray(shape, data.dtype)
def _csr_todense_gpu_translation_rule(c, data, indices, indptr, *, shape):
return cusparse.csr_todense(c, data, indices, indptr, shape=shape)
def _csr_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
indptr, *, shape):
return [cusparse.csr_todense(ctx.builder, data, indices, indptr, shape=shape)]
xla.translations[csr_todense_p] = xla.lower_fun(
_csr_todense_impl, multiple_results=False)
xla.register_translation(csr_todense_p, xla.lower_fun(
_csr_todense_impl, multiple_results=False, new_style=True))
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
csr_todense_p] = _csr_todense_gpu_translation_rule
xla.register_translation(csr_todense_p, _csr_todense_gpu_translation_rule,
platform='gpu')
#--------------------------------------------------------------------
# csr_fromdense
@ -159,16 +160,18 @@ def _csr_fromdense_abstract_eval(mat, *, nse, index_dtype):
indptr = core.ShapedArray((mat.shape[0] + 1,), index_dtype)
return data, indices, indptr
def _csr_fromdense_gpu_translation_rule(c, mat, *, nse, index_dtype):
def _csr_fromdense_gpu_translation_rule(ctx, avals_in, avals_out, mat, *, nse,
index_dtype):
data, indices, indptr = cusparse.csr_fromdense(
c, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
return xops.Tuple(c, [data, indices, indptr])
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
return [data, indices, indptr]
xla.translations[csr_fromdense_p] = xla.lower_fun(
_csr_fromdense_impl, multiple_results=True)
xla.register_translation(csr_fromdense_p, xla.lower_fun(
_csr_fromdense_impl, multiple_results=True, new_style=True))
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
csr_fromdense_p] = _csr_fromdense_gpu_translation_rule
xla.register_translation(csr_fromdense_p,
_csr_fromdense_gpu_translation_rule,
platform='gpu')
#--------------------------------------------------------------------
# csr_matvec
@ -211,14 +214,16 @@ def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose):
assert v.shape[0] == (shape[0] if transpose else shape[1])
return core.ShapedArray((out_shape,), data.dtype)
def _csr_matvec_gpu_translation_rule(c, data, indices, indptr, v, *, shape, transpose):
return cusparse.csr_matvec(c, data, indices, indptr, v, shape=shape, transpose=transpose)
def _csr_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
indptr, v, *, shape, transpose):
return [cusparse.csr_matvec(ctx.builder, data, indices, indptr, v,
shape=shape, transpose=transpose)]
xla.translations[csr_matvec_p] = xla.lower_fun(
_csr_matvec_impl, multiple_results=False)
xla.register_translation(csr_matvec_p, xla.lower_fun(
_csr_matvec_impl, multiple_results=False, new_style=True))
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
csr_matvec_p] = _csr_matvec_gpu_translation_rule
xla.register_translation(csr_matvec_p, _csr_matvec_gpu_translation_rule,
platform='gpu')
#--------------------------------------------------------------------
@ -263,14 +268,16 @@ def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose):
assert B.shape[0] == (shape[0] if transpose else shape[1])
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
def _csr_matmat_gpu_translation_rule(c, data, indices, indptr, B, *, shape, transpose):
return cusparse.csr_matmat(c, data, indices, indptr, B, shape=shape, transpose=transpose)
def _csr_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
indptr, B, *, shape, transpose):
return [cusparse.csr_matmat(ctx.builder, data, indices, indptr, B,
shape=shape, transpose=transpose)]
xla.translations[csr_matmat_p] = xla.lower_fun(
_csr_matmat_impl, multiple_results=False)
xla.register_translation(csr_matmat_p, xla.lower_fun(
_csr_matmat_impl, multiple_results=False, new_style=True))
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
csr_matmat_p] = _csr_matmat_gpu_translation_rule
xla.register_translation(csr_matmat_p, _csr_matmat_gpu_translation_rule,
platform='gpu')
#--------------------------------------------------------------------
@ -300,8 +307,9 @@ def _coo_todense_impl(data, row, col, *, shape):
def _coo_todense_abstract_eval(data, row, col, *, shape):
return core.ShapedArray(shape, data.dtype)
def _coo_todense_gpu_translation_rule(c, data, row, col, *, shape):
return cusparse.coo_todense(c, data, row, col, shape=shape)
def _coo_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
*, shape):
return [cusparse.coo_todense(ctx.builder, data, row, col, shape=shape)]
def _coo_todense_jvp(data_dot, data, row, col, *, shape):
return coo_todense(data_dot, row, col, shape=shape)
@ -319,11 +327,11 @@ def _coo_todense_transpose(ct, data, row, col, *, shape):
ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
xla.translations[coo_todense_p] = xla.lower_fun(
_coo_todense_impl, multiple_results=False)
xla.register_translation(coo_todense_p, xla.lower_fun(
_coo_todense_impl, multiple_results=False, new_style=True))
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
coo_todense_p] = _coo_todense_gpu_translation_rule
xla.register_translation(coo_todense_p, _coo_todense_gpu_translation_rule,
platform='gpu')
#--------------------------------------------------------------------
# coo_fromdense
@ -367,10 +375,11 @@ def _coo_fromdense_abstract_eval(mat, *, nse, index_dtype):
row = col = core.ShapedArray((nse,), index_dtype)
return data, row, col
def _coo_fromdense_gpu_translation_rule(c, mat, *, nse, index_dtype):
def _coo_fromdense_gpu_translation_rule(ctx, avals_in, avals_out, mat, *, nse,
index_dtype):
data, row, col = cusparse.coo_fromdense(
c, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
return xops.Tuple(c, [data, row, col])
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
return [data, row, col]
def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype):
M, = primals
@ -400,11 +409,12 @@ def _coo_fromdense_transpose(ct, M, *, nse, index_dtype):
ad.primitive_jvps[coo_fromdense_p] = _coo_fromdense_jvp
ad.primitive_transposes[coo_fromdense_p] = _coo_fromdense_transpose
xla.translations[coo_fromdense_p] = xla.lower_fun(
_coo_fromdense_impl, multiple_results=True)
xla.register_translation(coo_fromdense_p, xla.lower_fun(
_coo_fromdense_impl, multiple_results=True, new_style=True))
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
coo_fromdense_p] = _coo_fromdense_gpu_translation_rule
xla.register_translation(coo_fromdense_p,
_coo_fromdense_gpu_translation_rule,
platform='gpu')
#--------------------------------------------------------------------
# coo_matvec
@ -450,8 +460,10 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, shape, transpose):
out_shape = shape[1] if transpose else shape[0]
return core.ShapedArray((out_shape,), data.dtype)
def _coo_matvec_gpu_translation_rule(c, data, row, col, v, *, shape, transpose):
return cusparse.coo_matvec(c, data, row, col, v, shape=shape, transpose=transpose)
def _coo_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
v, *, shape, transpose):
return [cusparse.coo_matvec(ctx.builder, data, row, col, v, shape=shape,
transpose=transpose)]
def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, shape, transpose):
return coo_matvec(data_dot, row, col, v, shape=shape, transpose=transpose)
@ -472,11 +484,11 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, shape, transpose):
ad.defjvp(coo_matvec_p, _coo_matvec_jvp_mat, None, None, _coo_matvec_jvp_vec)
ad.primitive_transposes[coo_matvec_p] = _coo_matvec_transpose
xla.translations[coo_matvec_p] = xla.lower_fun(
_coo_matvec_impl, multiple_results=False)
xla.register_translation(coo_matvec_p, xla.lower_fun(
_coo_matvec_impl, multiple_results=False, new_style=True))
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
coo_matvec_p] = _coo_matvec_gpu_translation_rule
xla.register_translation(coo_matvec_p, _coo_matvec_gpu_translation_rule,
platform='gpu')
#--------------------------------------------------------------------
# coo_matmat
@ -521,14 +533,16 @@ def _coo_matmat_abstract_eval(data, row, col, B, *, shape, transpose):
out_shape = shape[1] if transpose else shape[0]
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
def _coo_matmat_gpu_translation_rule(c, data, row, col, B, *, shape, transpose):
return cusparse.coo_matmat(c, data, row, col, B, shape=shape, transpose=transpose)
def _coo_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
B, *, shape, transpose):
return [cusparse.coo_matmat(ctx.builder, data, row, col, B, shape=shape,
transpose=transpose)]
xla.translations[coo_matmat_p] = xla.lower_fun(
_coo_matmat_impl, multiple_results=False)
xla.register_translation(coo_matmat_p, xla.lower_fun(
_coo_matmat_impl, multiple_results=False, new_style=True))
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
coo_matmat_p] = _coo_matmat_gpu_translation_rule
xla.register_translation(coo_matmat_p, _coo_matmat_gpu_translation_rule,
platform='gpu')
def _coo_matmat_jvp_rule(primals_in, tangents_in, **params):
vals, rows, cols, mat = primals_in

View File

@ -409,15 +409,17 @@ def _sharding_constraint_impl(x, partitions):
raise NotImplementedError(
"with_sharding_constraint() should only be called inside sharded_jit()")
def _sharding_constraint_translation_rule(c, x_node, partitions):
return xb.set_sharding(c, x_node, partitions)
def _sharding_constraint_translation_rule(ctx, avals_in, avals_out, x_node,
partitions):
return [xb.set_sharding(ctx.builder, x_node, partitions)]
sharding_constraint_p = core.Primitive("sharding_constraint")
sharding_constraint_p.def_impl(_sharding_constraint_impl)
sharding_constraint_p.def_abstract_eval(lambda x, partitions: x)
ad.deflinear2(sharding_constraint_p,
lambda ct, _, partitions: (with_sharding_constraint(ct, partitions),))
xla.translations[sharding_constraint_p] = _sharding_constraint_translation_rule
xla.register_translation(sharding_constraint_p,
_sharding_constraint_translation_rule)
def with_sharding_constraint(x, partitions: Optional[PartitionSpec]):
"""Identity-like function that specifies how ``x`` should be sharded.

View File

@ -2699,7 +2699,8 @@ class APITest(jtu.JaxTestCase):
tokentest_p = core.Primitive("tokentest")
tokentest_p.def_impl(partial(xla.apply_primitive, tokentest_p))
tokentest_p.def_abstract_eval(lambda x, y: x)
xla.translations[tokentest_p] = lambda c, x, y: x
xla.register_translation(tokentest_p,
lambda ctx, avals_in, avals_out, x, y: [x])
ad.defjvp(tokentest_p, (lambda g, x, token: x), None)
token = jax.lax.create_token(123)

View File

@ -145,12 +145,12 @@ def _sp_indices_impl(mat):
def _sp_indices_abstract_eval(mat):
return mat.indices_aval
def _sp_indices_translation_rule(c, data, indices):
return indices
def _sp_indices_translation_rule(ctx, avals_in, avals_out, data, indices):
return [indices]
# Note: cannot use lower_fun to define attribute access primitives
# because it leads to infinite recursion.
xla.translations[sp_indices_p] = _sp_indices_translation_rule
xla.register_translation(sp_indices_p, _sp_indices_translation_rule)
sp_data_p = core.Primitive('sp_data')
@ -162,12 +162,12 @@ def _sp_data_impl(mat):
def _sp_data_abstract_eval(mat):
return mat.data_aval
def _sp_data_translation_rule(c, data, indices):
return data
def _sp_data_translation_rule(ctx, avals_in, avals_out, data, indices):
return [data]
# Note: cannot use lower_fun to define attribute access primitives
# because it leads to infinite recursion.
xla.translations[sp_data_p] = _sp_data_translation_rule
xla.register_translation(sp_data_p, _sp_data_translation_rule)
def identity(x):
return identity_p.bind(x)