mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[MHLO] Add MHLO lowerings for triangular_solve, cholesky, and schur.
PiperOrigin-RevId: 441769591
This commit is contained in:
parent
4806c29bf7
commit
6c1461b52b
@ -348,6 +348,13 @@ cholesky_p = standard_unop(_float | _complex, 'cholesky')
|
||||
ad.primitive_jvps[cholesky_p] = cholesky_jvp_rule
|
||||
batching.primitive_batchers[cholesky_p] = cholesky_batching_rule
|
||||
|
||||
def _cholesky_lowering(ctx, x):
|
||||
aval, = ctx.avals_out
|
||||
return mhlo.CholeskyOp(mlir.aval_to_ir_type(aval), x,
|
||||
lower=ir.BoolAttr.get(True)).results
|
||||
|
||||
mlir.register_lowering(cholesky_p, _cholesky_lowering)
|
||||
|
||||
def _cholesky_cpu_gpu_translation_rule(potrf_impl, ctx, avals_in, avals_out,
|
||||
operand):
|
||||
operand_aval, = avals_in
|
||||
@ -787,6 +794,24 @@ ad.primitive_transposes[triangular_solve_p] = triangular_solve_transpose_rule
|
||||
batching.primitive_batchers[triangular_solve_p] = triangular_solve_batching_rule
|
||||
|
||||
|
||||
def _triangular_solve_lowering(
|
||||
ctx, a, b, *, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
|
||||
out_aval, = ctx.avals_out
|
||||
if conjugate_a and not transpose_a:
|
||||
a = chlo.ConjOp(a)
|
||||
conjugate_a = False
|
||||
if not transpose_a:
|
||||
transpose = "NO_TRANSPOSE"
|
||||
else:
|
||||
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
|
||||
return mhlo.TriangularSolveOp(
|
||||
mlir.aval_to_ir_type(out_aval), a, b, ir.BoolAttr.get(left_side),
|
||||
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
|
||||
mhlo.TransposeAttr.get(transpose)).results
|
||||
|
||||
mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)
|
||||
|
||||
|
||||
def _triangular_solve_cpu_translation_rule(
|
||||
ctx, avals_in, avals_out, a, b, *, left_side, lower, transpose_a,
|
||||
conjugate_a, unit_diagonal):
|
||||
@ -1862,7 +1887,7 @@ def _schur_cpu_translation_rule(ctx, avals_in, avals_out, operand, *,
|
||||
_cpu_gees = lapack.gees
|
||||
|
||||
if sort_eig_vals:
|
||||
T, vs, sdim, info = _cpu_gees(
|
||||
T, vs, _sdim, info = _cpu_gees(
|
||||
c,
|
||||
operand,
|
||||
jobvs=compute_schur_vectors,
|
||||
@ -1888,6 +1913,49 @@ def _schur_cpu_translation_rule(ctx, avals_in, avals_out, operand, *,
|
||||
|
||||
return output
|
||||
|
||||
def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals,
|
||||
select_callable):
|
||||
operand_aval, = ctx.avals_in
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
|
||||
if sort_eig_vals:
|
||||
T, vs, _sdim, info = lapack.gees_mhlo(
|
||||
operand,
|
||||
jobvs=compute_schur_vectors,
|
||||
sort=sort_eig_vals,
|
||||
select=select_callable)
|
||||
else:
|
||||
T, vs, info = lapack.gees_mhlo(
|
||||
operand,
|
||||
jobvs=compute_schur_vectors,
|
||||
sort=sort_eig_vals,
|
||||
select=select_callable)
|
||||
|
||||
ok = mlir.compare_mhlo(
|
||||
info, mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
||||
"EQ", "SIGNED")
|
||||
|
||||
T = _broadcasting_select_mhlo(
|
||||
mhlo.BroadcastInDimOp(
|
||||
ir.RankedTensorType.get(batch_dims + (1, 1),
|
||||
ir.IntegerType.get_signless(1)),
|
||||
ok,
|
||||
mlir.dense_int_elements(range(len(batch_dims)))).result,
|
||||
T, _nan_like_mhlo(ctx.avals_out[0]))
|
||||
output = [T]
|
||||
if compute_schur_vectors:
|
||||
vs = _broadcasting_select_mhlo(
|
||||
mhlo.BroadcastInDimOp(
|
||||
ir.RankedTensorType.get(batch_dims + (1, 1),
|
||||
ir.IntegerType.get_signless(1)),
|
||||
ok,
|
||||
mlir.dense_int_elements(range(len(batch_dims)))).result,
|
||||
vs, _nan_like_mhlo(ctx.avals_out[1]))
|
||||
|
||||
output.append(vs)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _schur_batching_rule(batched_args, batch_dims, *, compute_schur_vectors,
|
||||
sort_eig_vals, select_callable):
|
||||
@ -1914,6 +1982,9 @@ schur_p.def_impl(_schur_impl)
|
||||
schur_p.def_abstract_eval(_schur_abstract_eval)
|
||||
xla.register_translation(schur_p, _schur_translation_rule)
|
||||
xla.register_translation(schur_p, _schur_cpu_translation_rule, platform='cpu')
|
||||
mlir.register_lowering(schur_p, _schur_translation_rule)
|
||||
if jax._src.lib.version >= (0, 3, 6):
|
||||
mlir.register_lowering(schur_p, _schur_cpu_lowering, platform='cpu')
|
||||
batching.primitive_batchers[schur_p] = _schur_batching_rule
|
||||
ad.primitive_jvps[schur_p] = _schur_jvp_rule
|
||||
|
||||
|
@ -1307,3 +1307,97 @@ def gees(c, a, jobvs=True, sort=False, select=None):
|
||||
else:
|
||||
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 3),
|
||||
_ops.GetTupleElement(out, 5))
|
||||
|
||||
def gees_mhlo(a, jobvs=True, sort=False, select=None):
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
etype = a_type.element_type
|
||||
dims = a_type.shape
|
||||
assert len(dims) >= 2
|
||||
m, n = dims[-2:]
|
||||
assert m == n
|
||||
batch_dims = tuple(dims[:-2])
|
||||
num_bd = len(batch_dims)
|
||||
b = 1
|
||||
for d in batch_dims:
|
||||
b *= d
|
||||
layout = ir.DenseIntElementsAttr.get(
|
||||
np.array((num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
|
||||
type=ir.IndexType.get())
|
||||
|
||||
if sort:
|
||||
raise NotImplementedError(
|
||||
"The sort feature of LAPACK's gees routine is not implemented.")
|
||||
|
||||
jobvs = ord('V' if jobvs else 'N')
|
||||
sort = ord('S' if sort else 'N')
|
||||
|
||||
if not ir.ComplexType.isinstance(etype):
|
||||
fn = "lapack_sgees" if etype == ir.F32Type.get() else "lapack_dgees"
|
||||
schurvecs_type = etype
|
||||
workspaces = [ir.RankedTensorType.get(dims, schurvecs_type)]
|
||||
workspace_layouts = [layout]
|
||||
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)] * 2
|
||||
eigvals_layouts = [
|
||||
ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
|
||||
type=ir.IndexType.get())
|
||||
] * 2
|
||||
else:
|
||||
fn = ("lapack_cgees" if etype == ir.ComplexType.get(ir.F32Type.get())
|
||||
else "lapack_zgees")
|
||||
schurvecs_type = etype
|
||||
workspaces = [
|
||||
ir.RankedTensorType.get(dims, schurvecs_type),
|
||||
ir.RankedTensorType.get([n], ir.ComplexType(etype).element_type),
|
||||
]
|
||||
workspace_layouts = [
|
||||
layout,
|
||||
ir.DenseIntElementsAttr.get(np.array([0]), type=ir.IndexType.get()),
|
||||
]
|
||||
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)]
|
||||
eigvals_layouts = [
|
||||
ir.DenseIntElementsAttr.get(np.arange(num_bd, -1, -1),
|
||||
type=ir.IndexType.get())
|
||||
]
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
|
||||
scalar_layout = ir.DenseIntElementsAttr.get(np.zeros((0,), np.int64),
|
||||
type=ir.IndexType.get())
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple(workspaces + eigvals + [
|
||||
ir.RankedTensorType.get(dims, schurvecs_type),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
])],
|
||||
[
|
||||
_mhlo_s32(b),
|
||||
_mhlo_s32(n),
|
||||
_mhlo_u8(np.uint8(jobvs)),
|
||||
_mhlo_u8(np.uint8(sort)),
|
||||
# TODO: figure out how to put the callable select function here
|
||||
a
|
||||
],
|
||||
call_target_name = ir.StringAttr.get(fn),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(""),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=ir.ArrayAttr.get([scalar_layout] * 4 + [layout]),
|
||||
result_layouts=ir.ArrayAttr.get(workspace_layouts + eigvals_layouts + [
|
||||
layout,
|
||||
ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1),
|
||||
type=ir.IndexType.get()),
|
||||
ir.DenseIntElementsAttr.get(np.arange(num_bd - 1, -1, -1),
|
||||
type=ir.IndexType.get()),
|
||||
])
|
||||
)
|
||||
i32_attr = lambda i: ir.IntegerAttr.get(i32_type, i)
|
||||
if sort == ord('S'):
|
||||
return (mhlo.GetTupleElementOp(out, i32_attr(0)).result,
|
||||
mhlo.GetTupleElementOp(out, i32_attr(3)).result,
|
||||
mhlo.GetTupleElementOp(out, i32_attr(4)).result,
|
||||
mhlo.GetTupleElementOp(out, i32_attr(5)).result)
|
||||
else:
|
||||
return (mhlo.GetTupleElementOp(out, i32_attr(0)).result,
|
||||
mhlo.GetTupleElementOp(out, i32_attr(3)).result,
|
||||
mhlo.GetTupleElementOp(out, i32_attr(5)).result)
|
||||
|
Loading…
x
Reference in New Issue
Block a user