mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #8260 from hawkinsp:unpack
PiperOrigin-RevId: 404001515
This commit is contained in:
commit
9eb06800fe
@ -37,23 +37,10 @@ except ImportError:
|
||||
_ops = xla_client.ops
|
||||
_Shape = xla_client.Shape
|
||||
|
||||
# TODO(phawkins): remove after we no longer need to support old jax releases.
|
||||
def _unpack_builder(c):
|
||||
# If `c` is a ComputationBuilder object, extracts the underlying XlaBuilder.
|
||||
return getattr(c, "_builder", c)
|
||||
|
||||
def _real_type(dtype):
|
||||
"""Returns the real equivalent of 'dtype'."""
|
||||
if dtype == np.float32:
|
||||
return np.float32
|
||||
elif dtype == np.float64:
|
||||
return np.float64
|
||||
elif dtype == np.complex64:
|
||||
return np.float32
|
||||
elif dtype == np.complex128:
|
||||
return np.float64
|
||||
else:
|
||||
raise NotImplementedError("Unsupported dtype {}".format(dtype))
|
||||
return np.finfo(dtype).dtype
|
||||
|
||||
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
|
||||
|
||||
@ -63,7 +50,6 @@ def trsm(c, a, b, left_side=False, lower=False, trans_a=False, conj_a=False,
|
||||
|
||||
XLA implements unbatched triangular solve directly, so we need only implement
|
||||
the batched case."""
|
||||
c = _unpack_builder(c)
|
||||
b_shape = c.get_shape(b)
|
||||
dtype = b_shape.element_type()
|
||||
dims = b_shape.dimensions()
|
||||
@ -105,7 +91,6 @@ def trsm(c, a, b, left_side=False, lower=False, trans_a=False, conj_a=False,
|
||||
|
||||
def potrf(c, a, lower):
|
||||
"""Cholesky decomposition."""
|
||||
c = _unpack_builder(c)
|
||||
a_shape = c.get_shape(a)
|
||||
dtype = a_shape.element_type()
|
||||
dims = a_shape.dimensions()
|
||||
@ -141,7 +126,6 @@ def potrf(c, a, lower):
|
||||
|
||||
def getrf(c, a):
|
||||
"""LU decomposition."""
|
||||
c = _unpack_builder(c)
|
||||
a_shape = c.get_shape(a)
|
||||
dtype = a_shape.element_type()
|
||||
dims = a_shape.dimensions()
|
||||
@ -187,7 +171,6 @@ def getrf(c, a):
|
||||
|
||||
def geqrf(c, a):
|
||||
"""QR decomposition."""
|
||||
c = _unpack_builder(c)
|
||||
a_shape = c.get_shape(a)
|
||||
dtype = a_shape.element_type()
|
||||
dims = a_shape.dimensions()
|
||||
@ -227,7 +210,6 @@ def geqrf(c, a):
|
||||
|
||||
def orgqr(c, a, tau):
|
||||
"""Product of elementary Householder reflections."""
|
||||
c = _unpack_builder(c)
|
||||
a_shape = c.get_shape(a)
|
||||
dtype = a_shape.element_type()
|
||||
dims = a_shape.dimensions()
|
||||
@ -273,8 +255,6 @@ def orgqr(c, a, tau):
|
||||
|
||||
def syevd(c, a, lower=False):
|
||||
"""Symmetric (Hermitian) eigendecomposition."""
|
||||
c = _unpack_builder(c)
|
||||
|
||||
a_shape = c.get_shape(a)
|
||||
dtype = a_shape.element_type()
|
||||
dims = a_shape.dimensions()
|
||||
@ -321,8 +301,6 @@ def syevd(c, a, lower=False):
|
||||
|
||||
def gesvd(c, a, full_matrices=True, compute_uv=True):
|
||||
"""Singular value decomposition."""
|
||||
c = _unpack_builder(c)
|
||||
|
||||
a_shape = c.get_shape(a)
|
||||
dims = a_shape.dimensions()
|
||||
dtype = a_shape.element_type()
|
||||
|
Loading…
x
Reference in New Issue
Block a user