Remove *_mhlo compatibility shims from jaxlib

We introduced these shims when migrating from MHLO to StableHLO, and they helped accommodate the version skew between jaxlib and JAX across different environments. Now that a sufficient amount of time has passed, these shims are no longer used anywhere and can be deleted.

PiperOrigin-RevId: 510820007
This commit is contained in:
Eugene Burmako 2023-02-19 09:02:36 -08:00 committed by jax authors
parent 418c2f9d2a
commit f337c00ed5
2 changed files with 0 additions and 50 deletions

View File

@ -107,10 +107,6 @@ def _ducc_fft_descriptor(shape: List[int], dtype, fft_type: FftType,
return descriptor, out_dtype, out_shape
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def ducc_fft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
return ducc_fft_hlo(a, dtype, fft_type=fft_type, fft_lengths=fft_lengths)
def ducc_fft_hlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
"""DUCC FFT kernel for CPU."""
a_type = ir.RankedTensorType(a.type)

View File

@ -46,12 +46,6 @@ def _hlo_s32(x):
# TODO(phawkins): it would be nice to avoid duplicating code for each type.
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
conj_a=False, diag=False):
return trsm_hlo(dtype, alpha, a, b, left_side=left_side, lower=lower,
trans_a=trans_a, conj_a=conj_a, diag=diag)
# ?trsm(left_side, lower, trans_a, diag, m, n, alpha, a, b):
# triangular solve
def trsm_hlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
@ -105,10 +99,6 @@ def trsm_hlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
# # ?getrf: LU decomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def getrf_mhlo(dtype, a):
return getrf_hlo(dtype, a)
def getrf_hlo(dtype, a):
_initialize()
dims = ir.RankedTensorType(a.type).shape
@ -154,10 +144,6 @@ def getrf_hlo(dtype, a):
# # ?geqrf: QR decomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def geqrf_mhlo(dtype, a):
return geqrf_hlo(dtype, a)
def geqrf_hlo(dtype, a):
_initialize()
a_type = ir.RankedTensorType(a.type)
@ -211,10 +197,6 @@ def geqrf_hlo(dtype, a):
# # ?orgqr: product of elementary Householder reflectors:
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def orgqr_mhlo(dtype, a, tau):
return orgqr_hlo(dtype, a, tau)
def orgqr_hlo(dtype, a, tau):
_initialize()
a_type = ir.RankedTensorType(a.type)
@ -274,10 +256,6 @@ def orgqr_hlo(dtype, a, tau):
# ?potrf: Cholesky decomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def potrf_mhlo(dtype, a, lower=False):
return potrf_hlo(dtype, a, lower=lower)
def potrf_hlo(dtype, a, lower=False):
_initialize()
a_type = ir.RankedTensorType(a.type)
@ -319,10 +297,6 @@ def potrf_hlo(dtype, a, lower=False):
# # ?gesdd: Singular value decomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def gesdd_mhlo(dtype, a, full_matrices=True, compute_uv=True):
return gesdd_hlo(dtype, a, full_matrices=full_matrices, compute_uv=compute_uv)
def gesdd_hlo(dtype, a, full_matrices=True, compute_uv=True):
_initialize()
a_type = ir.RankedTensorType(a.type)
@ -413,10 +387,6 @@ def gesdd_hlo(dtype, a, full_matrices=True, compute_uv=True):
# # syevd: Symmetric eigendecomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def syevd_mhlo(dtype, a, lower=False):
return syevd_hlo(dtype, a, lower=lower)
def syevd_hlo(dtype, a, lower=False):
_initialize()
a_type = ir.RankedTensorType(a.type)
@ -496,10 +466,6 @@ def syevd_hlo(dtype, a, lower=False):
# # geev: Nonsymmetric eigendecomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def geev_mhlo(dtype, a, jobvl=True, jobvr=True):
return geev_hlo(dtype, a, jobvl=jobvl, jobvr=jobvr)
def geev_hlo(dtype, a, jobvl=True, jobvr=True):
_initialize()
dims = ir.RankedTensorType(a.type).shape
@ -585,10 +551,6 @@ def geev_hlo(dtype, a, jobvl=True, jobvr=True):
# # gees : Schur factorization
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None):
return gees_hlo(dtype, a, jobvs=jobvs, sort=sort, select=select)
def gees_hlo(dtype, a, jobvs=True, sort=False, select=None):
_initialize()
a_type = ir.RankedTensorType(a.type)
@ -668,10 +630,6 @@ def gees_hlo(dtype, a, jobvs=True, sort=False, select=None):
return (out[0], out[3], out[5])
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def gehrd_mhlo(dtype, a):
return gehrd_hlo(dtype, a)
# gehrd: Reduction of a non-symmetric square matrix to upper Hessenberg form.
def gehrd_hlo(dtype, a):
_initialize()
@ -725,10 +683,6 @@ def gehrd_hlo(dtype, a):
return out[:3]
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def sytrd_mhlo(dtype, a, *, lower):
return sytrd_hlo(dtype, a, lower=lower)
# sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form.
def sytrd_hlo(dtype, a, *, lower):
_initialize()