mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Remove *_mhlo compatibility shims from jaxlib
We introduced these shims when migrating from MHLO to StableHLO, and they helped accommodate the version skew between jaxlib and JAX across different environments. Now that a sufficient amount of time has passed, these shims are no longer used anywhere and can be deleted. PiperOrigin-RevId: 510820007
This commit is contained in:
parent
418c2f9d2a
commit
f337c00ed5
@ -107,10 +107,6 @@ def _ducc_fft_descriptor(shape: List[int], dtype, fft_type: FftType,
|
||||
return descriptor, out_dtype, out_shape
|
||||
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def ducc_fft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
|
||||
return ducc_fft_hlo(a, dtype, fft_type=fft_type, fft_lengths=fft_lengths)
|
||||
|
||||
def ducc_fft_hlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
|
||||
"""DUCC FFT kernel for CPU."""
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
|
@ -46,12 +46,6 @@ def _hlo_s32(x):
|
||||
|
||||
# TODO(phawkins): it would be nice to avoid duplicating code for each type.
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
|
||||
conj_a=False, diag=False):
|
||||
return trsm_hlo(dtype, alpha, a, b, left_side=left_side, lower=lower,
|
||||
trans_a=trans_a, conj_a=conj_a, diag=diag)
|
||||
|
||||
# ?trsm(left_side, lower, trans_a, diag, m, n, alpha, a, b):
|
||||
# triangular solve
|
||||
def trsm_hlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
|
||||
@ -105,10 +99,6 @@ def trsm_hlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
|
||||
|
||||
# # ?getrf: LU decomposition
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def getrf_mhlo(dtype, a):
|
||||
return getrf_hlo(dtype, a)
|
||||
|
||||
def getrf_hlo(dtype, a):
|
||||
_initialize()
|
||||
dims = ir.RankedTensorType(a.type).shape
|
||||
@ -154,10 +144,6 @@ def getrf_hlo(dtype, a):
|
||||
|
||||
# # ?geqrf: QR decomposition
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def geqrf_mhlo(dtype, a):
|
||||
return geqrf_hlo(dtype, a)
|
||||
|
||||
def geqrf_hlo(dtype, a):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
@ -211,10 +197,6 @@ def geqrf_hlo(dtype, a):
|
||||
|
||||
# # ?orgqr: product of elementary Householder reflectors:
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def orgqr_mhlo(dtype, a, tau):
|
||||
return orgqr_hlo(dtype, a, tau)
|
||||
|
||||
def orgqr_hlo(dtype, a, tau):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
@ -274,10 +256,6 @@ def orgqr_hlo(dtype, a, tau):
|
||||
|
||||
# ?potrf: Cholesky decomposition
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def potrf_mhlo(dtype, a, lower=False):
|
||||
return potrf_hlo(dtype, a, lower=lower)
|
||||
|
||||
def potrf_hlo(dtype, a, lower=False):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
@ -319,10 +297,6 @@ def potrf_hlo(dtype, a, lower=False):
|
||||
|
||||
# # ?gesdd: Singular value decomposition
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def gesdd_mhlo(dtype, a, full_matrices=True, compute_uv=True):
|
||||
return gesdd_hlo(dtype, a, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
|
||||
def gesdd_hlo(dtype, a, full_matrices=True, compute_uv=True):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
@ -413,10 +387,6 @@ def gesdd_hlo(dtype, a, full_matrices=True, compute_uv=True):
|
||||
|
||||
# # syevd: Symmetric eigendecomposition
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def syevd_mhlo(dtype, a, lower=False):
|
||||
return syevd_hlo(dtype, a, lower=lower)
|
||||
|
||||
def syevd_hlo(dtype, a, lower=False):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
@ -496,10 +466,6 @@ def syevd_hlo(dtype, a, lower=False):
|
||||
|
||||
# # geev: Nonsymmetric eigendecomposition
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def geev_mhlo(dtype, a, jobvl=True, jobvr=True):
|
||||
return geev_hlo(dtype, a, jobvl=jobvl, jobvr=jobvr)
|
||||
|
||||
def geev_hlo(dtype, a, jobvl=True, jobvr=True):
|
||||
_initialize()
|
||||
dims = ir.RankedTensorType(a.type).shape
|
||||
@ -585,10 +551,6 @@ def geev_hlo(dtype, a, jobvl=True, jobvr=True):
|
||||
|
||||
# # gees : Schur factorization
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None):
|
||||
return gees_hlo(dtype, a, jobvs=jobvs, sort=sort, select=select)
|
||||
|
||||
def gees_hlo(dtype, a, jobvs=True, sort=False, select=None):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
@ -668,10 +630,6 @@ def gees_hlo(dtype, a, jobvs=True, sort=False, select=None):
|
||||
return (out[0], out[3], out[5])
|
||||
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def gehrd_mhlo(dtype, a):
|
||||
return gehrd_hlo(dtype, a)
|
||||
|
||||
# gehrd: Reduction of a non-symmetric square matrix to upper Hessenberg form.
|
||||
def gehrd_hlo(dtype, a):
|
||||
_initialize()
|
||||
@ -725,10 +683,6 @@ def gehrd_hlo(dtype, a):
|
||||
return out[:3]
|
||||
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def sytrd_mhlo(dtype, a, *, lower):
|
||||
return sytrd_hlo(dtype, a, lower=lower)
|
||||
|
||||
# sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form.
|
||||
def sytrd_hlo(dtype, a, *, lower):
|
||||
_initialize()
|
||||
|
Loading…
x
Reference in New Issue
Block a user