diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 64286d4f3..98e9a589d 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -2055,10 +2055,18 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals, operand_aval, = ctx.avals_in batch_dims = operand_aval.shape[:-2] - gees_result = lapack.gees_hlo(operand_aval.dtype, operand, - jobvs=compute_schur_vectors, - sort=sort_eig_vals, - select=select_callable) + if jaxlib_version < (0, 4, 14): + gees_result = lapack.gees_hlo(operand_aval.dtype, operand, + jobvs=compute_schur_vectors, + sort=sort_eig_vals, + select=select_callable) # type: ignore + else: + a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) + gees_result = lapack.gees_hlo(operand_aval.dtype, operand, + jobvs=compute_schur_vectors, + sort=sort_eig_vals, + select=select_callable, + a_shape_vals=a_shape_vals) # Number of return values depends on value of sort_eig_vals. T, vs, *_, info = gees_result diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index 1d5b5f6b4..83054c8a9 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -732,6 +732,8 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = { # TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU # # lu on CPU "lapack_sgetrf", "lapack_dgetrf", "lapack_cgetrf", "lapack_zgetrf", + # schur on CPU + "lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees", # # lu on GPU # "cublas_getrf_batched", "cusolver_getrf", # "hipblas_getrf_batched", "hipsolver_getrf", diff --git a/jax/experimental/jax2tf/tests/back_compat_test.py b/jax/experimental/jax2tf/tests/back_compat_test.py index 307a58555..2b5382af8 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_test.py @@ -125,6 +125,8 @@ class CompatTest(bctu.CompatTestBase): "lapack_sgesdd", "lapack_dsesdd", "lapack_cgesdd", "lapack_zgesdd", # TODO(necula): add tests for triangular_solve on CPU "blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm", + # TODO(necula): add tests for schur on CPU + "lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees", }) not_covered = targets_to_cover.difference(covered_targets) self.assertEmpty(not_covered) diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 8d8bc1c57..e374ace40 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -2579,6 +2579,24 @@ _POLY_SHAPE_TEST_HARNESSES = [ RandArg((7, 1), _f32), # updates: [b, 1] StaticArg(lax.ScatterDimensionNumbers((1,), (0,), (0, 1,)))], polymorphic_shapes=["b, ...", "b, ...", "b, ..."]), + [ + PolyHarness("schur", + f"shape={jtu.format_shape_dtype_string(shape, dtype)}_{poly=}_{compute_schur_vectors=}", + lambda a, compute_schur_vectors: lax.linalg.schur( + a, compute_schur_vectors=compute_schur_vectors), + arg_descriptors=[RandArg(shape, dtype), + StaticArg(compute_schur_vectors)], + polymorphic_shapes=[poly], + # In non-native serialization, we cannot check exact match, + # we ought to check the invariants of the result. + check_result=config.jax2tf_default_native_serialization) + for dtype in [np.float32, np.float64, np.complex64, np.complex128] + for compute_schur_vectors in [True, False] + for (shape, poly) in [ + ((3, 3), "w, w"), + ((3, 4, 4), "b, w, w"), + ] + ], PolyHarness("select", "0", # x.shape = (b, 3) lambda x: lax.select(x > 5., x, x), @@ -2861,6 +2879,9 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase): if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses: raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778") + if harness.group_name == "schur" and jtu.device_under_test() != "cpu": + raise unittest.SkipTest("schur decomposition is only implemented on CPU.") + if "fft_fft_type" in harness.fullname: if "nr_fft_lengths=2" in harness.fullname: raise unittest.SkipTest("native serialization with shape polymorphism not implemented for fft with non-constant fft_lengths on GPU and TPU") @@ -2916,6 +2937,9 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase): # For non-native serialization the overflow behavior is different. harness.check_result = False + if harness.group_name == "schur": + raise unittest.SkipTest("jax2tf graph serialization does not support schur.") + if harness.group_name == "eig" and "left=True_right=True" in harness.fullname: raise unittest.SkipTest("jax2tf graph serialization does not support both left and right.") diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index 63fbef735..d61aedb73 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -568,19 +568,15 @@ def geev_hlo(dtype, input, *, # # gees : Schur factorization -def gees_hlo(dtype, a, jobvs=True, sort=False, select=None): +def gees_hlo(dtype, a, *, jobvs=True, sort=False, select=None, + a_shape_vals: tuple[DimensionSize, ...]): _initialize() a_type = ir.RankedTensorType(a.type) etype = a_type.element_type - dims = a_type.shape - assert len(dims) >= 2 - m, n = dims[-2:] - assert m == n - batch_dims = tuple(dims[:-2]) - num_bd = len(batch_dims) - b = 1 - for d in batch_dims: - b *= d + assert len(a_shape_vals) >= 2 + n = a_shape_vals[-1] + batch_dims_vals = a_shape_vals[:-2] + num_bd = len(batch_dims_vals) layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) if sort: @@ -601,33 +597,38 @@ def gees_hlo(dtype, a, jobvs=True, sort=False, select=None): else: raise NotImplementedError(f"Unsupported dtype {dtype}") + workspaces: list[ShapeTypePair] + eigvals: list[ShapeTypePair] if not np.issubdtype(dtype, np.complexfloating): - workspaces = [ir.RankedTensorType.get(dims, etype)] + workspaces = [(a_shape_vals, etype)] workspace_layouts = [layout] - eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)] * 2 + eigvals = [(batch_dims_vals + (n,), etype)] * 2 eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 else: - workspaces = [ - ir.RankedTensorType.get(dims, etype), - ir.RankedTensorType.get([n], ir.ComplexType(etype).element_type), + workspaces = [(a_shape_vals, etype), + ([n], ir.ComplexType(etype).element_type), ] workspace_layouts = [layout, [0]] - eigvals = [ir.RankedTensorType.get(batch_dims + (n,), etype)] + eigvals = [(batch_dims_vals + (n,), etype)] eigvals_layouts = [tuple(range(num_bd, -1, -1))] i32_type = ir.IntegerType.get_signless(32) scalar_layout = [] + batch_size_val = hlo_s32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result + shape_type_pairs = workspaces + eigvals + [ + (a_shape_vals, etype), + (batch_dims_vals, i32_type), + (batch_dims_vals, i32_type)] + result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) out = custom_call( fn, - workspaces + eigvals + [ - ir.RankedTensorType.get(dims, etype), - ir.RankedTensorType.get(batch_dims, i32_type), - ir.RankedTensorType.get(batch_dims, i32_type), - ], + result_types, [ - hlo_s32(b), - hlo_s32(n), + batch_size_val, + ensure_hlo_s32(n), hlo_u8(jobvs), hlo_u8(sort), # TODO: figure out how to put the callable select function here @@ -640,6 +641,7 @@ def gees_hlo(dtype, a, jobvs=True, sort=False, select=None): tuple(range(num_bd - 1, -1, -1)), ], operand_output_aliases={4: 0}, + result_shapes=result_shapes, ) if sort == ord('S'): return (out[0], out[3], out[4], out[5])