jax.scipy.linalg.schur: error on 16-bit floats

Fixes https://github.com/google/jax/issues/10530

PiperOrigin-RevId: 446279906
This commit is contained in:
Jake VanderPlas 2022-05-03 13:47:11 -07:00 committed by jax authors
parent 37ea024d39
commit c6343ddf8e
3 changed files with 44 additions and 22 deletions

View File

@ -1607,18 +1607,19 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals,
operand_aval, = ctx.avals_in
batch_dims = operand_aval.shape[:-2]
if sort_eig_vals:
T, vs, _sdim, info = lapack.gees_mhlo(
operand,
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable)
else:
T, vs, info = lapack.gees_mhlo(
operand,
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable)
# TODO(jakevdp): remove this try/except when minimum jaxlib >= 0.3.8
try:
gees_result = lapack.gees_mhlo(operand_aval.dtype, operand,
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable)
except TypeError: # API for jaxlib <= 0.3.7
gees_result = lapack.gees_mhlo(operand, # pytype: disable=missing-parameter
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable)
# Number of return values depends on value of sort_eig_vals.
T, vs, *_, info = gees_result
ok = mlir.compare_mhlo(
info, mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32))),

View File

@ -672,7 +672,7 @@ def geev_mhlo(dtype, a, jobvl=True, jobvr=True):
# # gees : Schur factorization
def gees_mhlo(a, jobvs=True, sort=False, select=None):
def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None):
a_type = ir.RankedTensorType(a.type)
etype = a_type.element_type
dims = a_type.shape
@ -695,10 +695,19 @@ def gees_mhlo(a, jobvs=True, sort=False, select=None):
jobvs = ord('V' if jobvs else 'N')
sort = ord('S' if sort else 'N')
if not ir.ComplexType.isinstance(etype):
fn = "lapack_sgees" if etype == ir.F32Type.get() else "lapack_dgees"
schurvecs_type = etype
workspaces = [ir.RankedTensorType.get(dims, schurvecs_type)]
if dtype == np.float32:
fn = "lapack_sgees"
elif dtype == np.float64:
fn = "lapack_dgees"
elif dtype == np.complex64:
fn = "lapack_cgees"
elif dtype == np.complex128:
fn = "lapack_zgees"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
if not np.issubdtype(dtype, np.complexfloating):
workspaces = [ir.RankedTensorType.get(dims, etype)]
workspace_layouts = [layout]
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)] * 2
eigvals_layouts = [
@ -706,11 +715,8 @@ def gees_mhlo(a, jobvs=True, sort=False, select=None):
type=ir.IndexType.get())
] * 2
else:
fn = ("lapack_cgees" if etype == ir.ComplexType.get(ir.F32Type.get())
else "lapack_zgees")
schurvecs_type = etype
workspaces = [
ir.RankedTensorType.get(dims, schurvecs_type),
ir.RankedTensorType.get(dims, etype),
ir.RankedTensorType.get([n], ir.ComplexType(etype).element_type),
]
workspace_layouts = [
@ -729,7 +735,7 @@ def gees_mhlo(a, jobvs=True, sort=False, select=None):
type=ir.IndexType.get())
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple(workspaces + eigvals + [
ir.RankedTensorType.get(dims, schurvecs_type),
ir.RankedTensorType.get(dims, etype),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get(batch_dims, i32_type),
])],

View File

@ -41,6 +41,8 @@ T = lambda x: np.swapaxes(x, -1, -2)
float_types = jtu.dtypes.floating
complex_types = jtu.dtypes.complex
jaxlib_version = tuple(map(int, jax.lib.__version__.split('.')))
class NumpyLinalgTest(jtu.JaxTestCase):
@ -719,6 +721,19 @@ class NumpyLinalgTest(jtu.JaxTestCase):
qr = partial(jnp.linalg.qr, mode=mode)
jtu.check_jvp(qr, partial(jvp, qr), (a,), atol=3e-3)
@unittest.skipIf(jaxlib_version < (0, 3, 8), "test requires jaxlib>=0.3.8")
@jtu.skip_on_devices("tpu")
def testQrInvalidDtypeCPU(self, shape=(5, 6), dtype=np.float16):
# Regression test for https://github.com/google/jax/issues/10530
rng = jtu.rand_default(self.rng())
arr = rng(shape, dtype)
if jtu.device_under_test() == 'cpu':
err, msg = NotImplementedError, "Unsupported dtype float16"
else:
err, msg = ValueError, r"Unsupported dtype dtype\('float16'\)"
with self.assertRaisesRegex(err, msg):
jnp.linalg.qr(arr)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),