Eigh primitive is now a customcall

PiperOrigin-RevId: 518074163
This commit is contained in:
Anish Tondwalkar 2023-03-20 14:16:37 -07:00 committed by jax authors
parent bf416a8b5c
commit 143dfcd74b

View File

@ -597,19 +597,40 @@ def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
w, v = operand, operand
return w, v
def _eigh_jacobi_translation_rule(ctx, avals_in, avals_out, operand, *, lower,
sort_eigenvalues):
operand_aval, = avals_in
def _eigh_jacobi_lowering_rule(ctx, operand, lower, sort_eigenvalues):
operand_aval, = ctx.avals_in
if operand_aval.shape[-1] == 0:
return [xops.Real(xops.Reshape(operand, operand_aval.shape[:-1])), operand]
v, w = xops.Eigh(operand, lower=lower, sort_eigenvalues=sort_eigenvalues)
return w, v
reshape_aval = operand_aval.update(shape=operand_aval.shape[:-1])
return [
hlo.RealOp(mlir.reshape(ctx, operand, reshape_aval)).result,
operand,
]
eigvals_type = mlir.aval_to_ir_type(ctx.avals_out[0])
eigvecs_type = mlir.aval_to_ir_type(ctx.avals_out[1])
eigh_type = ir.TupleType.get_tuple([eigvecs_type, eigvals_type])
backend_config = f"{int(lower)},{int(sort_eigenvalues)},100,1e-6"
op = hlo.CustomCallOp(
[eigh_type],
[operand],
call_target_name=ir.StringAttr.get("Eigh"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(backend_config),
api_version=mlir.i32_attr(1),
)
return (
hlo.GetTupleElementOp(op, 1).result,
hlo.GetTupleElementOp(op, 0).result,
)
eigh_jacobi_p = Primitive('eigh_jacobi')
eigh_jacobi_p.multiple_results = True
eigh_jacobi_p.def_impl(_eigh_jacobi_impl)
eigh_jacobi_p.def_abstract_eval(_eigh_jacobi_abstract_eval)
xla.register_translation(eigh_jacobi_p, _eigh_jacobi_translation_rule)
mlir.register_lowering(eigh_jacobi_p, _eigh_jacobi_lowering_rule)
def _eigh_impl(operand, *, lower, sort_eigenvalues):