[MHLO] Add MHLO lowerings for triangular_solve, cholesky, and schur.

PiperOrigin-RevId: 441769591
This commit is contained in:
Peter Hawkins 2022-04-14 08:37:43 -07:00 committed by jax authors
parent 4806c29bf7
commit 6c1461b52b
2 changed files with 166 additions and 1 deletions

View File

@ -348,6 +348,13 @@ cholesky_p = standard_unop(_float | _complex, 'cholesky')
ad.primitive_jvps[cholesky_p] = cholesky_jvp_rule
batching.primitive_batchers[cholesky_p] = cholesky_batching_rule
def _cholesky_lowering(ctx, x):
aval, = ctx.avals_out
return mhlo.CholeskyOp(mlir.aval_to_ir_type(aval), x,
lower=ir.BoolAttr.get(True)).results
mlir.register_lowering(cholesky_p, _cholesky_lowering)
def _cholesky_cpu_gpu_translation_rule(potrf_impl, ctx, avals_in, avals_out,
operand):
operand_aval, = avals_in
@ -787,6 +794,24 @@ ad.primitive_transposes[triangular_solve_p] = triangular_solve_transpose_rule
batching.primitive_batchers[triangular_solve_p] = triangular_solve_batching_rule
def _triangular_solve_lowering(
ctx, a, b, *, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
out_aval, = ctx.avals_out
if conjugate_a and not transpose_a:
a = chlo.ConjOp(a)
conjugate_a = False
if not transpose_a:
transpose = "NO_TRANSPOSE"
else:
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
return mhlo.TriangularSolveOp(
mlir.aval_to_ir_type(out_aval), a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
mhlo.TransposeAttr.get(transpose)).results
mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)
def _triangular_solve_cpu_translation_rule(
ctx, avals_in, avals_out, a, b, *, left_side, lower, transpose_a,
conjugate_a, unit_diagonal):
@ -1862,7 +1887,7 @@ def _schur_cpu_translation_rule(ctx, avals_in, avals_out, operand, *,
_cpu_gees = lapack.gees
if sort_eig_vals:
T, vs, sdim, info = _cpu_gees(
T, vs, _sdim, info = _cpu_gees(
c,
operand,
jobvs=compute_schur_vectors,
@ -1888,6 +1913,49 @@ def _schur_cpu_translation_rule(ctx, avals_in, avals_out, operand, *,
return output
def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals,
select_callable):
operand_aval, = ctx.avals_in
batch_dims = operand_aval.shape[:-2]
if sort_eig_vals:
T, vs, _sdim, info = lapack.gees_mhlo(
operand,
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable)
else:
T, vs, info = lapack.gees_mhlo(
operand,
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable)
ok = mlir.compare_mhlo(
info, mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32))),
"EQ", "SIGNED")
T = _broadcasting_select_mhlo(
mhlo.BroadcastInDimOp(
ir.RankedTensorType.get(batch_dims + (1, 1),
ir.IntegerType.get_signless(1)),
ok,
mlir.dense_int_elements(range(len(batch_dims)))).result,
T, _nan_like_mhlo(ctx.avals_out[0]))
output = [T]
if compute_schur_vectors:
vs = _broadcasting_select_mhlo(
mhlo.BroadcastInDimOp(
ir.RankedTensorType.get(batch_dims + (1, 1),
ir.IntegerType.get_signless(1)),
ok,
mlir.dense_int_elements(range(len(batch_dims)))).result,
vs, _nan_like_mhlo(ctx.avals_out[1]))
output.append(vs)
return output
def _schur_batching_rule(batched_args, batch_dims, *, compute_schur_vectors,
sort_eig_vals, select_callable):
@ -1914,6 +1982,9 @@ schur_p.def_impl(_schur_impl)
schur_p.def_abstract_eval(_schur_abstract_eval)
xla.register_translation(schur_p, _schur_translation_rule)
xla.register_translation(schur_p, _schur_cpu_translation_rule, platform='cpu')
mlir.register_lowering(schur_p, _schur_translation_rule)
if jax._src.lib.version >= (0, 3, 6):
mlir.register_lowering(schur_p, _schur_cpu_lowering, platform='cpu')
batching.primitive_batchers[schur_p] = _schur_batching_rule
ad.primitive_jvps[schur_p] = _schur_jvp_rule

View File

@ -1307,3 +1307,97 @@ def gees(c, a, jobvs=True, sort=False, select=None):
else:
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 3),
_ops.GetTupleElement(out, 5))
def gees_mhlo(a, jobvs=True, sort=False, select=None):
a_type = ir.RankedTensorType(a.type)
etype = a_type.element_type
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
layout = ir.DenseIntElementsAttr.get(
np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
type=ir.IndexType.get())
if sort:
raise NotImplementedError(
"The sort feature of LAPACK's gees routine is not implemented.")
jobvs = ord('V' if jobvs else 'N')
sort = ord('S' if sort else 'N')
if not ir.ComplexType.isinstance(etype):
fn = "lapack_sgees" if etype == ir.F32Type.get() else "lapack_dgees"
schurvecs_type = etype
workspaces = [ir.RankedTensorType.get(dims, schurvecs_type)]
workspace_layouts = [layout]
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)] * 2
eigvals_layouts = [
ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
type=ir.IndexType.get())
] * 2
else:
fn = ("lapack_cgees" if etype == ir.ComplexType.get(ir.F32Type.get())
else "lapack_zgees")
schurvecs_type = etype
workspaces = [
ir.RankedTensorType.get(dims, schurvecs_type),
ir.RankedTensorType.get([n], ir.ComplexType(etype).element_type),
]
workspace_layouts = [
layout,
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
]
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)]
eigvals_layouts = [
ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
type=ir.IndexType.get())
]
i32_type = ir.IntegerType.get_signless(32)
scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0,), np.int64),
type=ir.IndexType.get())
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple(workspaces + eigvals + [
ir.RankedTensorType.get(dims, schurvecs_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get(batch_dims, i32_type),
])],
[
_mhlo_s32(b),
_mhlo_s32(n),
_mhlo_u8(np.uint8(jobvs)),
_mhlo_u8(np.uint8(sort)),
# TODO: figure out how to put the callable select function here
a
],
call_target_name = ir.StringAttr.get(fn),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(""),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([scalar_layout] * 4 + [layout]),
result_layouts=ir.ArrayAttr.get(workspace_layouts + eigvals_layouts + [
layout,
ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1),
type=ir.IndexType.get()),
ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1),
type=ir.IndexType.get()),
])
)
i32_attr = lambda i: ir.IntegerAttr.get(i32_type, i)
if sort == ord('S'):
return (mhlo.GetTupleElementOp(out, i32_attr(0)).result,
mhlo.GetTupleElementOp(out, i32_attr(3)).result,
mhlo.GetTupleElementOp(out, i32_attr(4)).result,
mhlo.GetTupleElementOp(out, i32_attr(5)).result)
else:
return (mhlo.GetTupleElementOp(out, i32_attr(0)).result,
mhlo.GetTupleElementOp(out, i32_attr(3)).result,
mhlo.GetTupleElementOp(out, i32_attr(5)).result)