mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Implement jax.scipy.linalg.block_diag. (#2113)
This commit is contained in:
parent
0904e5ff74
commit
cfef568dd6
@ -9,6 +9,7 @@ jax.scipy.linalg
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
block_diag
|
||||
cho_factor
|
||||
cho_solve
|
||||
cholesky
|
||||
|
@ -278,7 +278,7 @@ def _expm(A, upper_triangular=False):
|
||||
R = _solve_P_Q(P, Q, upper_triangular)
|
||||
R = _squaring(R, n_squarings)
|
||||
return R
|
||||
|
||||
|
||||
@jit
|
||||
def _calc_P_Q(A):
|
||||
A = np.asarray(A)
|
||||
@ -379,3 +379,25 @@ def _pade13(A):
|
||||
U = np.dot(A,np.dot(A6, b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
|
||||
V = np.dot(A6, b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
|
||||
return U,V
|
||||
|
||||
|
||||
@_wraps(scipy.linalg.block_diag)
|
||||
@jit
|
||||
def block_diag(*arrs):
|
||||
if len(arrs) == 0:
|
||||
arrs = [np.zeros((1, 0))]
|
||||
arrs = np._promote_dtypes(*arrs)
|
||||
bad_shapes = [i for i, a in enumerate(arrs) if np.ndim(a) > 2]
|
||||
if bad_shapes:
|
||||
raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at "
|
||||
"most 2 dimensions, got {} at argument {}."
|
||||
.format(arrs[bad_shapes[0]], bad_shapes[0]))
|
||||
arrs = [np.atleast_2d(a) for a in arrs]
|
||||
acc = arrs[0]
|
||||
dtype = lax.dtype(acc)
|
||||
for a in arrs[1:]:
|
||||
_, c = a.shape
|
||||
a = lax.pad(a, dtype.type(0), ((0, 0, 0), (acc.shape[-1], 0, 0)))
|
||||
acc = lax.pad(acc, dtype.type(0), ((0, 0, 0), (0, c, 0)))
|
||||
acc = lax.concatenate([acc, a], dimension=0)
|
||||
return acc
|
||||
|
@ -682,6 +682,23 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_i={}".format(i), "args": args}
|
||||
for i, args in enumerate([
|
||||
(),
|
||||
(1,),
|
||||
(7, -2),
|
||||
(3, 4, 5),
|
||||
(onp.ones((3, 4), dtype=np.float_), 5,
|
||||
onp.random.randn(5, 2).astype(np.float_)),
|
||||
])))
|
||||
def testBlockDiag(self, args):
|
||||
args_maker = lambda: args
|
||||
self._CheckAgainstNumpy(osp.linalg.block_diag, jsp.linalg.block_diag,
|
||||
args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jsp.linalg.block_diag, args_maker, check_dtypes=True)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
|
Loading…
x
Reference in New Issue
Block a user