Merge pull request #8260 from hawkinsp:unpack

PiperOrigin-RevId: 404001515
This commit is contained in:
jax authors 2021-10-18 10:50:44 -07:00
commit 9eb06800fe

View File

@ -37,23 +37,10 @@ except ImportError:
_ops = xla_client.ops
_Shape = xla_client.Shape
# TODO(phawkins): remove after we no longer need to support old jax releases.
def _unpack_builder(c):
# If `c` is a ComputationBuilder object, extracts the underlying XlaBuilder.
return getattr(c, "_builder", c)
def _real_type(dtype):
"""Returns the real equivalent of 'dtype'."""
if dtype == np.float32:
return np.float32
elif dtype == np.float64:
return np.float64
elif dtype == np.complex64:
return np.float32
elif dtype == np.complex128:
return np.float64
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
return np.finfo(dtype).dtype
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
@ -63,7 +50,6 @@ def trsm(c, a, b, left_side=False, lower=False, trans_a=False, conj_a=False,
XLA implements unbatched triangular solve directly, so we need only implement
the batched case."""
c = _unpack_builder(c)
b_shape = c.get_shape(b)
dtype = b_shape.element_type()
dims = b_shape.dimensions()
@ -105,7 +91,6 @@ def trsm(c, a, b, left_side=False, lower=False, trans_a=False, conj_a=False,
def potrf(c, a, lower):
"""Cholesky decomposition."""
c = _unpack_builder(c)
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
@ -141,7 +126,6 @@ def potrf(c, a, lower):
def getrf(c, a):
"""LU decomposition."""
c = _unpack_builder(c)
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
@ -187,7 +171,6 @@ def getrf(c, a):
def geqrf(c, a):
"""QR decomposition."""
c = _unpack_builder(c)
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
@ -227,7 +210,6 @@ def geqrf(c, a):
def orgqr(c, a, tau):
"""Product of elementary Householder reflections."""
c = _unpack_builder(c)
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
@ -273,8 +255,6 @@ def orgqr(c, a, tau):
def syevd(c, a, lower=False):
"""Symmetric (Hermitian) eigendecomposition."""
c = _unpack_builder(c)
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
@ -321,8 +301,6 @@ def syevd(c, a, lower=False):
def gesvd(c, a, full_matrices=True, compute_uv=True):
"""Singular value decomposition."""
c = _unpack_builder(c)
a_shape = c.get_shape(a)
dims = a_shape.dimensions()
dtype = a_shape.element_type()