rocm_jax/jaxlib/lapack.py
Eugene Burmako a1480c454e Migrate JAX from producing MHLO to producing StableHLO
As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically:
  1) MLIR lowerings now produce StableHLO ops instead of MHLO ops.
  2) Fallback lowerings now produce StableHLO ops as well.
  3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs).

From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO):
  a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing.
  b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath.
  c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo".
  d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues.

PiperOrigin-RevId: 497978733
2022-12-27 08:53:20 -08:00

791 lines
25 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Shims that allow the XLA CPU backend to call scipy-provided LAPACK kernels
# via CustomCallWithLayout.
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.stablehlo as hlo
import numpy as np
from jaxlib import xla_client
from .hlo_helpers import custom_call
from .cpu import _lapack
for _name, _value in _lapack.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="cpu")
# Function that lazily initializes the LAPACK kernels in the runtime on first
# use.
_initialize = _lapack.initialize
def _hlo_u8(x):
return hlo.ConstantOp(
ir.DenseElementsAttr.get(
np.array(x, dtype=np.uint8),
type=ir.IntegerType.get_unsigned(8))).result
def _hlo_s32(x):
return hlo.ConstantOp(
ir.DenseElementsAttr.get(
np.array(x, dtype=np.int32),
type=ir.IntegerType.get_signless(32))).result
# TODO(phawkins): it would be nice to avoid duplicating code for each type.
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
conj_a=False, diag=False):
return trsm_hlo(dtype, alpha, a, b, left_side=left_side, lower=lower,
trans_a=trans_a, conj_a=conj_a, diag=diag)
# ?trsm(left_side, lower, trans_a, diag, m, n, alpha, a, b):
# triangular solve
def trsm_hlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
conj_a=False, diag=False):
_initialize()
a_type = ir.RankedTensorType(a.type)
b_type = ir.RankedTensorType(b.type)
dims = b_type.shape
m, n = dims[-2:]
k = m if left_side else n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
num_b = 1
for d in batch_dims:
num_b *= d
if (batch_dims + (k, k) != tuple(a_type.shape) or
a_type.element_type != b_type.element_type):
raise ValueError("Argument mismatch for trsm, got {} and {}".format(
a_type, b_type))
if dtype == np.float32:
fn = "blas_strsm"
elif dtype == np.float64:
fn = "blas_dtrsm"
elif dtype == np.complex64:
fn = "blas_ctrsm"
elif dtype == np.complex128:
fn = "blas_ztrsm"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
if conj_a and not trans_a:
raise NotImplementedError("Conjugation without transposition not supported")
scalar_layout = []
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
return custom_call(
fn,
[b.type],
[_hlo_s32(int(left_side)), _hlo_s32(int(lower)),
_hlo_s32((2 if conj_a else 1) if trans_a else 0), _hlo_s32(int(diag)),
_hlo_s32(m), _hlo_s32(n), _hlo_s32(num_b),
alpha, a, b],
operand_layouts=[scalar_layout] * 8 + [layout] * 2,
result_layouts=[layout],
operand_output_aliases={9: 0},
)
# # ?getrf: LU decomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def getrf_mhlo(dtype, a):
return getrf_hlo(dtype, a)
def getrf_hlo(dtype, a):
_initialize()
dims = ir.RankedTensorType(a.type).shape
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
if dtype == np.float32:
fn = b"lapack_sgetrf"
elif dtype == np.float64:
fn = b"lapack_dgetrf"
elif dtype == np.complex64:
fn = b"lapack_cgetrf"
elif dtype == np.complex128:
fn = b"lapack_zgetrf"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
scalar_layout = []
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
return custom_call(
fn,
[
a.type,
ir.RankedTensorType.get(batch_dims + (min(m, n),), i32_type),
ir.RankedTensorType.get(batch_dims, i32_type),
],
[_hlo_s32(int(b)), _hlo_s32(m), _hlo_s32(n), a],
operand_layouts=[scalar_layout] * 3 + [layout],
result_layouts=[
layout,
tuple(range(num_bd, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
],
operand_output_aliases={3: 0},
)
# # ?geqrf: QR decomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def geqrf_mhlo(dtype, a):
return geqrf_hlo(dtype, a)
def geqrf_hlo(dtype, a):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
if dtype == np.float32:
fn = b"lapack_sgeqrf"
lwork = _lapack.lapack_sgeqrf_workspace(m, n)
elif dtype == np.float64:
fn = b"lapack_dgeqrf"
lwork = _lapack.lapack_dgeqrf_workspace(m, n)
elif dtype == np.complex64:
fn = b"lapack_cgeqrf"
lwork = _lapack.lapack_cgeqrf_workspace(m, n)
elif dtype == np.complex128:
fn = b"lapack_zgeqrf"
lwork = _lapack.lapack_zgeqrf_workspace(m, n)
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
scalar_layout = []
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
out = custom_call(
fn,
[
a.type,
ir.RankedTensorType.get(batch_dims + (min(m, n),), a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
[_hlo_s32(int(b)), _hlo_s32(m), _hlo_s32(n), _hlo_s32(lwork), a],
operand_layouts=[scalar_layout] * 4 + [layout],
result_layouts=[
layout,
tuple(range(num_bd, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={4: 0},
)
return out[:3]
# # ?orgqr: product of elementary Householder reflectors:
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def orgqr_mhlo(dtype, a, tau):
return orgqr_hlo(dtype, a, tau)
def orgqr_hlo(dtype, a, tau):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
tau_dims = ir.RankedTensorType(tau.type).shape
assert tau_dims[:-1] == dims[:-2], (tau.type, a.type)
k = tau_dims[-1]
if dtype == np.float32:
fn = b"lapack_sorgqr"
lwork = _lapack.lapack_sorgqr_workspace(m, n, k)
elif dtype == np.float64:
fn = b"lapack_dorgqr"
lwork = _lapack.lapack_dorgqr_workspace(m, n, k)
elif dtype == np.complex64:
fn = b"lapack_cungqr"
lwork = _lapack.lapack_cungqr_workspace(m, n, k)
elif dtype == np.complex128:
fn = b"lapack_zungqr"
lwork = _lapack.lapack_zungqr_workspace(m, n, k)
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
scalar_layout = []
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
out = custom_call(
fn,
[
a.type,
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
[_hlo_s32(int(b)), _hlo_s32(m), _hlo_s32(n), _hlo_s32(k),
_hlo_s32(lwork), a, tau],
operand_layouts=[scalar_layout] * 5 + [
layout,
tuple(range(num_bd, -1, -1)),
],
result_layouts=[
layout,
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={5: 0},
)
return out[:2]
# ?potrf: Cholesky decomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def potrf_mhlo(dtype, a, lower=False):
return potrf_hlo(dtype, a, lower=lower)
def potrf_hlo(dtype, a, lower=False):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
m, n = dims[-2:]
if m != n:
raise ValueError(f"potrf expects a square matrix, got {a_type}")
if dtype == np.float32:
fn = b"lapack_spotrf"
elif dtype == np.float64:
fn = b"lapack_dpotrf"
elif dtype == np.complex64:
fn = b"lapack_cpotrf"
elif dtype == np.complex128:
fn = b"lapack_zpotrf"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
scalar_layout = []
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
info_layout = tuple(range(num_bd - 1, -1, -1))
out = custom_call(
fn,
[a.type,
ir.RankedTensorType.get(batch_dims, ir.IntegerType.get_signless(32))],
[_hlo_s32(int(lower)), _hlo_s32(b), _hlo_s32(n), a],
operand_layouts=[scalar_layout] * 3 + [layout],
result_layouts=[layout, info_layout],
operand_output_aliases={3: 0},
)
return out[:2]
# # ?gesdd: Singular value decomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def gesdd_mhlo(dtype, a, full_matrices=True, compute_uv=True):
return gesdd_hlo(dtype, a, full_matrices=full_matrices, compute_uv=compute_uv)
def gesdd_hlo(dtype, a, full_matrices=True, compute_uv=True):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
i32_type = ir.IntegerType.get_signless(32)
if dtype == np.float32:
fn = b"lapack_sgesdd"
singular_vals_type = ir.F32Type.get()
lwork = _lapack.sgesdd_work_size(m, n, compute_uv, full_matrices)
workspace = [
ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
]
workspace_layouts = [[0], [0]]
elif dtype == np.float64:
fn = b"lapack_dgesdd"
singular_vals_type = ir.F64Type.get()
lwork = _lapack.dgesdd_work_size(m, n, compute_uv, full_matrices)
workspace = [
ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
]
workspace_layouts = [[0], [0]]
elif dtype == np.complex64:
fn = b"lapack_cgesdd"
singular_vals_type = ir.F32Type.get()
lwork = _lapack.cgesdd_work_size(m, n, compute_uv, full_matrices)
workspace = [
ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type),
ir.RankedTensorType.get(
[_lapack.cgesdd_rwork_size(m, n, int(compute_uv))],
ir.F32Type.get()),
ir.RankedTensorType.get([lwork], a_type.element_type),
]
workspace_layouts = [[0], [0], [0]]
elif dtype == np.complex128:
fn = b"lapack_zgesdd"
singular_vals_type = ir.F64Type.get()
lwork = _lapack.zgesdd_work_size(m, n, compute_uv, full_matrices)
workspace = [
ir.RankedTensorType.get([_lapack.gesdd_iwork_size(m, n)], i32_type),
ir.RankedTensorType.get(
[_lapack.cgesdd_rwork_size(m, n, int(compute_uv))],
ir.F64Type.get()),
ir.RankedTensorType.get([lwork], a_type.element_type),
]
workspace_layouts = [[0], [0], [0]]
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
scalar_layout = []
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
out = custom_call(
fn,
[
a.type,
ir.RankedTensorType.get(batch_dims + (min(m, n),), singular_vals_type),
ir.RankedTensorType.get(
batch_dims + (m, m if full_matrices else min(m, n)),
a_type.element_type),
ir.RankedTensorType.get(
batch_dims + (n if full_matrices else min(m, n), n),
a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
] + workspace,
[_hlo_s32(int(full_matrices)), _hlo_s32(int(compute_uv)), _hlo_s32(b),
_hlo_s32(m), _hlo_s32(n), _hlo_s32(lwork), a],
operand_layouts=[scalar_layout] * 6 + [layout],
result_layouts=[
layout,
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
layout,
layout,
tuple(range(num_bd - 1, -1, -1)),
] + workspace_layouts,
operand_output_aliases={6: 0},
)
return out[1:5]
# # syevd: Symmetric eigendecomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def syevd_mhlo(dtype, a, lower=False):
return syevd_hlo(dtype, a, lower=lower)
def syevd_hlo(dtype, a, lower=False):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
if dtype == np.float32:
fn = b"lapack_ssyevd"
eigvals_type = ir.F32Type.get()
workspace = [
ir.RankedTensorType.get([_lapack.syevd_work_size(n)],
a_type.element_type),
ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type),
]
workspace_layouts = [[0], [0]]
elif dtype == np.float64:
fn = b"lapack_dsyevd"
eigvals_type = ir.F64Type.get()
workspace = [
ir.RankedTensorType.get([_lapack.syevd_work_size(n)],
a_type.element_type),
ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type),
]
workspace_layouts = [[0], [0]]
elif dtype == np.complex64:
fn = b"lapack_cheevd"
eigvals_type = ir.F32Type.get()
workspace = [
ir.RankedTensorType.get([_lapack.heevd_work_size(n)],
a_type.element_type),
ir.RankedTensorType.get([_lapack.heevd_rwork_size(n)], eigvals_type),
ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type),
]
workspace_layouts = [[0], [0], [0]]
elif dtype == np.complex128:
fn = b"lapack_zheevd"
eigvals_type = ir.F64Type.get()
workspace = [
ir.RankedTensorType.get([_lapack.heevd_work_size(n)],
a_type.element_type),
ir.RankedTensorType.get([_lapack.heevd_rwork_size(n)], eigvals_type),
ir.RankedTensorType.get([_lapack.syevd_iwork_size(n)], i32_type),
]
workspace_layouts = [[0], [0], [0]]
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
scalar_layout = []
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
out = custom_call(
fn,
[
a.type,
ir.RankedTensorType.get(batch_dims + (n,), eigvals_type),
ir.RankedTensorType.get(batch_dims, i32_type),
] + workspace,
[_hlo_s32(1 if lower else 0), _hlo_s32(b), _hlo_s32(n), a],
operand_layouts=[scalar_layout] * 3 + [layout],
result_layouts=[
layout,
tuple(range(num_bd, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
] + workspace_layouts,
operand_output_aliases={3: 0},
)
return out[:3]
# # geev: Nonsymmetric eigendecomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def geev_mhlo(dtype, a, jobvl=True, jobvr=True):
return geev_hlo(dtype, a, jobvl=jobvl, jobvr=jobvr)
def geev_hlo(dtype, a, jobvl=True, jobvr=True):
_initialize()
dims = ir.RankedTensorType(a.type).shape
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
jobvl_c = ord('V' if jobvl else 'N')
jobvr_c = ord('V' if jobvr else 'N')
if dtype == np.float32:
fn = b"lapack_sgeev"
real = True
eigvecs_type = ir.ComplexType.get(ir.F32Type.get())
workspaces = [ir.RankedTensorType.get([n, n], ir.F32Type.get()),
ir.RankedTensorType.get([n, n], ir.F32Type.get()),
ir.RankedTensorType.get([n, n], ir.F32Type.get())]
workspace_layouts = [[0, 1]] * 3
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), ir.F32Type.get()),
ir.RankedTensorType.get(batch_dims + (n,), ir.F32Type.get())]
eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2
elif dtype == np.float64:
fn = b"lapack_dgeev"
real = True
eigvecs_type = ir.ComplexType.get(ir.F64Type.get())
workspaces = [ir.RankedTensorType.get([n, n], ir.F64Type.get()),
ir.RankedTensorType.get([n, n], ir.F64Type.get()),
ir.RankedTensorType.get([n, n], ir.F64Type.get())]
workspace_layouts = [[0, 1]] * 3
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), ir.F64Type.get()),
ir.RankedTensorType.get(batch_dims + (n,), ir.F64Type.get())]
eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2
elif dtype == np.complex64:
fn = b"lapack_cgeev"
real = False
eigvecs_type = ir.ComplexType.get(ir.F32Type.get())
workspaces = [ir.RankedTensorType.get([n, n],
ir.ComplexType.get(ir.F32Type.get())),
ir.RankedTensorType.get([2 * n], ir.F32Type.get())]
workspace_layouts = [[0, 1], [0]]
eigvals = [ir.RankedTensorType.get(batch_dims + (n,),
ir.ComplexType.get(ir.F32Type.get()))]
eigvals_layouts = [tuple(range(num_bd, -1, -1))]
elif dtype == np.complex128:
fn = b"lapack_zgeev"
real = False
eigvecs_type = ir.ComplexType.get(ir.F64Type.get())
workspaces = [ir.RankedTensorType.get([n, n],
ir.ComplexType.get(ir.F64Type.get())),
ir.RankedTensorType.get([2 * n], ir.F64Type.get())]
workspace_layouts = [[0, 1], [0]]
eigvals = [ir.RankedTensorType.get(batch_dims + (n,),
ir.ComplexType.get(ir.F64Type.get()))]
eigvals_layouts = [tuple(range(num_bd, -1, -1))]
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
i32_type = ir.IntegerType.get_signless(32)
scalar_layout = []
info_layout = tuple(range(num_bd - 1, -1, -1))
out = custom_call(
fn,
workspaces + eigvals + [
ir.RankedTensorType.get(dims, eigvecs_type),
ir.RankedTensorType.get(dims, eigvecs_type),
ir.RankedTensorType.get(batch_dims, i32_type),
],
[_hlo_s32(b), _hlo_s32(n), _hlo_u8(jobvl_c), _hlo_u8(jobvr_c), a],
operand_layouts=[scalar_layout] * 4 + [layout],
result_layouts=(workspace_layouts + eigvals_layouts + [layout] * 2 +
[info_layout])
)
if real:
return (hlo.ComplexOp(out[3], out[4]).result, out[5], out[6], out[7])
else:
return out[2:6]
# # gees : Schur factorization
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None):
return gees_hlo(dtype, a, jobvs=jobvs, sort=sort, select=select)
def gees_hlo(dtype, a, jobvs=True, sort=False, select=None):
_initialize()
a_type = ir.RankedTensorType(a.type)
etype = a_type.element_type
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
if sort:
raise NotImplementedError(
"The sort feature of LAPACK's gees routine is not implemented.")
jobvs = ord('V' if jobvs else 'N')
sort = ord('S' if sort else 'N')
if dtype == np.float32:
fn = "lapack_sgees"
elif dtype == np.float64:
fn = "lapack_dgees"
elif dtype == np.complex64:
fn = "lapack_cgees"
elif dtype == np.complex128:
fn = "lapack_zgees"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
if not np.issubdtype(dtype, np.complexfloating):
workspaces = [ir.RankedTensorType.get(dims, etype)]
workspace_layouts = [layout]
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)] * 2
eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2
else:
workspaces = [
ir.RankedTensorType.get(dims, etype),
ir.RankedTensorType.get([n], ir.ComplexType(etype).element_type),
]
workspace_layouts = [layout, [0]]
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)]
eigvals_layouts = [tuple(range(num_bd, -1, -1))]
i32_type = ir.IntegerType.get_signless(32)
scalar_layout = []
out = custom_call(
fn,
workspaces + eigvals + [
ir.RankedTensorType.get(dims, etype),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get(batch_dims, i32_type),
],
[
_hlo_s32(b),
_hlo_s32(n),
_hlo_u8(np.uint8(jobvs)),
_hlo_u8(np.uint8(sort)),
# TODO: figure out how to put the callable select function here
a
],
operand_layouts=[scalar_layout] * 4 + [layout],
result_layouts=workspace_layouts + eigvals_layouts + [
layout,
tuple(range(num_bd - 1, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
],
operand_output_aliases={4: 0},
)
if sort == ord('S'):
return (out[0], out[3], out[4], out[5])
else:
return (out[0], out[3], out[5])
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def gehrd_mhlo(dtype, a):
return gehrd_hlo(dtype, a)
# gehrd: Reduction of a non-symmetric square matrix to upper Hessenberg form.
def gehrd_hlo(dtype, a):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n, (m, n)
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
if dtype == np.float32:
fn = b"lapack_sgehrd"
lwork = _lapack.lapack_sgehrd_workspace(n, n, 1, n)
elif dtype == np.float64:
fn = b"lapack_dgehrd"
lwork = _lapack.lapack_dgehrd_workspace(n, n, 1, n)
elif dtype == np.complex64:
fn = b"lapack_cgehrd"
lwork = _lapack.lapack_cgehrd_workspace(n, n, 1, n)
elif dtype == np.complex128:
fn = b"lapack_zgehrd"
lwork = _lapack.lapack_zgehrd_workspace(n, n, 1, n)
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
out = custom_call(
fn,
[
a.type,
ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
[_hlo_s32(n), _hlo_s32(1), _hlo_s32(n), _hlo_s32(n), _hlo_s32(b),
_hlo_s32(lwork), a],
operand_layouts=[[]] * 6 + [layout],
result_layouts=[
layout,
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={6: 0},
)
return out[:3]
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def sytrd_mhlo(dtype, a, *, lower):
return sytrd_hlo(dtype, a, lower=lower)
# sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form.
def sytrd_hlo(dtype, a, *, lower):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n, (m, n)
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
if dtype == np.float32:
fn = b"lapack_ssytrd"
lwork = _lapack.lapack_ssytrd_workspace(n, n)
diag_type = a_type.element_type
elif dtype == np.float64:
fn = b"lapack_dsytrd"
lwork = _lapack.lapack_dsytrd_workspace(n, n)
diag_type = a_type.element_type
elif dtype == np.complex64:
fn = b"lapack_chetrd"
lwork = _lapack.lapack_chetrd_workspace(n, n)
diag_type = ir.F32Type.get()
elif dtype == np.complex128:
fn = b"lapack_zhetrd"
lwork = _lapack.lapack_zhetrd_workspace(n, n)
diag_type = ir.F64Type.get()
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
out = custom_call(
fn,
[
a.type,
ir.RankedTensorType.get(batch_dims + (n,), diag_type),
ir.RankedTensorType.get(batch_dims + (n - 1,), diag_type),
ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
[_hlo_s32(n), _hlo_s32(1 if lower else 0), _hlo_s32(max(1, n)),
_hlo_s32(b), _hlo_s32(lwork), a],
operand_layouts=[[]] * 5 + [layout],
result_layouts=[
layout,
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={5: 0},
)
return out[:5]