[shape_poly] linalg.eig: shape polymorphism with native serialization on CPU

The backwards compatibility tests to be added separately.

PiperOrigin-RevId: 541122069
This commit is contained in:
George Necula 2023-06-16 23:58:37 -07:00 committed by jax authors
parent 68a38c6021
commit 3adfe321b0
7 changed files with 201 additions and 56 deletions

View File

@ -553,7 +553,6 @@ def sharded_aval(aval: core.AbstractValue,
def eval_dynamic_shape(ctx: LoweringRuleContext,
shape: core.Shape) -> Tuple[Union[int, Value], ...]:
# assert not core.is_constant_shape(shape)
if config.jax_dynamic_shapes:
return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore
else:
@ -565,6 +564,20 @@ def eval_dynamic_shape(ctx: LoweringRuleContext,
multiple_results=True)(ctx, *ctx.dim_var_values)
return util.flatten(res) # type: ignore
def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext,
shape: core.Shape) -> Tuple[Value, ...]:
"""Evaluates the dynamic shapes as int32 values."""
def convert_dim(d: Union[int, Value]):
if type(d) is int:
return ir_constant(np.array(d, dtype=np.int32))
else:
i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
if d.type != i32_type: # type: ignore
return hlo.ConvertOp(i32_type, d).result
else:
return d
return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape))
class LoweringResult(NamedTuple):
module: ir.Module

View File

@ -487,15 +487,23 @@ def eig_abstract_eval(operand, *, compute_left_eigenvectors,
def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
compute_right_eigenvectors):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (eig); b/261671778")
operand_aval, = ctx.avals_in
out_aval = ctx.avals_out[0]
batch_dims = operand_aval.shape[:-2]
w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand,
jobvl=compute_left_eigenvectors,
jobvr=compute_right_eigenvectors)
if jaxlib_version < (0, 4, 13):
if any(not is_constant_shape(a.shape) for a in ctx.avals_in):
raise NotImplementedError(
"Shape polymorphism for eig is not implemented. "
"Try upgrading jaxlib")
w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand, # type: ignore
jobvl=compute_left_eigenvectors,
jobvr=compute_right_eigenvectors)
else:
op_shape_vals = mlir.eval_dynamic_shape_as_vals(ctx, operand_aval.shape)
w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand,
input_shape_vals=op_shape_vals,
jobvl=compute_left_eigenvectors,
jobvr=compute_right_eigenvectors)
ok = mlir.compare_hlo(
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),

View File

@ -694,6 +694,8 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
"cusolver_syevj", "cusolver_syevd",
# eigh on TPU
"Eigh",
# eig on CPU
"lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev",
# qr on CPU
"lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf",
"lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr",

View File

@ -412,6 +412,9 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
self.assertIsInstance(data, CompatTestData)
covered_targets = covered_targets.union(data.custom_call_targets)
# TODO(necula): add tests for eig on CPU
covered_targets = covered_targets.union({
"lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev"})
not_covered = targets_to_cover.difference(covered_targets)
self.assertEmpty(not_covered)

View File

