mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jax.scipy.linalg.schur: error on 16-bit floats
Fixes https://github.com/google/jax/issues/10530 PiperOrigin-RevId: 446279906
This commit is contained in:
parent
37ea024d39
commit
c6343ddf8e
@ -1607,18 +1607,19 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals,
|
||||
operand_aval, = ctx.avals_in
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
|
||||
if sort_eig_vals:
|
||||
T, vs, _sdim, info = lapack.gees_mhlo(
|
||||
operand,
|
||||
jobvs=compute_schur_vectors,
|
||||
sort=sort_eig_vals,
|
||||
select=select_callable)
|
||||
else:
|
||||
T, vs, info = lapack.gees_mhlo(
|
||||
operand,
|
||||
jobvs=compute_schur_vectors,
|
||||
sort=sort_eig_vals,
|
||||
select=select_callable)
|
||||
# TODO(jakevdp): remove this try/except when minimum jaxlib >= 0.3.8
|
||||
try:
|
||||
gees_result = lapack.gees_mhlo(operand_aval.dtype, operand,
|
||||
jobvs=compute_schur_vectors,
|
||||
sort=sort_eig_vals,
|
||||
select=select_callable)
|
||||
except TypeError: # API for jaxlib <= 0.3.7
|
||||
gees_result = lapack.gees_mhlo(operand, # pytype: disable=missing-parameter
|
||||
jobvs=compute_schur_vectors,
|
||||
sort=sort_eig_vals,
|
||||
select=select_callable)
|
||||
# Number of return values depends on value of sort_eig_vals.
|
||||
T, vs, *_, info = gees_result
|
||||
|
||||
ok = mlir.compare_mhlo(
|
||||
info, mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
||||
|
@ -672,7 +672,7 @@ def geev_mhlo(dtype, a, jobvl=True, jobvr=True):
|
||||
|
||||
# # gees : Schur factorization
|
||||
|
||||
def gees_mhlo(a, jobvs=True, sort=False, select=None):
|
||||
def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None):
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
etype = a_type.element_type
|
||||
dims = a_type.shape
|
||||
@ -695,10 +695,19 @@ def gees_mhlo(a, jobvs=True, sort=False, select=None):
|
||||
jobvs = ord('V' if jobvs else 'N')
|
||||
sort = ord('S' if sort else 'N')
|
||||
|
||||
if not ir.ComplexType.isinstance(etype):
|
||||
fn = "lapack_sgees" if etype == ir.F32Type.get() else "lapack_dgees"
|
||||
schurvecs_type = etype
|
||||
workspaces = [ir.RankedTensorType.get(dims, schurvecs_type)]
|
||||
if dtype == np.float32:
|
||||
fn = "lapack_sgees"
|
||||
elif dtype == np.float64:
|
||||
fn = "lapack_dgees"
|
||||
elif dtype == np.complex64:
|
||||
fn = "lapack_cgees"
|
||||
elif dtype == np.complex128:
|
||||
fn = "lapack_zgees"
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dtype {dtype}")
|
||||
|
||||
if not np.issubdtype(dtype, np.complexfloating):
|
||||
workspaces = [ir.RankedTensorType.get(dims, etype)]
|
||||
workspace_layouts = [layout]
|
||||
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)] * 2
|
||||
eigvals_layouts = [
|
||||
@ -706,11 +715,8 @@ def gees_mhlo(a, jobvs=True, sort=False, select=None):
|
||||
type=ir.IndexType.get())
|
||||
] * 2
|
||||
else:
|
||||
fn = ("lapack_cgees" if etype == ir.ComplexType.get(ir.F32Type.get())
|
||||
else "lapack_zgees")
|
||||
schurvecs_type = etype
|
||||
workspaces = [
|
||||
ir.RankedTensorType.get(dims, schurvecs_type),
|
||||
ir.RankedTensorType.get(dims, etype),
|
||||
ir.RankedTensorType.get([n], ir.ComplexType(etype).element_type),
|
||||
]
|
||||
workspace_layouts = [
|
||||
@ -729,7 +735,7 @@ def gees_mhlo(a, jobvs=True, sort=False, select=None):
|
||||
type=ir.IndexType.get())
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple(workspaces + eigvals + [
|
||||
ir.RankedTensorType.get(dims, schurvecs_type),
|
||||
ir.RankedTensorType.get(dims, etype),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
])],
|
||||
|
@ -41,6 +41,8 @@ T = lambda x: np.swapaxes(x, -1, -2)
|
||||
float_types = jtu.dtypes.floating
|
||||
complex_types = jtu.dtypes.complex
|
||||
|
||||
jaxlib_version = tuple(map(int, jax.lib.__version__.split('.')))
|
||||
|
||||
|
||||
class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
@ -719,6 +721,19 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
qr = partial(jnp.linalg.qr, mode=mode)
|
||||
jtu.check_jvp(qr, partial(jvp, qr), (a,), atol=3e-3)
|
||||
|
||||
@unittest.skipIf(jaxlib_version < (0, 3, 8), "test requires jaxlib>=0.3.8")
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testQrInvalidDtypeCPU(self, shape=(5, 6), dtype=np.float16):
|
||||
# Regression test for https://github.com/google/jax/issues/10530
|
||||
rng = jtu.rand_default(self.rng())
|
||||
arr = rng(shape, dtype)
|
||||
if jtu.device_under_test() == 'cpu':
|
||||
err, msg = NotImplementedError, "Unsupported dtype float16"
|
||||
else:
|
||||
err, msg = ValueError, r"Unsupported dtype dtype\('float16'\)"
|
||||
with self.assertRaisesRegex(err, msg):
|
||||
jnp.linalg.qr(arr)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
|
Loading…
x
Reference in New Issue
Block a user