mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Eigh primitive is now a customcall
PiperOrigin-RevId: 518074163
This commit is contained in:
parent
bf416a8b5c
commit
143dfcd74b
@ -597,19 +597,40 @@ def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
|
||||
w, v = operand, operand
|
||||
return w, v
|
||||
|
||||
def _eigh_jacobi_translation_rule(ctx, avals_in, avals_out, operand, *, lower,
|
||||
sort_eigenvalues):
|
||||
operand_aval, = avals_in
|
||||
|
||||
def _eigh_jacobi_lowering_rule(ctx, operand, lower, sort_eigenvalues):
|
||||
operand_aval, = ctx.avals_in
|
||||
if operand_aval.shape[-1] == 0:
|
||||
return [xops.Real(xops.Reshape(operand, operand_aval.shape[:-1])), operand]
|
||||
v, w = xops.Eigh(operand, lower=lower, sort_eigenvalues=sort_eigenvalues)
|
||||
return w, v
|
||||
reshape_aval = operand_aval.update(shape=operand_aval.shape[:-1])
|
||||
return [
|
||||
hlo.RealOp(mlir.reshape(ctx, operand, reshape_aval)).result,
|
||||
operand,
|
||||
]
|
||||
|
||||
eigvals_type = mlir.aval_to_ir_type(ctx.avals_out[0])
|
||||
eigvecs_type = mlir.aval_to_ir_type(ctx.avals_out[1])
|
||||
eigh_type = ir.TupleType.get_tuple([eigvecs_type, eigvals_type])
|
||||
|
||||
backend_config = f"{int(lower)},{int(sort_eigenvalues)},100,1e-6"
|
||||
op = hlo.CustomCallOp(
|
||||
[eigh_type],
|
||||
[operand],
|
||||
call_target_name=ir.StringAttr.get("Eigh"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(backend_config),
|
||||
api_version=mlir.i32_attr(1),
|
||||
)
|
||||
return (
|
||||
hlo.GetTupleElementOp(op, 1).result,
|
||||
hlo.GetTupleElementOp(op, 0).result,
|
||||
)
|
||||
|
||||
|
||||
eigh_jacobi_p = Primitive('eigh_jacobi')
|
||||
eigh_jacobi_p.multiple_results = True
|
||||
eigh_jacobi_p.def_impl(_eigh_jacobi_impl)
|
||||
eigh_jacobi_p.def_abstract_eval(_eigh_jacobi_abstract_eval)
|
||||
xla.register_translation(eigh_jacobi_p, _eigh_jacobi_translation_rule)
|
||||
mlir.register_lowering(eigh_jacobi_p, _eigh_jacobi_lowering_rule)
|
||||
|
||||
|
||||
def _eigh_impl(operand, *, lower, sort_eigenvalues):
|
||||
|
Loading…
x
Reference in New Issue
Block a user