From cfef568dd63afb91a962b7f210d1ed2e513a90f4 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 29 Jan 2020 11:24:40 -0500 Subject: [PATCH] Implement jax.scipy.linalg.block_diag. (#2113) --- docs/jax.scipy.rst | 1 + jax/scipy/linalg.py | 24 +++++++++++++++++++++++- tests/linalg_test.py | 17 +++++++++++++++++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 9b3f07a59..336233bd7 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -9,6 +9,7 @@ jax.scipy.linalg .. autosummary:: :toctree: _autosummary + block_diag cho_factor cho_solve cholesky diff --git a/jax/scipy/linalg.py b/jax/scipy/linalg.py index b8e82357b..e285f2d4f 100644 --- a/jax/scipy/linalg.py +++ b/jax/scipy/linalg.py @@ -278,7 +278,7 @@ def _expm(A, upper_triangular=False): R = _solve_P_Q(P, Q, upper_triangular) R = _squaring(R, n_squarings) return R - + @jit def _calc_P_Q(A): A = np.asarray(A) @@ -379,3 +379,25 @@ def _pade13(A): U = np.dot(A,np.dot(A6, b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident) V = np.dot(A6, b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident return U,V + + +@_wraps(scipy.linalg.block_diag) +@jit +def block_diag(*arrs): + if len(arrs) == 0: + arrs = [np.zeros((1, 0))] + arrs = np._promote_dtypes(*arrs) + bad_shapes = [i for i, a in enumerate(arrs) if np.ndim(a) > 2] + if bad_shapes: + raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at " + "most 2 dimensions, got {} at argument {}." + .format(arrs[bad_shapes[0]], bad_shapes[0])) + arrs = [np.atleast_2d(a) for a in arrs] + acc = arrs[0] + dtype = lax.dtype(acc) + for a in arrs[1:]: + _, c = a.shape + a = lax.pad(a, dtype.type(0), ((0, 0, 0), (acc.shape[-1], 0, 0))) + acc = lax.pad(acc, dtype.type(0), ((0, 0, 0), (0, c, 0))) + acc = lax.concatenate([acc, a], dimension=0) + return acc diff --git a/tests/linalg_test.py b/tests/linalg_test.py index bb37d71e3..33012d963 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -682,6 +682,23 @@ class NumpyLinalgTest(jtu.JaxTestCase): class ScipyLinalgTest(jtu.JaxTestCase): + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_i={}".format(i), "args": args} + for i, args in enumerate([ + (), + (1,), + (7, -2), + (3, 4, 5), + (onp.ones((3, 4), dtype=np.float_), 5, + onp.random.randn(5, 2).astype(np.float_)), + ]))) + def testBlockDiag(self, args): + args_maker = lambda: args + self._CheckAgainstNumpy(osp.linalg.block_diag, jsp.linalg.block_diag, + args_maker, check_dtypes=True) + self._CompileAndCheck(jsp.linalg.block_diag, args_maker, check_dtypes=True) + + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),