diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index a63837ac9..1597c8c85 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1607,18 +1607,19 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals, operand_aval, = ctx.avals_in batch_dims = operand_aval.shape[:-2] - if sort_eig_vals: - T, vs, _sdim, info = lapack.gees_mhlo( - operand, - jobvs=compute_schur_vectors, - sort=sort_eig_vals, - select=select_callable) - else: - T, vs, info = lapack.gees_mhlo( - operand, - jobvs=compute_schur_vectors, - sort=sort_eig_vals, - select=select_callable) + # TODO(jakevdp): remove this try/except when minimum jaxlib >= 0.3.8 + try: + gees_result = lapack.gees_mhlo(operand_aval.dtype, operand, + jobvs=compute_schur_vectors, + sort=sort_eig_vals, + select=select_callable) + except TypeError: # API for jaxlib <= 0.3.7 + gees_result = lapack.gees_mhlo(operand, # pytype: disable=missing-parameter + jobvs=compute_schur_vectors, + sort=sort_eig_vals, + select=select_callable) + # Number of return values depends on value of sort_eig_vals. + T, vs, *_, info = gees_result ok = mlir.compare_mhlo( info, mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32))), diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index 083385819..33caf5b21 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -672,7 +672,7 @@ def geev_mhlo(dtype, a, jobvl=True, jobvr=True): # # gees : Schur factorization -def gees_mhlo(a, jobvs=True, sort=False, select=None): +def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None): a_type = ir.RankedTensorType(a.type) etype = a_type.element_type dims = a_type.shape @@ -695,10 +695,19 @@ def gees_mhlo(a, jobvs=True, sort=False, select=None): jobvs = ord('V' if jobvs else 'N') sort = ord('S' if sort else 'N') - if not ir.ComplexType.isinstance(etype): - fn = "lapack_sgees" if etype == ir.F32Type.get() else "lapack_dgees" - schurvecs_type = etype - workspaces = [ir.RankedTensorType.get(dims, schurvecs_type)] + if dtype == np.float32: + fn = "lapack_sgees" + elif dtype == np.float64: + fn = "lapack_dgees" + elif dtype == np.complex64: + fn = "lapack_cgees" + elif dtype == np.complex128: + fn = "lapack_zgees" + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") + + if not np.issubdtype(dtype, np.complexfloating): + workspaces = [ir.RankedTensorType.get(dims, etype)] workspace_layouts = [layout] eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)] * 2 eigvals_layouts = [ @@ -706,11 +715,8 @@ def gees_mhlo(a, jobvs=True, sort=False, select=None): type=ir.IndexType.get()) ] * 2 else: - fn = ("lapack_cgees" if etype == ir.ComplexType.get(ir.F32Type.get()) - else "lapack_zgees") - schurvecs_type = etype workspaces = [ - ir.RankedTensorType.get(dims, schurvecs_type), + ir.RankedTensorType.get(dims, etype), ir.RankedTensorType.get([n], ir.ComplexType(etype).element_type), ] workspace_layouts = [ @@ -729,7 +735,7 @@ def gees_mhlo(a, jobvs=True, sort=False, select=None): type=ir.IndexType.get()) out = mhlo.CustomCallOp( [ir.TupleType.get_tuple(workspaces + eigvals + [ - ir.RankedTensorType.get(dims, schurvecs_type), + ir.RankedTensorType.get(dims, etype), ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get(batch_dims, i32_type), ])], diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 575954774..f676092f1 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -41,6 +41,8 @@ T = lambda x: np.swapaxes(x, -1, -2) float_types = jtu.dtypes.floating complex_types = jtu.dtypes.complex +jaxlib_version = tuple(map(int, jax.lib.__version__.split('.'))) + class NumpyLinalgTest(jtu.JaxTestCase): @@ -719,6 +721,19 @@ class NumpyLinalgTest(jtu.JaxTestCase): qr = partial(jnp.linalg.qr, mode=mode) jtu.check_jvp(qr, partial(jvp, qr), (a,), atol=3e-3) + @unittest.skipIf(jaxlib_version < (0, 3, 8), "test requires jaxlib>=0.3.8") + @jtu.skip_on_devices("tpu") + def testQrInvalidDtypeCPU(self, shape=(5, 6), dtype=np.float16): + # Regression test for https://github.com/google/jax/issues/10530 + rng = jtu.rand_default(self.rng()) + arr = rng(shape, dtype) + if jtu.device_under_test() == 'cpu': + err, msg = NotImplementedError, "Unsupported dtype float16" + else: + err, msg = ValueError, r"Unsupported dtype dtype\('float16'\)" + with self.assertRaisesRegex(err, msg): + jnp.linalg.qr(arr) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format( jtu.format_shape_dtype_string(shape, dtype)),