@ -2021,7 +2021,21 @@ _POLY_SHAPE_TEST_HARNESSES = [
# x:shape: (b, 4)
lambda x, idx: lax.dynamic_update_slice(x, x, idx),
arg_descriptors=[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
polymorphic_shapes=["b, ...", None]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, _", None]).both_enable_and_disable_xla(),
[
PolyHarness("eig", f"shape={jtu.format_shape_dtype_string((3, 5, 5), dtype)}_poly={poly}_{left=}_{right=}",
lambda x, left, right: lax.linalg.eig(x, compute_left_eigenvectors=left, compute_right_eigenvectors=right),
arg_descriptors=[RandArg((3, 5, 5), dtype),
StaticArg(left), StaticArg(right)],
polymorphic_shapes=[poly],
# In non-native serialization, we cannot check exact match,
# we ought to check the invariants of the result.
check_result=config.jax2tf_default_native_serialization)
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
for poly in ["b, ...", "b, w, w"]
for left in ([True, False] if dtype == np.float32 else [True])
for right in ([True, False] if dtype == np.float32 else [False])
],
PolyHarness("einsum", "0",
lambda x: jnp.einsum("...i->...", x),
arg_descriptors=[RandArg((3, 4), _f32)],
@ -2728,7 +2742,6 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
# Set of harness.group_name:platform that are implemented with custom call
custom_call_harnesses = {
"vmap_cholesky:cpu", "vmap_cholesky:gpu",
"vmap_eig:cpu",
"vmap_fft:cpu", "fft:cpu",
"householder_product:cpu", "householder_product:gpu",
"vmap_geqrf:cpu", "vmap_geqrf:gpu",
@ -2795,11 +2808,17 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
# For non-native serialization the overflow behavior is different.
harness.check_result = False
if harness.group_name == "eig" and "left=True_right=True" in harness.fullname:
raise unittest.SkipTest("jax2tf graph serialization does not support both left and right.")
# FOR BOTH NATIVE AND GRAPH SERIALIZATION
if harness.group_name == "vmap_conv_general_dilated":
# https://github.com/openxla/stablehlo/issues/1268
raise unittest.SkipTest("Need more dynamism for DynamicConvOp")
if harness.group_name == "eig" and jtu.device_under_test() != "cpu":
raise unittest.SkipTest("JAX implements eig only on CPU.")
prev_jax_config_flags = {
fname: getattr(jax.config, fname)
for fname, fvalue in harness.override_jax_config_flags.items()

View File

@ -12,13 +12,68 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Helpers for building MLIR operators
from typing import Dict, Optional, Sequence, Union
"""A small libary of helpers for use in jaxlib to build MLIR operations."""
from functools import partial
from typing import Callable, Dict, Optional, Sequence, Union
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.stablehlo as hlo
import numpy as np
_dtype_to_ir_type_factory : Dict[np.dtype, Callable[[], ir.Type]] = {
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16),
np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32),
np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64),
np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8),
np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16),
np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32),
np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64),
np.dtype(np.float16): ir.F16Type.get,
np.dtype(np.float32): ir.F32Type.get,
np.dtype(np.float64): ir.F64Type.get,
np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
}
def dtype_to_ir_type(dtype) -> ir.Type:
return _dtype_to_ir_type_factory[np.dtype(dtype)]()
def ir_constant(x: np.ndarray) -> ir.Value:
assert isinstance(x, np.ndarray)
return hlo.ConstantOp(
ir.DenseElementsAttr.get(x, type=dtype_to_ir_type(x.dtype))).result
def ir_constant_u8(x: int): return ir_constant(np.array(x, dtype=np.uint8))
def ir_constant_i32(x: int): return ir_constant(np.array(x, dtype=np.int32))
def shape_dtype_to_ir_type(shape: Sequence[int], dtype) -> ir.Type:
return ir.RankedTensorType.get(shape, dtype_to_ir_type(dtype))
# TODO(necula): share this with mlir.shape_tensor
def shape_tensor(sizes: Sequence[Union[int, ir.Value]]) -> ir.Value:
int1d = shape_dtype_to_ir_type((1,), np.int32)
i32_type = shape_dtype_to_ir_type((), np.int32)
def dim_to_i32x1(d):
if type(d) is int:
return ir_constant(np.array([d], dtype=np.int32))
else:
if d.type != i32_type:
d = hlo.ConvertOp(i32_type, d).result
return hlo.ReshapeOp(int1d, d).result
ds = [dim_to_i32x1(sz) for sz in sizes]
if not ds:
return ir_constant(np.array([], np.int32))
elif len(ds) == 1:
return ds[0]
else:
return hlo.ConcatenateOp(
ds, ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0)).result
# TODO(necula): share this with mlir.custom_call
def custom_call(
call_target_name: Union[str, bytes],
out_types: Sequence[ir.Type],
@ -42,12 +97,12 @@ def custom_call(
match the number of the results. They are appended to the list
of operands.
"""
i32_type = ir.IntegerType.get_signless(32)
attributes = dict(
call_target_name=ir.StringAttr.get(call_target_name),
has_side_effect=ir.BoolAttr.get(has_side_effect),
backend_config=ir.StringAttr.get(backend_config),
api_version=ir.IntegerAttr.get(i32_type, api_version),
api_version=ir.IntegerAttr.get(
ir.IntegerType.get_signless(32), api_version),
called_computations=ir.ArrayAttr.get([]),
output_operand_aliases=ir.ArrayAttr.get([
hlo.OutputOperandAlias.get(

View File

@ -19,9 +19,13 @@ import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.stablehlo as hlo
import numpy as np
from typing import Tuple
from jaxlib import xla_client
from .hlo_helpers import custom_call
from .hlo_helpers import (
custom_call, ir_constant_u8, ir_constant_i32,
shape_tensor
)
from .cpu import _lapack
for _name, _value in _lapack.registrations().items():
@ -477,85 +481,126 @@ def syevd_hlo(dtype, a: ir.Value, batch_size: ir.Value,
return out[:3]
# # geev: Nonsymmetric eigendecomposition
# # geev: Nonsymmetric eigendecomposition (eig)
def geev_hlo(dtype, a, jobvl=True, jobvr=True):
def geev_hlo(dtype, input, *,
input_shape_vals: Tuple[ir.Value, ...], # input.shape as ir.Values
jobvl=True, jobvr=True):
# input_shape_vals are used for when input has dynamic shapes.
_initialize()
dims = ir.RankedTensorType(a.type).shape
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n
batch_dims = tuple(dims[:-2])
input_shape = ir.RankedTensorType(input.type).shape
assert len(input_shape) >= 2
n = input_shape[-1]
n_val: ir.Value = input_shape_vals[-1]
batch_dims = tuple(input_shape[:-2])
batch_dims_vals = input_shape_vals[:-2]
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
jobvl_c = ord('V' if jobvl else 'N')
jobvr_c = ord('V' if jobvr else 'N')
i32_type = ir.IntegerType.get_signless(32)
f32_type = ir.F32Type.get()
f64_type = ir.F64Type.get()
c64_type = ir.ComplexType.get(ir.F32Type.get())
c128_type = ir.ComplexType.get(ir.F64Type.get())
if n == ir.ShapedType.get_dynamic_size():
two_n = ir.ShapedType.get_dynamic_size()
else:
two_n = n + n
if dtype == np.float32:
fn = b"lapack_sgeev"
real = True
eigvecs_type = ir.ComplexType.get(ir.F32Type.get())
workspaces = [ir.RankedTensorType.get([n, n], ir.F32Type.get()),
ir.RankedTensorType.get([n, n], ir.F32Type.get()),
ir.RankedTensorType.get([n, n], ir.F32Type.get())]
eigvecs_type = c64_type
workspace_types = [ir.RankedTensorType.get([n, n], f32_type)] * 3
workspace_result_shapes = [shape_tensor((n_val, n_val))] * 3
workspace_layouts = [[0, 1]] * 3
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), ir.F32Type.get()),
ir.RankedTensorType.get(batch_dims + (n,), ir.F32Type.get())]
eigval_types = [
ir.RankedTensorType.get(batch_dims + (n,), f32_type)] * 2
eigval_result_shapes = [
shape_tensor(batch_dims_vals + (n_val,))] * 2
eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2
elif dtype == np.float64:
fn = b"lapack_dgeev"
real = True
eigvecs_type = ir.ComplexType.get(ir.F64Type.get())
workspaces = [ir.RankedTensorType.get([n, n], ir.F64Type.get()),
ir.RankedTensorType.get([n, n], ir.F64Type.get()),
ir.RankedTensorType.get([n, n], ir.F64Type.get())]
eigvecs_type = c128_type
workspace_types = [ir.RankedTensorType.get([n, n], f64_type)] * 3
workspace_result_shapes = [shape_tensor((n_val, n_val))] * 3
workspace_layouts = [[0, 1]] * 3
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), ir.F64Type.get()),
ir.RankedTensorType.get(batch_dims + (n,), ir.F64Type.get())]
eigval_types = [
ir.RankedTensorType.get(batch_dims + (n,), f64_type)] * 2
eigval_result_shapes = [
shape_tensor(batch_dims_vals + (n_val,))] * 2
eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2
elif dtype == np.complex64:
fn = b"lapack_cgeev"
real = False
eigvecs_type = ir.ComplexType.get(ir.F32Type.get())
workspaces = [ir.RankedTensorType.get([n, n],
ir.ComplexType.get(ir.F32Type.get())),
ir.RankedTensorType.get([2 * n], ir.F32Type.get())]
eigvecs_type = c64_type
workspace_types = [
ir.RankedTensorType.get([n, n], c64_type),
ir.RankedTensorType.get([two_n], f32_type)]
workspace_result_shapes = [
shape_tensor((n_val, n_val)),
shape_tensor((hlo.AddOp(n_val, n_val).result,))]
workspace_layouts = [[0, 1], [0]]
eigvals = [ir.RankedTensorType.get(batch_dims + (n,),
ir.ComplexType.get(ir.F32Type.get()))]
eigval_types = [
ir.RankedTensorType.get(batch_dims + (n,), c64_type)]
eigval_result_shapes = [shape_tensor(batch_dims_vals + (n_val,))]
eigvals_layouts = [tuple(range(num_bd, -1, -1))]
elif dtype == np.complex128:
fn = b"lapack_zgeev"
real = False
eigvecs_type = ir.ComplexType.get(ir.F64Type.get())
workspaces = [ir.RankedTensorType.get([n, n],
ir.ComplexType.get(ir.F64Type.get())),
ir.RankedTensorType.get([2 * n], ir.F64Type.get())]
eigvecs_type = c128_type
workspace_types = [
ir.RankedTensorType.get([n, n], c128_type),
ir.RankedTensorType.get([two_n], f64_type)]
workspace_result_shapes = [
shape_tensor((n_val, n_val)),
shape_tensor((hlo.AddOp(n_val, n_val).result,))]
workspace_layouts = [[0, 1], [0]]
eigvals = [ir.RankedTensorType.get(batch_dims + (n,),
ir.ComplexType.get(ir.F64Type.get()))]
eigval_types = [
ir.RankedTensorType.get(batch_dims + (n,), c128_type)]
eigval_result_shapes = [
shape_tensor(batch_dims_vals + (n_val,))]
eigvals_layouts = [tuple(range(num_bd, -1, -1))]
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
i32_type = ir.IntegerType.get_signless(32)
scalar_layout = []
info_layout = tuple(range(num_bd - 1, -1, -1))
batch_size_val = ir_constant_i32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.MulOp(batch_size_val, b_v).result
result_types = (
workspace_types + eigval_types + [
ir.RankedTensorType.get(input_shape, eigvecs_type),
ir.RankedTensorType.get(input_shape, eigvecs_type),
ir.RankedTensorType.get(batch_dims, i32_type),
])
if any(a == ir.ShapedType.get_dynamic_size() for a in input_shape):
result_shapes = workspace_result_shapes + eigval_result_shapes + [
shape_tensor(input_shape_vals),
shape_tensor(input_shape_vals),
shape_tensor(batch_dims_vals),
]
else:
result_shapes = None
out = custom_call(
fn,
workspaces + eigvals + [
ir.RankedTensorType.get(dims, eigvecs_type),
ir.RankedTensorType.get(dims, eigvecs_type),
ir.RankedTensorType.get(batch_dims, i32_type),
],
[_hlo_s32(b), _hlo_s32(n), _hlo_u8(jobvl_c), _hlo_u8(jobvr_c), a],
result_types,
[batch_size_val, n_val,
ir_constant_u8(jobvl_c),
ir_constant_u8(jobvr_c),
input],
operand_layouts=[scalar_layout] * 4 + [layout],
result_layouts=(workspace_layouts + eigvals_layouts + [layout] * 2 +
[info_layout])
[info_layout]),
result_shapes=result_shapes,
)
if real:
return (hlo.ComplexOp(out[3], out[4]).result, out[5], out[6], out[7])