Implement jax.scipy.linalg.block_diag. (#2113)

This commit is contained in:
Peter Hawkins 2020-01-29 11:24:40 -05:00 committed by GitHub
parent 0904e5ff74
commit cfef568dd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 1 deletions

View File

@ -9,6 +9,7 @@ jax.scipy.linalg
.. autosummary::
:toctree: _autosummary
block_diag
cho_factor
cho_solve
cholesky

View File

@ -278,7 +278,7 @@ def _expm(A, upper_triangular=False):
R = _solve_P_Q(P, Q, upper_triangular)
R = _squaring(R, n_squarings)
return R
@jit
def _calc_P_Q(A):
A = np.asarray(A)
@ -379,3 +379,25 @@ def _pade13(A):
U = np.dot(A,np.dot(A6, b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = np.dot(A6, b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V
@_wraps(scipy.linalg.block_diag)
@jit
def block_diag(*arrs):
if len(arrs) == 0:
arrs = [np.zeros((1, 0))]
arrs = np._promote_dtypes(*arrs)
bad_shapes = [i for i, a in enumerate(arrs) if np.ndim(a) > 2]
if bad_shapes:
raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at "
"most 2 dimensions, got {} at argument {}."
.format(arrs[bad_shapes[0]], bad_shapes[0]))
arrs = [np.atleast_2d(a) for a in arrs]
acc = arrs[0]
dtype = lax.dtype(acc)
for a in arrs[1:]:
_, c = a.shape
a = lax.pad(a, dtype.type(0), ((0, 0, 0), (acc.shape[-1], 0, 0)))
acc = lax.pad(acc, dtype.type(0), ((0, 0, 0), (0, c, 0)))
acc = lax.concatenate([acc, a], dimension=0)
return acc

View File

@ -682,6 +682,23 @@ class NumpyLinalgTest(jtu.JaxTestCase):
class ScipyLinalgTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_i={}".format(i), "args": args}
for i, args in enumerate([
(),
(1,),
(7, -2),
(3, 4, 5),
(onp.ones((3, 4), dtype=np.float_), 5,
onp.random.randn(5, 2).astype(np.float_)),
])))
def testBlockDiag(self, args):
args_maker = lambda: args
self._CheckAgainstNumpy(osp.linalg.block_diag, jsp.linalg.block_diag,
args_maker, check_dtypes=True)
self._CompileAndCheck(jsp.linalg.block_diag, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),