[shape_poly] linalg.schur: shape polymorphism with native lowering on CPU

PiperOrigin-RevId: 543533821
This commit is contained in:
George Necula 2023-06-26 13:58:23 -07:00 committed by jax authors
parent a91412e1e7
commit c6a60054b9
5 changed files with 65 additions and 27 deletions

View File

@ -2055,10 +2055,18 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals,
operand_aval, = ctx.avals_in
batch_dims = operand_aval.shape[:-2]
gees_result = lapack.gees_hlo(operand_aval.dtype, operand,
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable)
if jaxlib_version < (0, 4, 14):
gees_result = lapack.gees_hlo(operand_aval.dtype, operand,
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable) # type: ignore
else:
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
gees_result = lapack.gees_hlo(operand_aval.dtype, operand,
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable,
a_shape_vals=a_shape_vals)
# Number of return values depends on value of sort_eig_vals.
T, vs, *_, info = gees_result

View File

@ -732,6 +732,8 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
# TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU
# # lu on CPU
"lapack_sgetrf", "lapack_dgetrf", "lapack_cgetrf", "lapack_zgetrf",
# schur on CPU
"lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees",
# # lu on GPU
# "cublas_getrf_batched", "cusolver_getrf",
# "hipblas_getrf_batched", "hipsolver_getrf",

View File

@ -125,6 +125,8 @@ class CompatTest(bctu.CompatTestBase):
"lapack_sgesdd", "lapack_dsesdd", "lapack_cgesdd", "lapack_zgesdd",
# TODO(necula): add tests for triangular_solve on CPU
"blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm",
# TODO(necula): add tests for schur on CPU
"lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees",
})
not_covered = targets_to_cover.difference(covered_targets)
self.assertEmpty(not_covered)

View File

@ -2579,6 +2579,24 @@ _POLY_SHAPE_TEST_HARNESSES = [
RandArg((7, 1), _f32), # updates: [b, 1]
StaticArg(lax.ScatterDimensionNumbers((1,), (0,), (0, 1,)))],
polymorphic_shapes=["b, ...", "b, ...", "b, ..."]),
[
PolyHarness("schur",
f"shape={jtu.format_shape_dtype_string(shape, dtype)}_{poly=}_{compute_schur_vectors=}",
lambda a, compute_schur_vectors: lax.linalg.schur(
a, compute_schur_vectors=compute_schur_vectors),
arg_descriptors=[RandArg(shape, dtype),
StaticArg(compute_schur_vectors)],
polymorphic_shapes=[poly],
# In non-native serialization, we cannot check exact match,
# we ought to check the invariants of the result.
check_result=config.jax2tf_default_native_serialization)
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
for compute_schur_vectors in [True, False]
for (shape, poly) in [
((3, 3), "w, w"),
((3, 4, 4), "b, w, w"),
]
],
PolyHarness("select", "0",
# x.shape = (b, 3)
lambda x: lax.select(x > 5., x, x),
@ -2861,6 +2879,9 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses:
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778")
if harness.group_name == "schur" and jtu.device_under_test() != "cpu":
raise unittest.SkipTest("schur decomposition is only implemented on CPU.")
if "fft_fft_type" in harness.fullname:
if "nr_fft_lengths=2" in harness.fullname:
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for fft with non-constant fft_lengths on GPU and TPU")
@ -2916,6 +2937,9 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
# For non-native serialization the overflow behavior is different.
harness.check_result = False
if harness.group_name == "schur":
raise unittest.SkipTest("jax2tf graph serialization does not support schur.")
if harness.group_name == "eig" and "left=True_right=True" in harness.fullname:
raise unittest.SkipTest("jax2tf graph serialization does not support both left and right.")

View File

@ -568,19 +568,15 @@ def geev_hlo(dtype, input, *,
# # gees : Schur factorization
def gees_hlo(dtype, a, jobvs=True, sort=False, select=None):
def gees_hlo(dtype, a, *, jobvs=True, sort=False, select=None,
a_shape_vals: tuple[DimensionSize, ...]):
_initialize()
a_type = ir.RankedTensorType(a.type)
etype = a_type.element_type
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
assert len(a_shape_vals) >= 2
n = a_shape_vals[-1]
batch_dims_vals = a_shape_vals[:-2]
num_bd = len(batch_dims_vals)
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
if sort:
@ -601,33 +597,38 @@ def gees_hlo(dtype, a, jobvs=True, sort=False, select=None):
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
workspaces: list[ShapeTypePair]
eigvals: list[ShapeTypePair]
if not np.issubdtype(dtype, np.complexfloating):
workspaces = [ir.RankedTensorType.get(dims, etype)]
workspaces = [(a_shape_vals, etype)]
workspace_layouts = [layout]
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)] * 2
eigvals = [(batch_dims_vals + (n,), etype)] * 2
eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2
else:
workspaces = [
ir.RankedTensorType.get(dims, etype),
ir.RankedTensorType.get([n], ir.ComplexType(etype).element_type),
workspaces = [(a_shape_vals, etype),
([n], ir.ComplexType(etype).element_type),
]
workspace_layouts = [layout, [0]]
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)]
eigvals = [(batch_dims_vals + (n,), etype)]
eigvals_layouts = [tuple(range(num_bd, -1, -1))]
i32_type = ir.IntegerType.get_signless(32)
scalar_layout = []
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
shape_type_pairs = workspaces + eigvals + [
(a_shape_vals, etype),
(batch_dims_vals, i32_type),
(batch_dims_vals, i32_type)]
result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs)
out = custom_call(
fn,
workspaces + eigvals + [
ir.RankedTensorType.get(dims, etype),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get(batch_dims, i32_type),
],
result_types,
[
hlo_s32(b),
hlo_s32(n),
batch_size_val,
ensure_hlo_s32(n),
hlo_u8(jobvs),
hlo_u8(sort),
# TODO: figure out how to put the callable select function here
@ -640,6 +641,7 @@ def gees_hlo(dtype, a, jobvs=True, sort=False, select=None):
tuple(range(num_bd - 1, -1, -1)),
],
operand_output_aliases={4: 0},
result_shapes=result_shapes,
)
if sort == ord('S'):
return (out[0], out[3], out[4], out[5])