mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[shape_poly] linalg.schur: shape polymorphism with native lowering on CPU
PiperOrigin-RevId: 543533821
This commit is contained in:
parent
a91412e1e7
commit
c6a60054b9
@ -2055,10 +2055,18 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals,
|
||||
operand_aval, = ctx.avals_in
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
|
||||
gees_result = lapack.gees_hlo(operand_aval.dtype, operand,
|
||||
jobvs=compute_schur_vectors,
|
||||
sort=sort_eig_vals,
|
||||
select=select_callable)
|
||||
if jaxlib_version < (0, 4, 14):
|
||||
gees_result = lapack.gees_hlo(operand_aval.dtype, operand,
|
||||
jobvs=compute_schur_vectors,
|
||||
sort=sort_eig_vals,
|
||||
select=select_callable) # type: ignore
|
||||
else:
|
||||
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
|
||||
gees_result = lapack.gees_hlo(operand_aval.dtype, operand,
|
||||
jobvs=compute_schur_vectors,
|
||||
sort=sort_eig_vals,
|
||||
select=select_callable,
|
||||
a_shape_vals=a_shape_vals)
|
||||
|
||||
# Number of return values depends on value of sort_eig_vals.
|
||||
T, vs, *_, info = gees_result
|
||||
|
@ -732,6 +732,8 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
||||
# TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU
|
||||
# # lu on CPU
|
||||
"lapack_sgetrf", "lapack_dgetrf", "lapack_cgetrf", "lapack_zgetrf",
|
||||
# schur on CPU
|
||||
"lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees",
|
||||
# # lu on GPU
|
||||
# "cublas_getrf_batched", "cusolver_getrf",
|
||||
# "hipblas_getrf_batched", "hipsolver_getrf",
|
||||
|
@ -125,6 +125,8 @@ class CompatTest(bctu.CompatTestBase):
|
||||
"lapack_sgesdd", "lapack_dsesdd", "lapack_cgesdd", "lapack_zgesdd",
|
||||
# TODO(necula): add tests for triangular_solve on CPU
|
||||
"blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm",
|
||||
# TODO(necula): add tests for schur on CPU
|
||||
"lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees",
|
||||
})
|
||||
not_covered = targets_to_cover.difference(covered_targets)
|
||||
self.assertEmpty(not_covered)
|
||||
|
@ -2579,6 +2579,24 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
RandArg((7, 1), _f32), # updates: [b, 1]
|
||||
StaticArg(lax.ScatterDimensionNumbers((1,), (0,), (0, 1,)))],
|
||||
polymorphic_shapes=["b, ...", "b, ...", "b, ..."]),
|
||||
[
|
||||
PolyHarness("schur",
|
||||
f"shape={jtu.format_shape_dtype_string(shape, dtype)}_{poly=}_{compute_schur_vectors=}",
|
||||
lambda a, compute_schur_vectors: lax.linalg.schur(
|
||||
a, compute_schur_vectors=compute_schur_vectors),
|
||||
arg_descriptors=[RandArg(shape, dtype),
|
||||
StaticArg(compute_schur_vectors)],
|
||||
polymorphic_shapes=[poly],
|
||||
# In non-native serialization, we cannot check exact match,
|
||||
# we ought to check the invariants of the result.
|
||||
check_result=config.jax2tf_default_native_serialization)
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
|
||||
for compute_schur_vectors in [True, False]
|
||||
for (shape, poly) in [
|
||||
((3, 3), "w, w"),
|
||||
((3, 4, 4), "b, w, w"),
|
||||
]
|
||||
],
|
||||
PolyHarness("select", "0",
|
||||
# x.shape = (b, 3)
|
||||
lambda x: lax.select(x > 5., x, x),
|
||||
@ -2861,6 +2879,9 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses:
|
||||
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778")
|
||||
|
||||
if harness.group_name == "schur" and jtu.device_under_test() != "cpu":
|
||||
raise unittest.SkipTest("schur decomposition is only implemented on CPU.")
|
||||
|
||||
if "fft_fft_type" in harness.fullname:
|
||||
if "nr_fft_lengths=2" in harness.fullname:
|
||||
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for fft with non-constant fft_lengths on GPU and TPU")
|
||||
@ -2916,6 +2937,9 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
# For non-native serialization the overflow behavior is different.
|
||||
harness.check_result = False
|
||||
|
||||
if harness.group_name == "schur":
|
||||
raise unittest.SkipTest("jax2tf graph serialization does not support schur.")
|
||||
|
||||
if harness.group_name == "eig" and "left=True_right=True" in harness.fullname:
|
||||
raise unittest.SkipTest("jax2tf graph serialization does not support both left and right.")
|
||||
|
||||
|
@ -568,19 +568,15 @@ def geev_hlo(dtype, input, *,
|
||||
|
||||
# # gees : Schur factorization
|
||||
|
||||
def gees_hlo(dtype, a, jobvs=True, sort=False, select=None):
|
||||
def gees_hlo(dtype, a, *, jobvs=True, sort=False, select=None,
|
||||
a_shape_vals: tuple[DimensionSize, ...]):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
etype = a_type.element_type
|
||||
dims = a_type.shape
|
||||
assert len(dims) >= 2
|
||||
m, n = dims[-2:]
|
||||
assert m == n
|
||||
batch_dims = tuple(dims[:-2])
|
||||
num_bd = len(batch_dims)
|
||||
b = 1
|
||||
for d in batch_dims:
|
||||
b *= d
|
||||
assert len(a_shape_vals) >= 2
|
||||
n = a_shape_vals[-1]
|
||||
batch_dims_vals = a_shape_vals[:-2]
|
||||
num_bd = len(batch_dims_vals)
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
|
||||
if sort:
|
||||
@ -601,33 +597,38 @@ def gees_hlo(dtype, a, jobvs=True, sort=False, select=None):
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dtype {dtype}")
|
||||
|
||||
workspaces: list[ShapeTypePair]
|
||||
eigvals: list[ShapeTypePair]
|
||||
if not np.issubdtype(dtype, np.complexfloating):
|
||||
workspaces = [ir.RankedTensorType.get(dims, etype)]
|
||||
workspaces = [(a_shape_vals, etype)]
|
||||
workspace_layouts = [layout]
|
||||
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)] * 2
|
||||
eigvals = [(batch_dims_vals + (n,), etype)] * 2
|
||||
eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2
|
||||
else:
|
||||
workspaces = [
|
||||
ir.RankedTensorType.get(dims, etype),
|
||||
ir.RankedTensorType.get([n], ir.ComplexType(etype).element_type),
|
||||
workspaces = [(a_shape_vals, etype),
|
||||
([n], ir.ComplexType(etype).element_type),
|
||||
]
|
||||
workspace_layouts = [layout, [0]]
|
||||
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)]
|
||||
eigvals = [(batch_dims_vals + (n,), etype)]
|
||||
eigvals_layouts = [tuple(range(num_bd, -1, -1))]
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
|
||||
scalar_layout = []
|
||||
batch_size_val = hlo_s32(1)
|
||||
for b_v in batch_dims_vals:
|
||||
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
|
||||
shape_type_pairs = workspaces + eigvals + [
|
||||
(a_shape_vals, etype),
|
||||
(batch_dims_vals, i32_type),
|
||||
(batch_dims_vals, i32_type)]
|
||||
result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs)
|
||||
out = custom_call(
|
||||
fn,
|
||||
workspaces + eigvals + [
|
||||
ir.RankedTensorType.get(dims, etype),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
],
|
||||
result_types,
|
||||
[
|
||||
hlo_s32(b),
|
||||
hlo_s32(n),
|
||||
batch_size_val,
|
||||
ensure_hlo_s32(n),
|
||||
hlo_u8(jobvs),
|
||||
hlo_u8(sort),
|
||||
# TODO: figure out how to put the callable select function here
|
||||
@ -640,6 +641,7 @@ def gees_hlo(dtype, a, jobvs=True, sort=False, select=None):
|
||||
tuple(range(num_bd - 1, -1, -1)),
|
||||
],
|
||||
operand_output_aliases={4: 0},
|
||||
result_shapes=result_shapes,
|
||||
)
|
||||
if sort == ord('S'):
|
||||
return (out[0], out[3], out[4], out[5])
|
||||
|
Loading…
x
Reference in New Issue
Block a user