mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Port remaining translation rules inside JAX to new style.
PiperOrigin-RevId: 404288551
This commit is contained in:
parent
f66985de25
commit
e783cbcb72
@ -616,7 +616,7 @@
|
||||
"source": [
|
||||
"from jax._src.lib import xla_client\n",
|
||||
"@trace(\"multiply_add_xla_translation\")\n",
|
||||
"def multiply_add_xla_translation(c, xc, yc, zc):\n",
|
||||
"def multiply_add_xla_translation(ctx, avals_in, avals_out, xc, yc, zc):\n",
|
||||
" \"\"\"The compilation to XLA of the primitive.\n",
|
||||
"\n",
|
||||
" Given an XlaBuilder and XlaOps for each argument, return the XlaOp for the\n",
|
||||
@ -624,12 +624,12 @@
|
||||
"\n",
|
||||
" Does not need to be a JAX-traceable function.\n",
|
||||
" \"\"\"\n",
|
||||
" return xla_client.ops.Add(xla_client.ops.Mul(xc, yc), zc)\n",
|
||||
" return [xla_client.ops.Add(xla_client.ops.Mul(xc, yc), zc)]\n",
|
||||
"\n",
|
||||
"# Now we register the XLA compilation rule with JAX\n",
|
||||
"# TODO: for GPU? and TPU?\n",
|
||||
"from jax.interpreters import xla\n",
|
||||
"xla.backend_specific_translations['cpu'][multiply_add_p] = multiply_add_xla_translation"
|
||||
"xla.register_translation(multiply_add_p, multiply_add_xla_translation, platform='cpu')"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -356,7 +356,7 @@ for most of them. However, XLA includes a `CustomCall` operation that can be use
|
||||
|
||||
from jax._src.lib import xla_client
|
||||
@trace("multiply_add_xla_translation")
|
||||
def multiply_add_xla_translation(c, xc, yc, zc):
|
||||
def multiply_add_xla_translation(ctx, avals_in, avals_out, xc, yc, zc):
|
||||
"""The compilation to XLA of the primitive.
|
||||
|
||||
Given an XlaBuilder and XlaOps for each argument, return the XlaOp for the
|
||||
@ -364,12 +364,12 @@ def multiply_add_xla_translation(c, xc, yc, zc):
|
||||
|
||||
Does not need to be a JAX-traceable function.
|
||||
"""
|
||||
return xla_client.ops.Add(xla_client.ops.Mul(xc, yc), zc)
|
||||
return [xla_client.ops.Add(xla_client.ops.Mul(xc, yc), zc)]
|
||||
|
||||
# Now we register the XLA compilation rule with JAX
|
||||
# TODO: for GPU? and TPU?
|
||||
from jax.interpreters import xla
|
||||
xla.backend_specific_translations['cpu'][multiply_add_p] = multiply_add_xla_translation
|
||||
xla.register_translation(multiply_add_p, multiply_add_xla_translation, platform='cpu')
|
||||
```
|
||||
|
||||
+++ {"id": "K98LX-VaJkFu"}
|
||||
|
@ -405,7 +405,8 @@ def name_jvp(primals, tangents, *, name):
|
||||
return name_p.bind(x, name=name), xdot # don't name the tangent value
|
||||
ad.primitive_jvps[name_p] = name_jvp
|
||||
|
||||
xla.translations[name_p] = lambda c, x, *, name: x
|
||||
xla.register_translation(name_p,
|
||||
lambda ctx, avals_in, avals_out, x, *, name: [x])
|
||||
|
||||
def name_batcher(args, dims, *, name):
|
||||
(x,), (d,) = args, dims
|
||||
|
@ -696,7 +696,7 @@ xla.register_translation(
|
||||
initial_style=True)
|
||||
|
||||
batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
|
||||
xla.translations[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
|
||||
xla.register_translation(ad.custom_lin_p, ad._raise_custom_vjp_error_on_jvp)
|
||||
|
||||
pe.partial_eval_jaxpr_custom_rules[custom_vjp_call_jaxpr_p] = \
|
||||
custom_jvp_jaxpr_custom_partial_eval_rule # type: ignore
|
||||
|
@ -2718,10 +2718,11 @@ def _cumulative_reduction_primitive(name,
|
||||
translation_rule=xla.lower_fun(
|
||||
partial(associative_scan, reduce_fn),
|
||||
multiple_results=False, new_style=True))
|
||||
xla.backend_specific_translations['tpu'][reducer_p] = xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, tpu_reduce_window_fn),
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p)
|
||||
xla.register_translation(reducer_p, xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, tpu_reduce_window_fn),
|
||||
multiple_results=False, new_style=True), platform='tpu')
|
||||
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule,
|
||||
reducer_p)
|
||||
return reducer_p
|
||||
|
||||
cumsum_p = _cumulative_reduction_primitive("cumsum", lax.add, lax._reduce_window_sum)
|
||||
|
@ -77,8 +77,14 @@ def fft_abstract_eval(x, fft_type, fft_lengths):
|
||||
dtype = x.dtype
|
||||
return x.update(shape=shape, dtype=dtype)
|
||||
|
||||
def fft_translation_rule(c, x, fft_type, fft_lengths):
|
||||
return xops.Fft(x, fft_type, fft_lengths)
|
||||
def _fft_translation_rule(ctx, avals_in, avals_out, x, *, fft_type,
|
||||
fft_lengths):
|
||||
return [xops.Fft(x, fft_type, fft_lengths)]
|
||||
|
||||
def _fft_translation_rule_cpu(ctx, avals_in, avals_out, x, *, fft_type,
|
||||
fft_lengths):
|
||||
return [pocketfft.pocketfft(ctx.builder, x, fft_type=fft_type,
|
||||
fft_lengths=fft_lengths)]
|
||||
|
||||
def _naive_rfft(x, fft_lengths):
|
||||
y = fft(x, xla_client.FftType.FFT, fft_lengths)
|
||||
@ -138,8 +144,8 @@ def fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths):
|
||||
fft_p = Primitive('fft')
|
||||
fft_p.def_impl(fft_impl)
|
||||
fft_p.def_abstract_eval(fft_abstract_eval)
|
||||
xla.translations[fft_p] = fft_translation_rule
|
||||
xla.register_translation(fft_p, _fft_translation_rule)
|
||||
ad.deflinear2(fft_p, fft_transpose_rule)
|
||||
batching.primitive_batchers[fft_p] = fft_batching_rule
|
||||
if pocketfft:
|
||||
xla.backend_specific_translations['cpu'][fft_p] = pocketfft.pocketfft
|
||||
xla.register_translation(fft_p, _fft_translation_rule_cpu, platform='cpu')
|
||||
|
@ -350,25 +350,33 @@ def _nan_like(c, operand):
|
||||
nan = xops.Constant(c, np.array(np.nan, dtype=dtype))
|
||||
return xops.Broadcast(nan, shape.dimensions())
|
||||
|
||||
def _cholesky_cpu_gpu_translation_rule(potrf_impl, c, operand):
|
||||
shape = c.get_shape(operand)
|
||||
batch_dims = shape.dimensions()[:-2]
|
||||
def _cholesky_cpu_gpu_translation_rule(potrf_impl, ctx, avals_in, avals_out,
|
||||
operand):
|
||||
operand_aval, = avals_in
|
||||
c = ctx.builder
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
result, info = potrf_impl(c, operand, lower=True)
|
||||
ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
|
||||
return _broadcasting_select(c,
|
||||
xops.Reshape(ok, batch_dims + (1, 1)), result,
|
||||
_nan_like(c, result))
|
||||
ok = xops.Eq(info, xops.Constant(c, np.array(0, np.int32)))
|
||||
return [_broadcasting_select(c,
|
||||
xops.Reshape(ok, batch_dims + (1, 1)), result,
|
||||
_nan_like(c, result))]
|
||||
|
||||
xla.backend_specific_translations['cpu'][cholesky_p] = partial(
|
||||
_cholesky_cpu_gpu_translation_rule, lapack.potrf)
|
||||
xla.register_translation(
|
||||
cholesky_p,
|
||||
partial(_cholesky_cpu_gpu_translation_rule, lapack.potrf),
|
||||
platform='cpu')
|
||||
|
||||
if cusolver is not None:
|
||||
xla.backend_specific_translations['gpu'][cholesky_p] = partial(
|
||||
_cholesky_cpu_gpu_translation_rule, cusolver.potrf)
|
||||
xla.register_translation(
|
||||
cholesky_p,
|
||||
partial(_cholesky_cpu_gpu_translation_rule, cusolver.potrf),
|
||||
platform='gpu')
|
||||
|
||||
if rocsolver is not None:
|
||||
xla.backend_specific_translations['gpu'][cholesky_p] = partial(
|
||||
_cholesky_cpu_gpu_translation_rule, rocsolver.potrf)
|
||||
xla.register_translation(
|
||||
cholesky_p,
|
||||
partial(_cholesky_cpu_gpu_translation_rule, rocsolver.potrf),
|
||||
platform='gpu')
|
||||
|
||||
# Asymmetric eigendecomposition
|
||||
|
||||
@ -409,15 +417,17 @@ def eig_abstract_eval(operand, *, compute_left_eigenvectors,
|
||||
|
||||
_cpu_geev = lapack.geev
|
||||
|
||||
def eig_cpu_translation_rule(c, operand, *, compute_left_eigenvectors,
|
||||
compute_right_eigenvectors):
|
||||
shape = c.get_shape(operand)
|
||||
batch_dims = shape.dimensions()[:-2]
|
||||
def _eig_cpu_translation_rule(ctx, avals_in, avals_out, operand, *,
|
||||
compute_left_eigenvectors,
|
||||
compute_right_eigenvectors):
|
||||
operand_aval, = avals_in
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
c = ctx.builder
|
||||
|
||||
w, vl, vr, info = _cpu_geev(c, operand, jobvl=compute_left_eigenvectors,
|
||||
jobvr=compute_right_eigenvectors)
|
||||
|
||||
ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
|
||||
ok = xops.Eq(info, xops.Constant(c, np.array(0, np.int32)))
|
||||
w = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), w,
|
||||
_nan_like(c, w))
|
||||
output = [w]
|
||||
@ -432,7 +442,7 @@ def eig_cpu_translation_rule(c, operand, *, compute_left_eigenvectors,
|
||||
_nan_like(c, vr))
|
||||
output.append(vr)
|
||||
|
||||
return xops.Tuple(c, output)
|
||||
return output
|
||||
|
||||
def eig_batching_rule(batched_args, batch_dims, *, compute_left_eigenvectors,
|
||||
compute_right_eigenvectors):
|
||||
@ -462,8 +472,8 @@ eig_p = Primitive('eig')
|
||||
eig_p.multiple_results = True
|
||||
eig_p.def_impl(eig_impl)
|
||||
eig_p.def_abstract_eval(eig_abstract_eval)
|
||||
xla.translations[eig_p] = eig_translation_rule
|
||||
xla.backend_specific_translations['cpu'][eig_p] = eig_cpu_translation_rule
|
||||
xla.register_translation(eig_p, eig_translation_rule)
|
||||
xla.register_translation(eig_p, _eig_cpu_translation_rule, platform='cpu')
|
||||
batching.primitive_batchers[eig_p] = eig_batching_rule
|
||||
ad.primitive_jvps[eig_p] = eig_jvp_rule
|
||||
|
||||
@ -474,12 +484,11 @@ def eigh_impl(operand, lower):
|
||||
v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
|
||||
return v, w
|
||||
|
||||
def eigh_translation_rule(c, operand, lower):
|
||||
shape = c.get_shape(operand)
|
||||
dims = shape.dimensions()
|
||||
if dims[-1] == 0:
|
||||
return xops.Tuple(c, [operand, xops.Real(xops.Reshape(operand, dims[:-1]))])
|
||||
return xops.Tuple(c, xops.Eigh(operand, lower=lower))
|
||||
def _eigh_translation_rule(ctx, avals_in, avals_out, operand, *, lower):
|
||||
operand_aval, = avals_in
|
||||
if operand_aval.shape[-1] == 0:
|
||||
return [operand, xops.Real(xops.Reshape(operand, operand_aval.shape[:-1]))]
|
||||
return xops.Eigh(operand, lower=lower)
|
||||
|
||||
def eigh_abstract_eval(operand, lower):
|
||||
if isinstance(operand, ShapedArray):
|
||||
@ -497,16 +506,18 @@ def eigh_abstract_eval(operand, lower):
|
||||
v, w = operand, operand
|
||||
return v, w
|
||||
|
||||
def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower):
|
||||
shape = c.get_shape(operand)
|
||||
batch_dims = shape.dimensions()[:-2]
|
||||
def _eigh_cpu_gpu_translation_rule(syevd_impl, ctx, avals_in, avals_out,
|
||||
operand, *, lower):
|
||||
operand_aval, = avals_in
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
c = ctx.builder
|
||||
v, w, info = syevd_impl(c, operand, lower=lower)
|
||||
ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
|
||||
ok = xops.Eq(info, xops.Constant(c, np.array(0, np.int32)))
|
||||
v = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), v,
|
||||
_nan_like(c, v))
|
||||
w = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), w,
|
||||
_nan_like(c, w))
|
||||
return xops.Tuple(c, [v, w])
|
||||
return [v, w]
|
||||
|
||||
def eigh_jvp_rule(primals, tangents, lower):
|
||||
# Derivative for eigh in the simplest case of distinct eigenvalues.
|
||||
@ -545,22 +556,25 @@ eigh_p = Primitive('eigh')
|
||||
eigh_p.multiple_results = True
|
||||
eigh_p.def_impl(eigh_impl)
|
||||
eigh_p.def_abstract_eval(eigh_abstract_eval)
|
||||
xla.translations[eigh_p] = eigh_translation_rule
|
||||
xla.register_translation(eigh_p, _eigh_translation_rule)
|
||||
ad.primitive_jvps[eigh_p] = eigh_jvp_rule
|
||||
batching.primitive_batchers[eigh_p] = eigh_batching_rule
|
||||
|
||||
_cpu_syevd = lapack.syevd
|
||||
|
||||
xla.backend_specific_translations['cpu'][eigh_p] = partial(
|
||||
_eigh_cpu_gpu_translation_rule, _cpu_syevd)
|
||||
xla.register_translation(
|
||||
eigh_p, partial(_eigh_cpu_gpu_translation_rule, _cpu_syevd),
|
||||
platform='cpu')
|
||||
|
||||
if cusolver is not None:
|
||||
xla.backend_specific_translations['gpu'][eigh_p] = partial(
|
||||
_eigh_cpu_gpu_translation_rule, cusolver.syevd)
|
||||
xla.register_translation(
|
||||
eigh_p, partial(_eigh_cpu_gpu_translation_rule, cusolver.syevd),
|
||||
platform='gpu')
|
||||
|
||||
if rocsolver is not None:
|
||||
xla.backend_specific_translations['gpu'][eigh_p] = partial(
|
||||
_eigh_cpu_gpu_translation_rule, rocsolver.syevd)
|
||||
xla.register_translation(
|
||||
eigh_p, partial(_eigh_cpu_gpu_translation_rule, rocsolver.syevd),
|
||||
platform='gpu')
|
||||
|
||||
|
||||
triangular_solve_dtype_rule = partial(
|
||||
@ -686,17 +700,18 @@ batching.primitive_batchers[triangular_solve_p] = triangular_solve_batching_rule
|
||||
|
||||
|
||||
def _triangular_solve_cpu_translation_rule(
|
||||
c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
|
||||
shape = c.get_shape(a)
|
||||
dtype = shape.element_type().type
|
||||
ctx, avals_in, avals_out, a, b, *, left_side, lower, transpose_a,
|
||||
conjugate_a, unit_diagonal):
|
||||
a_aval, _ = avals_in
|
||||
c = ctx.builder
|
||||
|
||||
if conjugate_a and not transpose_a:
|
||||
a = xops.Conj(a)
|
||||
conjugate_a = False
|
||||
if len(shape.dimensions()) == 2 and np.dtype(dtype) in _cpu_lapack_types:
|
||||
return lapack.jax_trsm(
|
||||
c, xops.Constant(c, np.array(1, dtype=dtype)),
|
||||
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)
|
||||
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
|
||||
return [lapack.jax_trsm(
|
||||
c, xops.Constant(c, np.array(1, dtype=a_aval.dtype)),
|
||||
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)]
|
||||
else:
|
||||
# Fall back to the HLO implementation for unsupported types or batching.
|
||||
# TODO: Consider swapping XLA for LAPACK in batched case
|
||||
@ -705,24 +720,26 @@ def _triangular_solve_cpu_translation_rule(
|
||||
else:
|
||||
transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT if conjugate_a
|
||||
else xops.TriangularSolveOptions_Transpose.TRANSPOSE)
|
||||
return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose)
|
||||
return [xops.TriangularSolve(a, b, left_side, lower, unit_diagonal,
|
||||
transpose)]
|
||||
|
||||
xla.backend_specific_translations['cpu'][triangular_solve_p] = \
|
||||
_triangular_solve_cpu_translation_rule
|
||||
xla.register_translation(triangular_solve_p,
|
||||
_triangular_solve_cpu_translation_rule,
|
||||
platform='cpu')
|
||||
|
||||
def _triangular_solve_gpu_translation_rule(trsm_impl,
|
||||
c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
|
||||
shape = c.get_shape(a)
|
||||
dims = shape.dimensions()
|
||||
m, n = dims[-2:]
|
||||
batch = prod(dims[:-2])
|
||||
def _triangular_solve_gpu_translation_rule(
|
||||
trsm_impl, ctx, avals_in, avals_out, a, b, *, left_side, lower, transpose_a,
|
||||
conjugate_a, unit_diagonal):
|
||||
c = ctx.builder
|
||||
a_aval, _ = avals_in
|
||||
m, n = a_aval.shape[-2:]
|
||||
batch = prod(a_aval.shape[:-2])
|
||||
if conjugate_a and not transpose_a:
|
||||
a = xops.Conj(a)
|
||||
conjugate_a = False
|
||||
if batch > 1 and m <= 256 and n <= 256:
|
||||
return trsm_impl(
|
||||
c, a, b, left_side, lower, transpose_a,
|
||||
conjugate_a, unit_diagonal)
|
||||
return [trsm_impl(c, a, b, left_side, lower, transpose_a,
|
||||
conjugate_a, unit_diagonal)]
|
||||
else:
|
||||
# Use the XLA implementation for unbatched triangular_solve.
|
||||
if not transpose_a:
|
||||
@ -730,16 +747,20 @@ def _triangular_solve_gpu_translation_rule(trsm_impl,
|
||||
else:
|
||||
transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT if conjugate_a
|
||||
else xops.TriangularSolveOptions_Transpose.TRANSPOSE)
|
||||
return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal,
|
||||
transpose)
|
||||
return [xops.TriangularSolve(a, b, left_side, lower, unit_diagonal,
|
||||
transpose)]
|
||||
|
||||
if cusolver is not None:
|
||||
xla.backend_specific_translations['gpu'][triangular_solve_p] = \
|
||||
partial(_triangular_solve_gpu_translation_rule, cusolver.trsm)
|
||||
xla.register_translation(
|
||||
triangular_solve_p,
|
||||
partial(_triangular_solve_gpu_translation_rule, cusolver.trsm),
|
||||
platform='gpu')
|
||||
|
||||
if rocsolver is not None:
|
||||
xla.backend_specific_translations['gpu'][triangular_solve_p] = \
|
||||
partial(_triangular_solve_gpu_translation_rule, rocsolver.trsm)
|
||||
xla.register_translation(
|
||||
triangular_solve_p,
|
||||
partial(_triangular_solve_gpu_translation_rule, rocsolver.trsm),
|
||||
platform='gpu')
|
||||
|
||||
# Support operation for LU decomposition: Transformation of the pivots returned
|
||||
# by LU decomposition into permutations.
|
||||
@ -757,8 +778,7 @@ def _lu_pivots_body_fn(i, permutation_and_swaps):
|
||||
return permutation.at[iotas + (j,)].set(x), swaps
|
||||
|
||||
|
||||
@partial(api.jit, static_argnums=(1,))
|
||||
def _generic_lu_pivots_to_permutation(swaps, m):
|
||||
def _generic_lu_pivots_to_permutation(swaps, permutation_size):
|
||||
"""Converts the pivots (row swaps) returned by LU to a permutation.
|
||||
|
||||
We build a permutation rather than applying `swaps` directly to the rows
|
||||
@ -766,13 +786,14 @@ def _generic_lu_pivots_to_permutation(swaps, m):
|
||||
|
||||
Args:
|
||||
swaps: an array of shape (..., k) of row swaps to perform
|
||||
m: the size of the output permutation. m should be >= k.
|
||||
permutation_size: the size of the output permutation. Should be >= k.
|
||||
Returns:
|
||||
An int32 array of shape (..., m).
|
||||
"""
|
||||
assert len(swaps.shape) >= 1
|
||||
batch_dims = swaps.shape[:-1]
|
||||
k = swaps.shape[-1]
|
||||
m = permutation_size
|
||||
|
||||
permutation = lax.broadcasted_iota(jnp.int32, batch_dims + (m,),
|
||||
len(batch_dims))
|
||||
@ -812,13 +833,10 @@ def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *,
|
||||
return lu_pivots_to_permutation_p.bind(
|
||||
x, permutation_size=permutation_size), 0
|
||||
|
||||
|
||||
def _lu_pivots_to_permutation_translation_rule(c, pivots, *, permutation_size):
|
||||
lowered_fun = xla.lower_fun(
|
||||
lambda x: _generic_lu_pivots_to_permutation(x, permutation_size),
|
||||
multiple_results=False)
|
||||
return lowered_fun(c, pivots)
|
||||
|
||||
def _lu_pivots_to_permutation_gpu(ctx, avals_in, avals_out, pivots, *,
|
||||
permutation_size):
|
||||
return [cuda_linalg.lu_pivots_to_permutation(
|
||||
ctx.builder, pivots, permutation_size=permutation_size)]
|
||||
|
||||
lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation')
|
||||
lu_pivots_to_permutation_p.multiple_results = False
|
||||
@ -828,12 +846,15 @@ lu_pivots_to_permutation_p.def_abstract_eval(
|
||||
_lu_pivots_to_permutation_abstract_eval)
|
||||
batching.primitive_batchers[lu_pivots_to_permutation_p] = (
|
||||
_lu_pivots_to_permutation_batching_rule)
|
||||
xla.translations[lu_pivots_to_permutation_p] = (
|
||||
_lu_pivots_to_permutation_translation_rule)
|
||||
xla.register_translation(
|
||||
lu_pivots_to_permutation_p,
|
||||
xla.lower_fun(_generic_lu_pivots_to_permutation, multiple_results=False,
|
||||
new_style=True))
|
||||
|
||||
if cuda_linalg:
|
||||
xla.backend_specific_translations['gpu'][lu_pivots_to_permutation_p] = (
|
||||
cuda_linalg.lu_pivots_to_permutation)
|
||||
xla.register_translation(lu_pivots_to_permutation_p,
|
||||
_lu_pivots_to_permutation_gpu,
|
||||
platform='gpu')
|
||||
|
||||
# LU decomposition
|
||||
|
||||
@ -988,49 +1009,50 @@ def _lu_batching_rule(batched_args, batch_dims):
|
||||
x = batching.moveaxis(x, bd, 0)
|
||||
return lu_p.bind(x), (0, 0, 0)
|
||||
|
||||
def _lu_cpu_gpu_translation_rule(getrf_impl, c, operand, backend):
|
||||
shape = c.get_shape(operand)
|
||||
batch_dims = shape.dimensions()[:-2]
|
||||
m = shape.dimensions()[-2]
|
||||
def _lu_cpu_gpu_translation_rule(getrf_impl, ctx, avals_in, avals_out, operand):
|
||||
operand_aval, = avals_in
|
||||
c = ctx.builder
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
m = operand_aval.shape[-2]
|
||||
lu, pivot, info = getrf_impl(c, operand)
|
||||
# Subtract 1 from the pivot to get 0-based indices.
|
||||
pivot = xops.Sub(pivot, xops.ConstantLiteral(c, np.array(1, np.int32)))
|
||||
ok = xops.Ge(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
|
||||
pivot = xops.Sub(pivot, xops.Constant(c, np.array(1, np.int32)))
|
||||
ok = xops.Ge(info, xops.Constant(c, np.array(0, np.int32)))
|
||||
lu = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), lu,
|
||||
_nan_like(c, lu))
|
||||
perm = xla.lower_fun(lambda x: lu_pivots_to_permutation(x, m),
|
||||
multiple_results=False, backend=backend)(c, pivot)
|
||||
return xops.Tuple(c, [lu, pivot, perm])
|
||||
multiple_results=False, backend=ctx.platform)(c, pivot)
|
||||
return [lu, pivot, perm]
|
||||
|
||||
|
||||
def _lu_tpu_translation_rule(c, operand):
|
||||
if hasattr(xops, "LU"):
|
||||
lu, pivot, perm = xops.LU(operand)
|
||||
return xops.Tuple(c, [lu, pivot, perm])
|
||||
else:
|
||||
return xla.lower_fun(_lu_python, multiple_results=True)(c, operand)
|
||||
def _lu_tpu_translation_rule(ctx, avals_in, avals_out, operand):
|
||||
return xops.LU(operand)
|
||||
|
||||
|
||||
lu_p = Primitive('lu')
|
||||
lu_p.multiple_results = True
|
||||
lu_p.def_impl(_lu_impl)
|
||||
lu_p.def_abstract_eval(_lu_abstract_eval)
|
||||
xla.translations[lu_p] = xla.lower_fun(_lu_python, multiple_results=True)
|
||||
xla.register_translation(lu_p, xla.lower_fun(_lu_python, multiple_results=True,
|
||||
new_style=True))
|
||||
ad.primitive_jvps[lu_p] = _lu_jvp_rule
|
||||
batching.primitive_batchers[lu_p] = _lu_batching_rule
|
||||
|
||||
xla.backend_specific_translations['cpu'][lu_p] = partial(
|
||||
_lu_cpu_gpu_translation_rule, lapack.getrf, backend='cpu')
|
||||
xla.register_translation(lu_p,
|
||||
partial(_lu_cpu_gpu_translation_rule, lapack.getrf),
|
||||
platform='cpu')
|
||||
|
||||
if cusolver is not None:
|
||||
xla.backend_specific_translations['gpu'][lu_p] = partial(
|
||||
_lu_cpu_gpu_translation_rule, cusolver.getrf, backend='gpu')
|
||||
xla.register_translation(
|
||||
lu_p, partial(_lu_cpu_gpu_translation_rule, cusolver.getrf),
|
||||
platform='gpu')
|
||||
|
||||
if rocsolver is not None:
|
||||
xla.backend_specific_translations['gpu'][lu_p] = partial(
|
||||
_lu_cpu_gpu_translation_rule, rocsolver.getrf, backend='gpu')
|
||||
xla.register_translation(
|
||||
lu_p, partial(_lu_cpu_gpu_translation_rule, rocsolver.getrf),
|
||||
platform='gpu')
|
||||
|
||||
xla.backend_specific_translations['tpu'][lu_p] = _lu_tpu_translation_rule
|
||||
xla.register_translation(lu_p, _lu_tpu_translation_rule, platform='tpu')
|
||||
|
||||
|
||||
@partial(vectorize, excluded={3}, signature='(n,n),(n),(n,k)->(n,k)')
|
||||
@ -1094,8 +1116,8 @@ def qr_impl(operand, full_matrices):
|
||||
q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
|
||||
return q, r
|
||||
|
||||
def qr_translation_rule(c, operand, full_matrices):
|
||||
return xops.Tuple(c, xops.QR(operand, full_matrices))
|
||||
def _qr_translation_rule(ctx, avals_in, avals_out, operand, *, full_matrices):
|
||||
return xops.QR(operand, full_matrices)
|
||||
|
||||
def qr_abstract_eval(operand, full_matrices):
|
||||
if isinstance(operand, ShapedArray):
|
||||
@ -1137,10 +1159,11 @@ def qr_batching_rule(batched_args, batch_dims, full_matrices):
|
||||
x = batching.moveaxis(x, bd, 0)
|
||||
return qr_p.bind(x, full_matrices=full_matrices), (0, 0)
|
||||
|
||||
def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, c, operand,
|
||||
full_matrices):
|
||||
shape = c.get_shape(operand)
|
||||
dims = shape.dimensions()
|
||||
def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, ctx, avals_in,
|
||||
avals_out, operand, *, full_matrices):
|
||||
c = ctx.builder
|
||||
operand_aval, = avals_in
|
||||
dims = operand_aval.shape
|
||||
m, n = dims[-2:]
|
||||
batch_dims = dims[:-2]
|
||||
r, tau, info_geqrf = geqrf_impl(c, operand)
|
||||
@ -1155,13 +1178,13 @@ def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, c, operand,
|
||||
else:
|
||||
padding_config = [(0, 0, 0)] * len(dims)
|
||||
padding_config[-1] = (0, m - n, 0)
|
||||
q = xops.Pad(r, xops.Constant(c, np.array(0, dtype=shape.element_type())),
|
||||
q = xops.Pad(r, xops.Constant(c, np.array(0, dtype=operand_aval.dtype)),
|
||||
xla_client.make_padding_config(padding_config))
|
||||
q, info_orgqr = orgqr_impl(c, q, tau)
|
||||
if info_geqrf is not None:
|
||||
ok = xops.And(
|
||||
xops.Eq(info_geqrf, xops.ConstantLiteral(c, np.array(0, np.int32))),
|
||||
xops.Eq(info_orgqr, xops.ConstantLiteral(c, np.array(0, np.int32))))
|
||||
xops.Eq(info_geqrf, xops.Constant(c, np.array(0, np.int32))),
|
||||
xops.Eq(info_orgqr, xops.Constant(c, np.array(0, np.int32))))
|
||||
q = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), q,
|
||||
_nan_like(c, q))
|
||||
r = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), r,
|
||||
@ -1170,26 +1193,31 @@ def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, c, operand,
|
||||
pass # rocsolver does not return info
|
||||
|
||||
r = xla.lower_fun(jnp.triu, multiple_results=False)(c, r)
|
||||
return xops.Tuple(c, [q, r])
|
||||
return [q, r]
|
||||
|
||||
qr_p = Primitive('qr')
|
||||
qr_p.multiple_results = True
|
||||
qr_p.def_impl(qr_impl)
|
||||
qr_p.def_abstract_eval(qr_abstract_eval)
|
||||
xla.translations[qr_p] = qr_translation_rule
|
||||
xla.register_translation(qr_p, _qr_translation_rule)
|
||||
ad.primitive_jvps[qr_p] = qr_jvp_rule
|
||||
batching.primitive_batchers[qr_p] = qr_batching_rule
|
||||
|
||||
xla.backend_specific_translations['cpu'][qr_p] = partial(
|
||||
_qr_cpu_gpu_translation_rule, lapack.geqrf, lapack.orgqr)
|
||||
xla.register_translation(
|
||||
qr_p, partial(_qr_cpu_gpu_translation_rule, lapack.geqrf, lapack.orgqr),
|
||||
platform='cpu')
|
||||
|
||||
if cusolver is not None:
|
||||
xla.backend_specific_translations['gpu'][qr_p] = partial(
|
||||
_qr_cpu_gpu_translation_rule, cusolver.geqrf, cusolver.orgqr)
|
||||
xla.register_translation(
|
||||
qr_p,
|
||||
partial(_qr_cpu_gpu_translation_rule, cusolver.geqrf, cusolver.orgqr),
|
||||
platform='gpu')
|
||||
|
||||
if rocsolver is not None:
|
||||
xla.backend_specific_translations['gpu'][qr_p] = partial(
|
||||
_qr_cpu_gpu_translation_rule, rocsolver.geqrf, rocsolver.orgqr)
|
||||
xla.register_translation(
|
||||
qr_p,
|
||||
partial(_qr_cpu_gpu_translation_rule, rocsolver.geqrf, rocsolver.orgqr),
|
||||
platform='gpu')
|
||||
|
||||
|
||||
# Singular value decomposition
|
||||
@ -1198,12 +1226,15 @@ def svd_impl(operand, full_matrices, compute_uv):
|
||||
return xla.apply_primitive(svd_p, operand, full_matrices=full_matrices,
|
||||
compute_uv=compute_uv)
|
||||
|
||||
def svd_translation_rule(c, operand, full_matrices, compute_uv):
|
||||
shape = c.get_shape(operand).dimensions()
|
||||
def _svd_translation_rule(ctx, avals_in, avals_out, operand, *, full_matrices,
|
||||
compute_uv):
|
||||
operand_aval, = avals_in
|
||||
shape = operand_aval.shape
|
||||
m, n = shape[-2:]
|
||||
if m == 0 or n == 0:
|
||||
return xla.lower_fun(_empty_svd, multiple_results=True)(
|
||||
c, operand, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
return xla.lower_fun(_empty_svd, multiple_results=True, new_style=True)(
|
||||
ctx, avals_in, avals_out, operand, full_matrices=full_matrices,
|
||||
compute_uv=compute_uv)
|
||||
|
||||
u, s, v = xops.SVD(operand)
|
||||
permutation = list(range(len(shape)))
|
||||
@ -1214,9 +1245,9 @@ def svd_translation_rule(c, operand, full_matrices, compute_uv):
|
||||
vt = xops.SliceInDim(vt, 0, min(m, n), stride=1, dimno=len(shape) - 2)
|
||||
|
||||
if not compute_uv:
|
||||
return xops.Tuple(c, [s])
|
||||
return [s]
|
||||
else:
|
||||
return xops.Tuple(c, [s, u, vt])
|
||||
return [s, u, vt]
|
||||
|
||||
|
||||
def svd_abstract_eval(operand, full_matrices, compute_uv):
|
||||
@ -1293,19 +1324,22 @@ def _empty_svd(a, *, full_matrices, compute_uv):
|
||||
u, v = v, u
|
||||
return s, u, v
|
||||
|
||||
def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute_uv):
|
||||
shape = c.get_shape(operand).dimensions()
|
||||
m, n = shape[-2:]
|
||||
batch_dims = shape[:-2]
|
||||
def _svd_cpu_gpu_translation_rule(gesvd_impl, ctx, avals_in, avals_out, operand,
|
||||
*, full_matrices, compute_uv):
|
||||
operand_aval, = avals_in
|
||||
m, n = operand_aval.shape[-2:]
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
c = ctx.builder
|
||||
|
||||
if m == 0 or n == 0:
|
||||
return xla.lower_fun(_empty_svd, multiple_results=True)(
|
||||
c, operand, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
return xla.lower_fun(_empty_svd, multiple_results=True, new_style=True)(
|
||||
ctx, avals_in, avals_out, operand, full_matrices=full_matrices,
|
||||
compute_uv=compute_uv)
|
||||
|
||||
s, u, vt, info = gesvd_impl(c, operand,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv)
|
||||
ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
|
||||
ok = xops.Eq(info, xops.Constant(c, np.array(0, np.int32)))
|
||||
s = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), s,
|
||||
_nan_like(c, s))
|
||||
|
||||
@ -1318,7 +1352,7 @@ def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute
|
||||
_nan_like(c, vt))
|
||||
result += [u, vt]
|
||||
|
||||
return xops.Tuple(c, result)
|
||||
return result
|
||||
|
||||
def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
|
||||
x, = batched_args
|
||||
@ -1337,20 +1371,27 @@ svd_p.def_impl(svd_impl)
|
||||
svd_p.def_abstract_eval(svd_abstract_eval)
|
||||
ad.primitive_jvps[svd_p] = svd_jvp_rule
|
||||
batching.primitive_batchers[svd_p] = svd_batching_rule
|
||||
xla.translations[svd_p] = svd_translation_rule
|
||||
xla.register_translation(svd_p, _svd_translation_rule)
|
||||
|
||||
xla.backend_specific_translations['cpu'][svd_p] = partial(
|
||||
_svd_cpu_gpu_translation_rule, lapack.gesdd)
|
||||
xla.register_translation(
|
||||
svd_p, partial(_svd_cpu_gpu_translation_rule, lapack.gesdd),
|
||||
platform='cpu')
|
||||
|
||||
if cusolver is not None:
|
||||
xla.backend_specific_translations['gpu'][svd_p] = partial(
|
||||
_svd_cpu_gpu_translation_rule, cusolver.gesvd)
|
||||
xla.register_translation(
|
||||
svd_p, partial(_svd_cpu_gpu_translation_rule, cusolver.gesvd),
|
||||
platform='gpu')
|
||||
|
||||
if rocsolver is not None:
|
||||
xla.backend_specific_translations['gpu'][svd_p] = partial(
|
||||
_svd_cpu_gpu_translation_rule, rocsolver.gesvd)
|
||||
xla.register_translation(
|
||||
svd_p, partial(_svd_cpu_gpu_translation_rule, rocsolver.gesvd),
|
||||
platform='gpu')
|
||||
|
||||
|
||||
def _tridiagonal_solve_gpu_translation_rule(ctx, avals_in, avals_out, dl, d, du,
|
||||
b, *, m, n, ldb, t):
|
||||
return [cusparse.gtsv2(ctx.builder, dl, d, du, b, m=m, n=n, ldb=ldb, t=t)]
|
||||
|
||||
tridiagonal_solve_p = Primitive('tridiagonal_solve')
|
||||
tridiagonal_solve_p.multiple_results = False
|
||||
tridiagonal_solve_p.def_impl(
|
||||
@ -1358,18 +1399,11 @@ tridiagonal_solve_p.def_impl(
|
||||
tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b)
|
||||
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?
|
||||
if cusparse is not None and hasattr(cusparse, "gtsv2"):
|
||||
xla.backend_specific_translations['gpu'][tridiagonal_solve_p] = cusparse.gtsv2
|
||||
xla.register_translation(tridiagonal_solve_p,
|
||||
_tridiagonal_solve_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
|
||||
|
||||
def _tridiagonal_solve_translation_rule(c, dl, d, du, b, *, m, n, ldb, t):
|
||||
del m, n, ldb, t
|
||||
lowered_fun = xla.lower_fun(_tridiagonal_solve_jax, multiple_results=False)
|
||||
return lowered_fun(c, dl, d, du, b)
|
||||
|
||||
xla.translations[tridiagonal_solve_p] = _tridiagonal_solve_translation_rule
|
||||
|
||||
|
||||
def _tridiagonal_solve_jax(dl, d, du, b):
|
||||
def _tridiagonal_solve_jax(dl, d, du, b, **kw):
|
||||
"""Pure JAX implementation of `tridiagonal_solve`."""
|
||||
prepend_zero = lambda x: jnp.append(jnp.zeros([1], dtype=x.dtype), x[:-1])
|
||||
fwd1 = lambda tu_, x: x[1] / (x[0] - x[2] * tu_)
|
||||
@ -1397,6 +1431,10 @@ def _tridiagonal_solve_jax(dl, d, du, b):
|
||||
return x_[::-1]
|
||||
|
||||
|
||||
xla.register_translation(tridiagonal_solve_p, xla.lower_fun(
|
||||
_tridiagonal_solve_jax, multiple_results=False, new_style=True))
|
||||
|
||||
|
||||
def tridiagonal_solve(dl, d, du, b):
|
||||
r"""Computes the solution of a tridiagonal linear system.
|
||||
|
||||
@ -1469,8 +1507,8 @@ def _schur_impl(operand, *, compute_schur_vectors, sort_eig_vals,
|
||||
select_callable=select_callable)
|
||||
|
||||
|
||||
def _schur_translation_rule(c, operand, *, compute_schur_vectors,
|
||||
sort_eig_vals):
|
||||
def _schur_translation_rule(ctx, avals_in, avals_out, operand, *,
|
||||
compute_schur_vectors, sort_eig_vals):
|
||||
raise NotImplementedError(
|
||||
"Schur decomposition is only implemented on the CPU backend.")
|
||||
|
||||
@ -1492,10 +1530,12 @@ def _schur_abstract_eval(operand, *, compute_schur_vectors, sort_eig_vals,
|
||||
return (T, vs) if compute_schur_vectors else (T,)
|
||||
|
||||
|
||||
def _schur_cpu_translation_rule(c, operand, *, compute_schur_vectors,
|
||||
sort_eig_vals, select_callable):
|
||||
shape = c.get_shape(operand)
|
||||
batch_dims = shape.dimensions()[:-2]
|
||||
def _schur_cpu_translation_rule(ctx, avals_in, avals_out, operand, *,
|
||||
compute_schur_vectors, sort_eig_vals,
|
||||
select_callable):
|
||||
operand_aval, = avals_in
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
c = ctx.builder
|
||||
|
||||
if jaxlib_version < (0, 1, 72):
|
||||
raise NotImplementedError(
|
||||
@ -1519,7 +1559,7 @@ def _schur_cpu_translation_rule(c, operand, *, compute_schur_vectors,
|
||||
sort=sort_eig_vals,
|
||||
select=select_callable)
|
||||
|
||||
ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
|
||||
ok = xops.Eq(info, xops.Constant(c, np.array(0, np.int32)))
|
||||
T = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), T,
|
||||
_nan_like(c, T))
|
||||
output = [T]
|
||||
@ -1529,7 +1569,7 @@ def _schur_cpu_translation_rule(c, operand, *, compute_schur_vectors,
|
||||
|
||||
output.append(vs)
|
||||
|
||||
return xops.Tuple(c, output)
|
||||
return output
|
||||
|
||||
|
||||
def _schur_batching_rule(batched_args, batch_dims, *, compute_schur_vectors,
|
||||
@ -1555,7 +1595,7 @@ schur_p = Primitive('schur')
|
||||
schur_p.multiple_results = True
|
||||
schur_p.def_impl(_schur_impl)
|
||||
schur_p.def_abstract_eval(_schur_abstract_eval)
|
||||
xla.translations[schur_p] = _schur_translation_rule
|
||||
xla.backend_specific_translations['cpu'][schur_p] = _schur_cpu_translation_rule
|
||||
xla.register_translation(schur_p, _schur_translation_rule)
|
||||
xla.register_translation(schur_p, _schur_cpu_translation_rule, platform='cpu')
|
||||
batching.primitive_batchers[schur_p] = _schur_batching_rule
|
||||
ad.primitive_jvps[schur_p] = _schur_jvp_rule
|
||||
|
@ -350,21 +350,24 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
|
||||
return tuple(x)
|
||||
|
||||
|
||||
def _threefry2x32_gpu_translation_rule(c, k1, k2, x1, x2):
|
||||
shape = lax.broadcast_shapes(
|
||||
c.get_shape(k1).dimensions(), c.get_shape(k2).dimensions(),
|
||||
c.get_shape(x1).dimensions(), c.get_shape(x2).dimensions())
|
||||
rank = len(shape)
|
||||
if 0 in shape:
|
||||
def _threefry2x32_gpu_translation_rule(ctx, avals_in, avals_out, k1, k2, x1,
|
||||
x2):
|
||||
aval_out, _ = avals_out
|
||||
k1_aval, k2_aval, x1_aval, x2_aval = avals_in
|
||||
rank = len(aval_out.shape)
|
||||
if 0 in aval_out.shape:
|
||||
zeros = xla_client.ops.Broadcast(
|
||||
xla_client.ops.Constant(c, np.array(0, np.uint32)), shape)
|
||||
return xla_client.ops.Tuple(c, [zeros, zeros])
|
||||
def _broadcast(x):
|
||||
ndims = c.get_shape(x).rank()
|
||||
return xla_client.ops.BroadcastInDim(x, shape,
|
||||
tuple(range(rank - ndims, rank)))
|
||||
return cuda_prng.threefry2x32(
|
||||
c, (_broadcast(k1), _broadcast(k2)), (_broadcast(x1), _broadcast(x2)))
|
||||
xla_client.ops.Constant(ctx.builder, np.array(0, np.uint32)),
|
||||
aval_out.shape)
|
||||
return [zeros, zeros]
|
||||
def _broadcast(x, aval):
|
||||
return xla_client.ops.BroadcastInDim(
|
||||
x, aval_out.shape, tuple(range(rank - len(aval.shape), rank)))
|
||||
return xla.xla_destructure(
|
||||
ctx.builder,
|
||||
cuda_prng.threefry2x32(
|
||||
ctx.builder, (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
||||
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval))))
|
||||
|
||||
|
||||
threefry2x32_p = core.Primitive("threefry2x32")
|
||||
@ -379,8 +382,8 @@ xla.register_translation(threefry2x32_p, xla.lower_fun(
|
||||
partial(_threefry2x32_lowering, use_rolled_loops=True),
|
||||
multiple_results=True, new_style=True), platform='cpu')
|
||||
if cuda_prng:
|
||||
xla.backend_specific_translations['gpu'][threefry2x32_p] = \
|
||||
_threefry2x32_gpu_translation_rule
|
||||
xla.register_translation(threefry2x32_p, _threefry2x32_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
|
||||
|
||||
@partial(jit, inline=True)
|
||||
|
@ -798,7 +798,8 @@ def _values_to_avals(vals) -> Sequence[core.ShapedArray]:
|
||||
id_tap_dep_p = core.Primitive("id_tap_dep")
|
||||
id_tap_dep_p.multiple_results = False
|
||||
id_tap_dep_p.def_impl(lambda r, _: r)
|
||||
xla.translations[id_tap_dep_p] = lambda comp, a_res, a_tap: a_res
|
||||
xla.register_translation(id_tap_dep_p,
|
||||
lambda ctx, avals_in, avals_out, a_res, a_tap: [a_res])
|
||||
id_tap_dep_p.def_abstract_eval(lambda r_a, _: r_a)
|
||||
|
||||
def _id_tap_dep_jvp_rule(primals, tangents):
|
||||
@ -923,9 +924,8 @@ def _outside_call_impl(*args, **params):
|
||||
outside_call_p.def_impl(_outside_call_impl)
|
||||
|
||||
|
||||
def _outside_call_translation_rule(comp: XlaBuilder,
|
||||
def _outside_call_translation_rule(ctx, avals_in, avals_out,
|
||||
*args_op: XlaOp,
|
||||
platform="tpu",
|
||||
has_token,
|
||||
identity,
|
||||
flat_results_aval=(),
|
||||
@ -934,6 +934,7 @@ def _outside_call_translation_rule(comp: XlaBuilder,
|
||||
assert has_token
|
||||
current_token = args_op[-2]
|
||||
current_itoken = args_op[-1]
|
||||
comp = ctx.builder
|
||||
assert comp.get_shape(current_token).is_token() and comp.get_shape(current_itoken).is_token(), (
|
||||
"The last two arguments must be tokens")
|
||||
|
||||
@ -944,7 +945,7 @@ def _outside_call_translation_rule(comp: XlaBuilder,
|
||||
flat_results_aval))
|
||||
need_callback_results_on_device = (not identity and
|
||||
len(non_empty_flat_results_aval) > 0)
|
||||
use_outfeed = _use_outfeed(platform)
|
||||
use_outfeed = _use_outfeed(ctx.platform)
|
||||
send_infeed = use_outfeed and need_callback_results_on_device
|
||||
generated_infeed = False # Keep track if we emitted an infeed op
|
||||
if use_outfeed:
|
||||
@ -1041,7 +1042,7 @@ def _outside_call_translation_rule(comp: XlaBuilder,
|
||||
xla.aval_to_xla_shapes(res_aval)[0]
|
||||
for res_aval in callback_flat_results_aval
|
||||
]
|
||||
backend = xb.get_backend(platform)
|
||||
backend = xb.get_backend(ctx.platform)
|
||||
token_and_results_op, keep_alive = backend.emit_python_callback(
|
||||
wrapped_callback,
|
||||
comp,
|
||||
@ -1062,12 +1063,10 @@ def _outside_call_translation_rule(comp: XlaBuilder,
|
||||
assert identity or len(results) == len(flat_results_aval), (
|
||||
f"got {len(results)} but expected {len(flat_results_aval)}. "
|
||||
f"identity = {identity}")
|
||||
return xops.Tuple(comp, results + [next_token, next_itoken])
|
||||
return results + [next_token, next_itoken]
|
||||
|
||||
|
||||
for platform in ["cpu", "gpu", "tpu"]:
|
||||
xla.backend_specific_translations[platform][outside_call_p] = (
|
||||
functools.partial(_outside_call_translation_rule, platform=platform))
|
||||
xla.register_translation(outside_call_p, _outside_call_translation_rule)
|
||||
|
||||
|
||||
def _outside_call_run_callback(
|
||||
@ -1383,7 +1382,7 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
|
||||
else:
|
||||
output_token_var = mk_new_var(last_token_var.aval)
|
||||
output_itoken_var = mk_new_var(last_itoken_var.aval)
|
||||
_rewrite_eqn(platform, eqn, eqns, last_token_var, output_token_var,
|
||||
_rewrite_eqn(eqn, eqns, last_token_var, output_token_var,
|
||||
last_itoken_var, output_itoken_var, mk_new_var)
|
||||
last_token_var = output_token_var
|
||||
last_itoken_var = output_itoken_var
|
||||
@ -1393,7 +1392,7 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
|
||||
return new_jaxpr
|
||||
|
||||
|
||||
def _rewrite_eqn(platform: str, eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
input_token_var: core.Var, output_token_var: core.Var,
|
||||
input_itoken_var: core.Var, output_itoken_var: core.Var,
|
||||
mk_new_var: Callable[[core.AbstractValue], core.Var]):
|
||||
@ -1691,7 +1690,7 @@ id_p = core.Primitive("id")
|
||||
id_p.multiple_results = True
|
||||
id_p.def_impl(lambda *args: args)
|
||||
id_p.def_abstract_eval(lambda *args: args)
|
||||
xla.translations[id_p] = lambda c, *args: xops.Tuple(c, args)
|
||||
xla.register_translation(id_p, lambda ctx, avals_in, avals_out, *args: args)
|
||||
|
||||
xla.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False)
|
||||
|
||||
|
@ -237,7 +237,7 @@ def _call_tf_abstract_eval(*_,
|
||||
call_tf_p.def_abstract_eval(_call_tf_abstract_eval)
|
||||
|
||||
|
||||
def _call_tf_translation_rule(builder: xla.XlaBuilder, *args_op,
|
||||
def _call_tf_translation_rule(ctx, avals_in, avals_out, *args_op,
|
||||
function_flat_tf,
|
||||
args_flat_sig_tf,
|
||||
**_):
|
||||
@ -245,7 +245,7 @@ def _call_tf_translation_rule(builder: xla.XlaBuilder, *args_op,
|
||||
code_gen, _ = _code_generator_and_avals(function_flat_tf, args_flat_sig_tf, # type: ignore
|
||||
code_gen_optional=False)
|
||||
assert code_gen is not None
|
||||
return code_gen(builder, args_op)
|
||||
return code_gen(ctx.builder, args_op)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
@ -253,7 +253,8 @@ def _code_generator_and_avals(
|
||||
function_flat_tf,
|
||||
args_flat_sig_tf,
|
||||
code_gen_optional=False
|
||||
) -> Tuple[Optional[Callable[[xla.XlaBuilder, Sequence[xla.XlaOp]], xla.XlaOp]],
|
||||
) -> Tuple[Optional[Callable[[xla.XlaBuilder, Sequence[xla.XlaOp]],
|
||||
Sequence[xla.XlaOp]]],
|
||||
Sequence[core.ShapedArray]]:
|
||||
# Returns and caches a code generator (taking a builder and the
|
||||
# XlaOps for the arguments) and a sequence of result abstract shapes.
|
||||
@ -384,8 +385,9 @@ def _code_generator_and_avals(
|
||||
|
||||
result_avals = tuple(map(canonical_res_aval, result_shapes)) # type: ignore
|
||||
|
||||
def code_gen(builder: xla.XlaBuilder, args_op: Sequence[xla.XlaOp]) -> xla.XlaOp:
|
||||
captured_ops = [xops.ConstantLiteral(builder, np.asarray(inp))
|
||||
def code_gen(builder: xla.XlaBuilder, args_op: Sequence[xla.XlaOp]
|
||||
) -> Sequence[xla.XlaOp]:
|
||||
captured_ops = [xops.Constant(builder, np.asarray(inp))
|
||||
for inp in captured_inputs]
|
||||
|
||||
res_tf = xops.Call(builder, xla_comp, args_op + tuple(captured_ops)) # type: ignore
|
||||
@ -399,15 +401,14 @@ def _code_generator_and_avals(
|
||||
new_element_type=xla.dtype_to_primitive_type(res_aval.dtype))
|
||||
return res_op
|
||||
|
||||
results = [
|
||||
return [
|
||||
post_process_result(i, res_aval, res_shape)
|
||||
for i, (res_aval, res_shape) in enumerate(zip(result_avals,
|
||||
result_shapes))]
|
||||
return xops.Tuple(builder, results)
|
||||
|
||||
return code_gen, result_avals
|
||||
|
||||
xla.translations[call_tf_p] = _call_tf_translation_rule
|
||||
xla.register_translation(call_tf_p, _call_tf_translation_rule)
|
||||
|
||||
TfVal = jax2tf_internal.TfVal
|
||||
def _jax2tf_call_tf(*args: TfVal,
|
||||
|
@ -752,11 +752,13 @@ ad.deflinear2(sharding_constraint_p,
|
||||
sharding_constraint_p.bind(
|
||||
ct, axis_resources=axis_resources, resource_env=resource_env),))
|
||||
|
||||
def _sharding_constraint_translation_rule(c, x_node, axis_resources, resource_env):
|
||||
def _sharding_constraint_translation_rule(ctx, avals_in, avals_out, x_node, *,
|
||||
axis_resources, resource_env):
|
||||
mesh = resource_env.physical_mesh
|
||||
return xb.set_sharding_proto(c, x_node,
|
||||
get_sharding_proto(c, x_node, axis_resources, mesh))
|
||||
xla.translations[sharding_constraint_p] = _sharding_constraint_translation_rule
|
||||
return [xb.set_sharding_proto(
|
||||
ctx.builder, x_node,
|
||||
get_sharding_proto(ctx.builder, x_node, axis_resources, mesh))]
|
||||
xla.register_translation(sharding_constraint_p, _sharding_constraint_translation_rule)
|
||||
|
||||
def _sharding_constraint_batcher(insert_axis, axis_size, axis_name, main_type, vals_in, dims_in,
|
||||
axis_resources, resource_env):
|
||||
|
@ -188,8 +188,8 @@ def _bcoo_todense_batching_rule(batched_args, batch_dims, *, shape):
|
||||
ad.defjvp(bcoo_todense_p, _bcoo_todense_jvp, None)
|
||||
ad.primitive_transposes[bcoo_todense_p] = _bcoo_todense_transpose
|
||||
batching.primitive_batchers[bcoo_todense_p] = _bcoo_todense_batching_rule
|
||||
xla.translations[bcoo_todense_p] = xla.lower_fun(
|
||||
_bcoo_todense_impl, multiple_results=False)
|
||||
xla.register_translation(bcoo_todense_p, xla.lower_fun(
|
||||
_bcoo_todense_impl, multiple_results=False, new_style=True))
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# bcoo_fromdense
|
||||
@ -288,8 +288,8 @@ def _bcoo_fromdense_batching_rule(batched_args, batch_dims, *, nse, n_batch, n_d
|
||||
ad.primitive_jvps[bcoo_fromdense_p] = _bcoo_fromdense_jvp
|
||||
ad.primitive_transposes[bcoo_fromdense_p] = _bcoo_fromdense_transpose
|
||||
batching.primitive_batchers[bcoo_fromdense_p] = _bcoo_fromdense_batching_rule
|
||||
xla.translations[bcoo_fromdense_p] = xla.lower_fun(
|
||||
_bcoo_fromdense_impl, multiple_results=True)
|
||||
xla.register_translation(bcoo_fromdense_p, xla.lower_fun(
|
||||
_bcoo_fromdense_impl, multiple_results=True, new_style=True))
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# bcoo_extract
|
||||
@ -355,8 +355,8 @@ def _bcoo_extract_batching_rule(batched_args, batch_dims):
|
||||
ad.defjvp(bcoo_extract_p, None, _bcoo_extract_jvp)
|
||||
ad.primitive_transposes[bcoo_extract_p] = _bcoo_extract_transpose
|
||||
batching.primitive_batchers[bcoo_extract_p] = _bcoo_extract_batching_rule
|
||||
xla.translations[bcoo_extract_p] = xla.lower_fun(
|
||||
_bcoo_extract_impl, multiple_results=False)
|
||||
xla.register_translation(bcoo_extract_p, xla.lower_fun(
|
||||
_bcoo_extract_impl, multiple_results=False, new_style=True))
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# bcoo_transpose
|
||||
@ -450,8 +450,8 @@ def _bcoo_transpose_batch_rule(batched_args, batch_dims, *, permutation, shape):
|
||||
ad.primitive_jvps[bcoo_transpose_p] = _bcoo_transpose_jvp
|
||||
ad.primitive_transposes[bcoo_transpose_p] = _bcoo_transpose_transpose
|
||||
batching.primitive_batchers[bcoo_transpose_p] = _bcoo_transpose_batch_rule
|
||||
xla.translations[bcoo_transpose_p] = xla.lower_fun(
|
||||
_bcoo_transpose_impl, multiple_results=True)
|
||||
xla.register_translation(bcoo_transpose_p, xla.lower_fun(
|
||||
_bcoo_transpose_impl, multiple_results=True, new_style=True))
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# bcoo_dot_general
|
||||
@ -620,8 +620,8 @@ def _bcoo_dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
ad.defjvp(bcoo_dot_general_p, _bcoo_dot_general_jvp_lhs, None, _bcoo_dot_general_jvp_rhs)
|
||||
ad.primitive_transposes[bcoo_dot_general_p] = _bcoo_dot_general_transpose
|
||||
batching.primitive_batchers[bcoo_dot_general_p] = _bcoo_dot_general_batch_rule
|
||||
xla.translations[bcoo_dot_general_p] = xla.lower_fun(
|
||||
_bcoo_dot_general_impl, multiple_results=False)
|
||||
xla.register_translation(bcoo_dot_general_p, xla.lower_fun(
|
||||
_bcoo_dot_general_impl, multiple_results=False, new_style=True))
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# bcoo_dot_general_sampled
|
||||
@ -672,8 +672,8 @@ ad.defjvp(bcoo_dot_general_sampled_p, _bcoo_dot_general_sampled_jvp_A,
|
||||
_bcoo_dot_general_sampled_jvp_B, None)
|
||||
ad.primitive_transposes[bcoo_dot_general_sampled_p] = _bcoo_dot_general_sampled_transpose
|
||||
batching.primitive_batchers[bcoo_dot_general_sampled_p] = _bcoo_dot_general_sampled_batch_rule
|
||||
xla.translations[bcoo_dot_general_sampled_p] = xla.lower_fun(
|
||||
_bcoo_dot_general_sampled_impl, multiple_results=False)
|
||||
xla.register_translation(bcoo_dot_general_sampled_p, xla.lower_fun(
|
||||
_bcoo_dot_general_sampled_impl, multiple_results=False, new_style=True))
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# bcoo_spdot_general
|
||||
@ -816,8 +816,8 @@ def _bcoo_spdot_general_batch_rule(batched_args, batch_dims, *, dimension_number
|
||||
|
||||
# TODO(JVP): jvp, transpose
|
||||
batching.primitive_batchers[bcoo_spdot_general_p] = _bcoo_spdot_general_batch_rule
|
||||
xla.translations[bcoo_spdot_general_p] = xla.lower_fun(
|
||||
_bcoo_spdot_general_impl, multiple_results=True)
|
||||
xla.register_translation(bcoo_spdot_general_p, xla.lower_fun(
|
||||
_bcoo_spdot_general_impl, multiple_results=True, new_style=True))
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# BCOO functions that maybe should be primitives?
|
||||
|
@ -103,14 +103,15 @@ def _csr_todense_abstract_eval(data, indices, indptr, *, shape):
|
||||
assert indptr.shape[0] == shape[0] + 1
|
||||
return core.ShapedArray(shape, data.dtype)
|
||||
|
||||
def _csr_todense_gpu_translation_rule(c, data, indices, indptr, *, shape):
|
||||
return cusparse.csr_todense(c, data, indices, indptr, shape=shape)
|
||||
def _csr_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
|
||||
indptr, *, shape):
|
||||
return [cusparse.csr_todense(ctx.builder, data, indices, indptr, shape=shape)]
|
||||
|
||||
xla.translations[csr_todense_p] = xla.lower_fun(
|
||||
_csr_todense_impl, multiple_results=False)
|
||||
xla.register_translation(csr_todense_p, xla.lower_fun(
|
||||
_csr_todense_impl, multiple_results=False, new_style=True))
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
csr_todense_p] = _csr_todense_gpu_translation_rule
|
||||
xla.register_translation(csr_todense_p, _csr_todense_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# csr_fromdense
|
||||
@ -159,16 +160,18 @@ def _csr_fromdense_abstract_eval(mat, *, nse, index_dtype):
|
||||
indptr = core.ShapedArray((mat.shape[0] + 1,), index_dtype)
|
||||
return data, indices, indptr
|
||||
|
||||
def _csr_fromdense_gpu_translation_rule(c, mat, *, nse, index_dtype):
|
||||
def _csr_fromdense_gpu_translation_rule(ctx, avals_in, avals_out, mat, *, nse,
|
||||
index_dtype):
|
||||
data, indices, indptr = cusparse.csr_fromdense(
|
||||
c, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
|
||||
return xops.Tuple(c, [data, indices, indptr])
|
||||
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
|
||||
return [data, indices, indptr]
|
||||
|
||||
xla.translations[csr_fromdense_p] = xla.lower_fun(
|
||||
_csr_fromdense_impl, multiple_results=True)
|
||||
xla.register_translation(csr_fromdense_p, xla.lower_fun(
|
||||
_csr_fromdense_impl, multiple_results=True, new_style=True))
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
csr_fromdense_p] = _csr_fromdense_gpu_translation_rule
|
||||
xla.register_translation(csr_fromdense_p,
|
||||
_csr_fromdense_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# csr_matvec
|
||||
@ -211,14 +214,16 @@ def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose):
|
||||
assert v.shape[0] == (shape[0] if transpose else shape[1])
|
||||
return core.ShapedArray((out_shape,), data.dtype)
|
||||
|
||||
def _csr_matvec_gpu_translation_rule(c, data, indices, indptr, v, *, shape, transpose):
|
||||
return cusparse.csr_matvec(c, data, indices, indptr, v, shape=shape, transpose=transpose)
|
||||
def _csr_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
|
||||
indptr, v, *, shape, transpose):
|
||||
return [cusparse.csr_matvec(ctx.builder, data, indices, indptr, v,
|
||||
shape=shape, transpose=transpose)]
|
||||
|
||||
xla.translations[csr_matvec_p] = xla.lower_fun(
|
||||
_csr_matvec_impl, multiple_results=False)
|
||||
xla.register_translation(csr_matvec_p, xla.lower_fun(
|
||||
_csr_matvec_impl, multiple_results=False, new_style=True))
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
csr_matvec_p] = _csr_matvec_gpu_translation_rule
|
||||
xla.register_translation(csr_matvec_p, _csr_matvec_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
@ -263,14 +268,16 @@ def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose):
|
||||
assert B.shape[0] == (shape[0] if transpose else shape[1])
|
||||
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
|
||||
|
||||
def _csr_matmat_gpu_translation_rule(c, data, indices, indptr, B, *, shape, transpose):
|
||||
return cusparse.csr_matmat(c, data, indices, indptr, B, shape=shape, transpose=transpose)
|
||||
def _csr_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
|
||||
indptr, B, *, shape, transpose):
|
||||
return [cusparse.csr_matmat(ctx.builder, data, indices, indptr, B,
|
||||
shape=shape, transpose=transpose)]
|
||||
|
||||
xla.translations[csr_matmat_p] = xla.lower_fun(
|
||||
_csr_matmat_impl, multiple_results=False)
|
||||
xla.register_translation(csr_matmat_p, xla.lower_fun(
|
||||
_csr_matmat_impl, multiple_results=False, new_style=True))
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
csr_matmat_p] = _csr_matmat_gpu_translation_rule
|
||||
xla.register_translation(csr_matmat_p, _csr_matmat_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
@ -300,8 +307,9 @@ def _coo_todense_impl(data, row, col, *, shape):
|
||||
def _coo_todense_abstract_eval(data, row, col, *, shape):
|
||||
return core.ShapedArray(shape, data.dtype)
|
||||
|
||||
def _coo_todense_gpu_translation_rule(c, data, row, col, *, shape):
|
||||
return cusparse.coo_todense(c, data, row, col, shape=shape)
|
||||
def _coo_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
|
||||
*, shape):
|
||||
return [cusparse.coo_todense(ctx.builder, data, row, col, shape=shape)]
|
||||
|
||||
def _coo_todense_jvp(data_dot, data, row, col, *, shape):
|
||||
return coo_todense(data_dot, row, col, shape=shape)
|
||||
@ -319,11 +327,11 @@ def _coo_todense_transpose(ct, data, row, col, *, shape):
|
||||
|
||||
ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
|
||||
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
|
||||
xla.translations[coo_todense_p] = xla.lower_fun(
|
||||
_coo_todense_impl, multiple_results=False)
|
||||
xla.register_translation(coo_todense_p, xla.lower_fun(
|
||||
_coo_todense_impl, multiple_results=False, new_style=True))
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
coo_todense_p] = _coo_todense_gpu_translation_rule
|
||||
xla.register_translation(coo_todense_p, _coo_todense_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# coo_fromdense
|
||||
@ -367,10 +375,11 @@ def _coo_fromdense_abstract_eval(mat, *, nse, index_dtype):
|
||||
row = col = core.ShapedArray((nse,), index_dtype)
|
||||
return data, row, col
|
||||
|
||||
def _coo_fromdense_gpu_translation_rule(c, mat, *, nse, index_dtype):
|
||||
def _coo_fromdense_gpu_translation_rule(ctx, avals_in, avals_out, mat, *, nse,
|
||||
index_dtype):
|
||||
data, row, col = cusparse.coo_fromdense(
|
||||
c, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
|
||||
return xops.Tuple(c, [data, row, col])
|
||||
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
|
||||
return [data, row, col]
|
||||
|
||||
def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype):
|
||||
M, = primals
|
||||
@ -400,11 +409,12 @@ def _coo_fromdense_transpose(ct, M, *, nse, index_dtype):
|
||||
ad.primitive_jvps[coo_fromdense_p] = _coo_fromdense_jvp
|
||||
ad.primitive_transposes[coo_fromdense_p] = _coo_fromdense_transpose
|
||||
|
||||
xla.translations[coo_fromdense_p] = xla.lower_fun(
|
||||
_coo_fromdense_impl, multiple_results=True)
|
||||
xla.register_translation(coo_fromdense_p, xla.lower_fun(
|
||||
_coo_fromdense_impl, multiple_results=True, new_style=True))
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
coo_fromdense_p] = _coo_fromdense_gpu_translation_rule
|
||||
xla.register_translation(coo_fromdense_p,
|
||||
_coo_fromdense_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# coo_matvec
|
||||
@ -450,8 +460,10 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, shape, transpose):
|
||||
out_shape = shape[1] if transpose else shape[0]
|
||||
return core.ShapedArray((out_shape,), data.dtype)
|
||||
|
||||
def _coo_matvec_gpu_translation_rule(c, data, row, col, v, *, shape, transpose):
|
||||
return cusparse.coo_matvec(c, data, row, col, v, shape=shape, transpose=transpose)
|
||||
def _coo_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
|
||||
v, *, shape, transpose):
|
||||
return [cusparse.coo_matvec(ctx.builder, data, row, col, v, shape=shape,
|
||||
transpose=transpose)]
|
||||
|
||||
def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, shape, transpose):
|
||||
return coo_matvec(data_dot, row, col, v, shape=shape, transpose=transpose)
|
||||
@ -472,11 +484,11 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, shape, transpose):
|
||||
|
||||
ad.defjvp(coo_matvec_p, _coo_matvec_jvp_mat, None, None, _coo_matvec_jvp_vec)
|
||||
ad.primitive_transposes[coo_matvec_p] = _coo_matvec_transpose
|
||||
xla.translations[coo_matvec_p] = xla.lower_fun(
|
||||
_coo_matvec_impl, multiple_results=False)
|
||||
xla.register_translation(coo_matvec_p, xla.lower_fun(
|
||||
_coo_matvec_impl, multiple_results=False, new_style=True))
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
coo_matvec_p] = _coo_matvec_gpu_translation_rule
|
||||
xla.register_translation(coo_matvec_p, _coo_matvec_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# coo_matmat
|
||||
@ -521,14 +533,16 @@ def _coo_matmat_abstract_eval(data, row, col, B, *, shape, transpose):
|
||||
out_shape = shape[1] if transpose else shape[0]
|
||||
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
|
||||
|
||||
def _coo_matmat_gpu_translation_rule(c, data, row, col, B, *, shape, transpose):
|
||||
return cusparse.coo_matmat(c, data, row, col, B, shape=shape, transpose=transpose)
|
||||
def _coo_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
|
||||
B, *, shape, transpose):
|
||||
return [cusparse.coo_matmat(ctx.builder, data, row, col, B, shape=shape,
|
||||
transpose=transpose)]
|
||||
|
||||
xla.translations[coo_matmat_p] = xla.lower_fun(
|
||||
_coo_matmat_impl, multiple_results=False)
|
||||
xla.register_translation(coo_matmat_p, xla.lower_fun(
|
||||
_coo_matmat_impl, multiple_results=False, new_style=True))
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
coo_matmat_p] = _coo_matmat_gpu_translation_rule
|
||||
xla.register_translation(coo_matmat_p, _coo_matmat_gpu_translation_rule,
|
||||
platform='gpu')
|
||||
|
||||
def _coo_matmat_jvp_rule(primals_in, tangents_in, **params):
|
||||
vals, rows, cols, mat = primals_in
|
||||
|
@ -409,15 +409,17 @@ def _sharding_constraint_impl(x, partitions):
|
||||
raise NotImplementedError(
|
||||
"with_sharding_constraint() should only be called inside sharded_jit()")
|
||||
|
||||
def _sharding_constraint_translation_rule(c, x_node, partitions):
|
||||
return xb.set_sharding(c, x_node, partitions)
|
||||
def _sharding_constraint_translation_rule(ctx, avals_in, avals_out, x_node,
|
||||
partitions):
|
||||
return [xb.set_sharding(ctx.builder, x_node, partitions)]
|
||||
|
||||
sharding_constraint_p = core.Primitive("sharding_constraint")
|
||||
sharding_constraint_p.def_impl(_sharding_constraint_impl)
|
||||
sharding_constraint_p.def_abstract_eval(lambda x, partitions: x)
|
||||
ad.deflinear2(sharding_constraint_p,
|
||||
lambda ct, _, partitions: (with_sharding_constraint(ct, partitions),))
|
||||
xla.translations[sharding_constraint_p] = _sharding_constraint_translation_rule
|
||||
xla.register_translation(sharding_constraint_p,
|
||||
_sharding_constraint_translation_rule)
|
||||
|
||||
def with_sharding_constraint(x, partitions: Optional[PartitionSpec]):
|
||||
"""Identity-like function that specifies how ``x`` should be sharded.
|
||||
|
@ -2699,7 +2699,8 @@ class APITest(jtu.JaxTestCase):
|
||||
tokentest_p = core.Primitive("tokentest")
|
||||
tokentest_p.def_impl(partial(xla.apply_primitive, tokentest_p))
|
||||
tokentest_p.def_abstract_eval(lambda x, y: x)
|
||||
xla.translations[tokentest_p] = lambda c, x, y: x
|
||||
xla.register_translation(tokentest_p,
|
||||
lambda ctx, avals_in, avals_out, x, y: [x])
|
||||
ad.defjvp(tokentest_p, (lambda g, x, token: x), None)
|
||||
|
||||
token = jax.lax.create_token(123)
|
||||
|
@ -145,12 +145,12 @@ def _sp_indices_impl(mat):
|
||||
def _sp_indices_abstract_eval(mat):
|
||||
return mat.indices_aval
|
||||
|
||||
def _sp_indices_translation_rule(c, data, indices):
|
||||
return indices
|
||||
def _sp_indices_translation_rule(ctx, avals_in, avals_out, data, indices):
|
||||
return [indices]
|
||||
|
||||
# Note: cannot use lower_fun to define attribute access primitives
|
||||
# because it leads to infinite recursion.
|
||||
xla.translations[sp_indices_p] = _sp_indices_translation_rule
|
||||
xla.register_translation(sp_indices_p, _sp_indices_translation_rule)
|
||||
|
||||
sp_data_p = core.Primitive('sp_data')
|
||||
|
||||
@ -162,12 +162,12 @@ def _sp_data_impl(mat):
|
||||
def _sp_data_abstract_eval(mat):
|
||||
return mat.data_aval
|
||||
|
||||
def _sp_data_translation_rule(c, data, indices):
|
||||
return data
|
||||
def _sp_data_translation_rule(ctx, avals_in, avals_out, data, indices):
|
||||
return [data]
|
||||
|
||||
# Note: cannot use lower_fun to define attribute access primitives
|
||||
# because it leads to infinite recursion.
|
||||
xla.translations[sp_data_p] = _sp_data_translation_rule
|
||||
xla.register_translation(sp_data_p, _sp_data_translation_rule)
|
||||
|
||||
def identity(x):
|
||||
return identity_p.bind(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user