Reapply: Use input-output aliasing for jaxlib GPU custom calls.

Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

It turns out some users are relying on the API contract of the custom calls within serialized HLO remaining stable. For the moment, we reapply only the Python changes. The C++ code is already tolerant of both aliased and unaliased outputs, and this gets us all the benefit of saving a copy. We can break backwards compatibility on the serialized HLO after users upgrade their saved HLO to the aliased version.

PiperOrigin-RevId: 480134780
This commit is contained in:
Peter Hawkins 2022-10-10 11:27:51 -07:00 committed by jax authors
parent 707b07c1e9
commit 22cd50535b
3 changed files with 62 additions and 32 deletions

View File

@ -98,7 +98,8 @@ def _trsm_mhlo(platform, gpu_blas, dtype, a, b, left_side=False, lower=False,
[a, b],
backend_config=opaque,
operand_layouts=[layout] * 2,
result_layouts=[layout, work_layout, work_layout])
result_layouts=[layout, work_layout, work_layout],
operand_output_aliases={1: 0})
return out[0]
cuda_trsm = partial(_trsm_mhlo, "cu", _cublas)
@ -133,7 +134,8 @@ def _potrf_mhlo(platform, gpu_solver, dtype, a, lower):
[a],
backend_config=opaque,
operand_layouts=[layout],
result_layouts=[layout, info_layout, work_layout])
result_layouts=[layout, info_layout, work_layout],
operand_output_aliases={0: 0})
return out[:2]
cuda_potrf = partial(_potrf_mhlo, "cu", _cusolver)
@ -179,7 +181,8 @@ def _getrf_mhlo(platform, gpu_blas, gpu_solver, dtype, a):
tuple(range(num_bd, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
])
],
operand_output_aliases={0: 0})
return out[:3]
cuda_getrf = partial(_getrf_mhlo, "cu", _cublas, _cusolver)
@ -217,7 +220,8 @@ def _geqrf_mhlo(platform, gpu_solver, dtype, a):
tuple(range(num_bd, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
])
],
operand_output_aliases={0: 0})
return out[:3]
cuda_geqrf = partial(_geqrf_mhlo, "cu", _cusolver)
@ -253,7 +257,9 @@ def _geqrf_batched_mhlo(platform, gpu_blas, dtype, a):
tuple(range(num_bd, -1, -1)),
[0],
[0],
])
],
operand_output_aliases={0: 0}
)
return out[:2]
cuda_geqrf_batched = partial(_geqrf_batched_mhlo, "cu", _cublas)
@ -321,7 +327,8 @@ def _orgqr_mhlo(platform, gpu_solver, dtype, a, tau):
layout,
tuple(range(num_bd - 1, -1, -1)),
[0],
])
],
operand_output_aliases={0: 0})
return out[:2]
cuda_orgqr = partial(_orgqr_mhlo, "cu", _cusolver)
@ -372,7 +379,8 @@ def _syevd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
tuple(range(num_bd, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
])
],
operand_output_aliases={0: 0})
return out[:3]
cuda_syevd = partial(_syevd_mhlo, "cu", _cusolver, True)
@ -427,7 +435,8 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
matrix_layout,
scalar_layout,
[0],
])
],
operand_output_aliases={0: 0})
vt = mhlo.TransposeOp(
v,
ir.DenseIntElementsAttr.get(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result
@ -469,7 +478,8 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
matrix_layout,
scalar_layout,
[0],
])
],
operand_output_aliases={0: 0})
else:
lwork, opaque = gpu_solver.build_gesvd_descriptor(
np.dtype(dtype), b, m, n, compute_uv, full_matrices)
@ -495,7 +505,8 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
matrix_layout,
scalar_layout,
[0],
])
],
operand_output_aliases={0: 0})
return s, u, vt, info
cuda_gesvd = partial(_gesvd_mhlo, "cu", _cusolver, True)

View File

@ -356,7 +356,8 @@ def _gtsv2_mhlo(platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t):
[dl, d, du, B],
backend_config=gpu_sparse.build_gtsv2_descriptor(m, n, ldb),
operand_layouts=[[0]] * 3 + [[1, 0]],
result_layouts=[[1, 0], [0]])
result_layouts=[[1, 0], [0]],
operand_output_aliases={3: 0})
return out[0]
cuda_gtsv2 = partial(_gtsv2_mhlo, "cu", _cusparse)

View File

@ -14,28 +14,37 @@
# Helpers for building MHLO operators
from typing import Optional, Sequence, Union
from typing import Dict, Optional, Sequence, Union
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as mhlo
import numpy as np
def custom_call(call_target_name: str, out_types: Sequence[ir.Type],
operands: Sequence[ir.Value],
operand_layouts: Sequence[Sequence[int]],
result_layouts: Sequence[Sequence[int]],
backend_config: Optional[str] = None,
has_side_effect: bool = False,
api_version: int = 2,
) -> Union[ir.Value, Sequence[ir.Value]]:
def custom_call(
call_target_name: str,
out_types: Sequence[ir.Type],
operands: Sequence[ir.Value],
operand_layouts: Sequence[Sequence[int]],
result_layouts: Sequence[Sequence[int]],
backend_config: Optional[str] = None,
has_side_effect: bool = False,
api_version: int = 2,
operand_output_aliases: Dict[int, int] = {},
) -> Union[ir.Value, Sequence[ir.Value]]:
"""Less-verbose helper for building an MHLO custom call op.
Once https://github.com/llvm/llvm-project/issues/54932 is fixed, this helper
may be able to go away.
Args:
...
operand_output_alias: a dictionary mapping input numbers -> output numbers
that must alias.
"""
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
(out_types if len(out_types) == 1 else
[ir.TupleType.get_tuple(out_types)]),
(out_types
if len(out_types) == 1 else [ir.TupleType.get_tuple(out_types)]),
operands,
call_target_name=ir.StringAttr.get(call_target_name),
has_side_effect=ir.BoolAttr.get(has_side_effect),
@ -43,18 +52,27 @@ def custom_call(call_target_name: str, out_types: Sequence[ir.Type],
"" if backend_config is None else backend_config),
api_version=ir.IntegerAttr.get(i32_type, api_version),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get(
[ir.DenseIntElementsAttr.get(
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(
np.atleast_1d(np.asarray(l, dtype=np.int64)),
type=ir.IndexType.get())
for l in operand_layouts]),
result_layouts=ir.ArrayAttr.get(
[ir.DenseIntElementsAttr.get(
type=ir.IndexType.get()) for l in operand_layouts
]),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(
np.atleast_1d(np.asarray(l, dtype=np.int64)),
type=ir.IndexType.get())
for l in result_layouts]))
type=ir.IndexType.get()) for l in result_layouts
]),
output_operand_aliases=ir.ArrayAttr.get([
mhlo.OutputOperandAlias.get(
output_tuple_indices=[] if len(out_types) == 1 else [output],
operand_index=input,
operand_tuple_indices=[])
for input, output in operand_output_aliases.items()
]))
if len(out_types) == 1:
return out.result
else:
return [mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
for i in range(len(out_types))]
return [
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
for i in range(len(out_types))
]