mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[shape_poly] linalg.eig: shape polymorphism with native serialization on CPU
The backwards compatibility tests to be added separately. PiperOrigin-RevId: 541122069
This commit is contained in:
parent
68a38c6021
commit
3adfe321b0
@ -553,7 +553,6 @@ def sharded_aval(aval: core.AbstractValue,
|
||||
|
||||
def eval_dynamic_shape(ctx: LoweringRuleContext,
|
||||
shape: core.Shape) -> Tuple[Union[int, Value], ...]:
|
||||
# assert not core.is_constant_shape(shape)
|
||||
if config.jax_dynamic_shapes:
|
||||
return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore
|
||||
else:
|
||||
@ -565,6 +564,20 @@ def eval_dynamic_shape(ctx: LoweringRuleContext,
|
||||
multiple_results=True)(ctx, *ctx.dim_var_values)
|
||||
return util.flatten(res) # type: ignore
|
||||
|
||||
def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext,
|
||||
shape: core.Shape) -> Tuple[Value, ...]:
|
||||
"""Evaluates the dynamic shapes as int32 values."""
|
||||
def convert_dim(d: Union[int, Value]):
|
||||
if type(d) is int:
|
||||
return ir_constant(np.array(d, dtype=np.int32))
|
||||
else:
|
||||
i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
|
||||
if d.type != i32_type: # type: ignore
|
||||
return hlo.ConvertOp(i32_type, d).result
|
||||
else:
|
||||
return d
|
||||
return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape))
|
||||
|
||||
|
||||
class LoweringResult(NamedTuple):
|
||||
module: ir.Module
|
||||
|
@ -487,15 +487,23 @@ def eig_abstract_eval(operand, *, compute_left_eigenvectors,
|
||||
|
||||
def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
|
||||
compute_right_eigenvectors):
|
||||
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
||||
raise NotImplementedError("Shape polymorphism for custom call is not implemented (eig); b/261671778")
|
||||
operand_aval, = ctx.avals_in
|
||||
out_aval = ctx.avals_out[0]
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
|
||||
w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand,
|
||||
jobvl=compute_left_eigenvectors,
|
||||
jobvr=compute_right_eigenvectors)
|
||||
if jaxlib_version < (0, 4, 13):
|
||||
if any(not is_constant_shape(a.shape) for a in ctx.avals_in):
|
||||
raise NotImplementedError(
|
||||
"Shape polymorphism for eig is not implemented. "
|
||||
"Try upgrading jaxlib")
|
||||
w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand, # type: ignore
|
||||
jobvl=compute_left_eigenvectors,
|
||||
jobvr=compute_right_eigenvectors)
|
||||
else:
|
||||
op_shape_vals = mlir.eval_dynamic_shape_as_vals(ctx, operand_aval.shape)
|
||||
w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand,
|
||||
input_shape_vals=op_shape_vals,
|
||||
jobvl=compute_left_eigenvectors,
|
||||
jobvr=compute_right_eigenvectors)
|
||||
|
||||
ok = mlir.compare_hlo(
|
||||
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
||||
|
@ -694,6 +694,8 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
||||
"cusolver_syevj", "cusolver_syevd",
|
||||
# eigh on TPU
|
||||
"Eigh",
|
||||
# eig on CPU
|
||||
"lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev",
|
||||
# qr on CPU
|
||||
"lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf",
|
||||
"lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr",
|
||||
|
@ -412,6 +412,9 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
self.assertIsInstance(data, CompatTestData)
|
||||
covered_targets = covered_targets.union(data.custom_call_targets)
|
||||
|
||||
# TODO(necula): add tests for eig on CPU
|
||||
covered_targets = covered_targets.union({
|
||||
"lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev"})
|
||||
not_covered = targets_to_cover.difference(covered_targets)
|
||||
self.assertEmpty(not_covered)
|
||||
|
||||
|
@ -2021,7 +2021,21 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
# x:shape: (b, 4)
|
||||
lambda x, idx: lax.dynamic_update_slice(x, x, idx),
|
||||
arg_descriptors=[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
|
||||
polymorphic_shapes=["b, ...", None]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, _", None]).both_enable_and_disable_xla(),
|
||||
[
|
||||
PolyHarness("eig", f"shape={jtu.format_shape_dtype_string((3, 5, 5), dtype)}_poly={poly}_{left=}_{right=}",
|
||||
lambda x, left, right: lax.linalg.eig(x, compute_left_eigenvectors=left, compute_right_eigenvectors=right),
|
||||
arg_descriptors=[RandArg((3, 5, 5), dtype),
|
||||
StaticArg(left), StaticArg(right)],
|
||||
polymorphic_shapes=[poly],
|
||||
# In non-native serialization, we cannot check exact match,
|
||||
# we ought to check the invariants of the result.
|
||||
check_result=config.jax2tf_default_native_serialization)
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
|
||||
for poly in ["b, ...", "b, w, w"]
|
||||
for left in ([True, False] if dtype == np.float32 else [True])
|
||||
for right in ([True, False] if dtype == np.float32 else [False])
|
||||
],
|
||||
PolyHarness("einsum", "0",
|
||||
lambda x: jnp.einsum("...i->...", x),
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
@ -2728,7 +2742,6 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
# Set of harness.group_name:platform that are implemented with custom call
|
||||
custom_call_harnesses = {
|
||||
"vmap_cholesky:cpu", "vmap_cholesky:gpu",
|
||||
"vmap_eig:cpu",
|
||||
"vmap_fft:cpu", "fft:cpu",
|
||||
"householder_product:cpu", "householder_product:gpu",
|
||||
"vmap_geqrf:cpu", "vmap_geqrf:gpu",
|
||||
@ -2795,11 +2808,17 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
# For non-native serialization the overflow behavior is different.
|
||||
harness.check_result = False
|
||||
|
||||
if harness.group_name == "eig" and "left=True_right=True" in harness.fullname:
|
||||
raise unittest.SkipTest("jax2tf graph serialization does not support both left and right.")
|
||||
|
||||
# FOR BOTH NATIVE AND GRAPH SERIALIZATION
|
||||
if harness.group_name == "vmap_conv_general_dilated":
|
||||
# https://github.com/openxla/stablehlo/issues/1268
|
||||
raise unittest.SkipTest("Need more dynamism for DynamicConvOp")
|
||||
|
||||
if harness.group_name == "eig" and jtu.device_under_test() != "cpu":
|
||||
raise unittest.SkipTest("JAX implements eig only on CPU.")
|
||||
|
||||
prev_jax_config_flags = {
|
||||
fname: getattr(jax.config, fname)
|
||||
for fname, fvalue in harness.override_jax_config_flags.items()
|
||||
|
@ -12,13 +12,68 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Helpers for building MLIR operators
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
"""A small libary of helpers for use in jaxlib to build MLIR operations."""
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, Optional, Sequence, Union
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
import numpy as np
|
||||
|
||||
|
||||
_dtype_to_ir_type_factory : Dict[np.dtype, Callable[[], ir.Type]] = {
|
||||
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
|
||||
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
|
||||
np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16),
|
||||
np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32),
|
||||
np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64),
|
||||
np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8),
|
||||
np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16),
|
||||
np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32),
|
||||
np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64),
|
||||
np.dtype(np.float16): ir.F16Type.get,
|
||||
np.dtype(np.float32): ir.F32Type.get,
|
||||
np.dtype(np.float64): ir.F64Type.get,
|
||||
np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
|
||||
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
|
||||
}
|
||||
def dtype_to_ir_type(dtype) -> ir.Type:
|
||||
return _dtype_to_ir_type_factory[np.dtype(dtype)]()
|
||||
|
||||
def ir_constant(x: np.ndarray) -> ir.Value:
|
||||
assert isinstance(x, np.ndarray)
|
||||
return hlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(x, type=dtype_to_ir_type(x.dtype))).result
|
||||
|
||||
def ir_constant_u8(x: int): return ir_constant(np.array(x, dtype=np.uint8))
|
||||
def ir_constant_i32(x: int): return ir_constant(np.array(x, dtype=np.int32))
|
||||
|
||||
def shape_dtype_to_ir_type(shape: Sequence[int], dtype) -> ir.Type:
|
||||
return ir.RankedTensorType.get(shape, dtype_to_ir_type(dtype))
|
||||
|
||||
|
||||
# TODO(necula): share this with mlir.shape_tensor
|
||||
def shape_tensor(sizes: Sequence[Union[int, ir.Value]]) -> ir.Value:
|
||||
int1d = shape_dtype_to_ir_type((1,), np.int32)
|
||||
i32_type = shape_dtype_to_ir_type((), np.int32)
|
||||
def dim_to_i32x1(d):
|
||||
if type(d) is int:
|
||||
return ir_constant(np.array([d], dtype=np.int32))
|
||||
else:
|
||||
if d.type != i32_type:
|
||||
d = hlo.ConvertOp(i32_type, d).result
|
||||
return hlo.ReshapeOp(int1d, d).result
|
||||
ds = [dim_to_i32x1(sz) for sz in sizes]
|
||||
if not ds:
|
||||
return ir_constant(np.array([], np.int32))
|
||||
elif len(ds) == 1:
|
||||
return ds[0]
|
||||
else:
|
||||
return hlo.ConcatenateOp(
|
||||
ds, ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0)).result
|
||||
|
||||
|
||||
# TODO(necula): share this with mlir.custom_call
|
||||
def custom_call(
|
||||
call_target_name: Union[str, bytes],
|
||||
out_types: Sequence[ir.Type],
|
||||
@ -42,12 +97,12 @@ def custom_call(
|
||||
match the number of the results. They are appended to the list
|
||||
of operands.
|
||||
"""
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
attributes = dict(
|
||||
call_target_name=ir.StringAttr.get(call_target_name),
|
||||
has_side_effect=ir.BoolAttr.get(has_side_effect),
|
||||
backend_config=ir.StringAttr.get(backend_config),
|
||||
api_version=ir.IntegerAttr.get(i32_type, api_version),
|
||||
api_version=ir.IntegerAttr.get(
|
||||
ir.IntegerType.get_signless(32), api_version),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
output_operand_aliases=ir.ArrayAttr.get([
|
||||
hlo.OutputOperandAlias.get(
|
||||
|
131
jaxlib/lapack.py
131
jaxlib/lapack.py
@ -19,9 +19,13 @@ import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
from jaxlib import xla_client
|
||||
|
||||
from .hlo_helpers import custom_call
|
||||
from .hlo_helpers import (
|
||||
custom_call, ir_constant_u8, ir_constant_i32,
|
||||
shape_tensor
|
||||
)
|
||||
from .cpu import _lapack
|
||||
|
||||
for _name, _value in _lapack.registrations().items():
|
||||
@ -477,85 +481,126 @@ def syevd_hlo(dtype, a: ir.Value, batch_size: ir.Value,
|
||||
return out[:3]
|
||||
|
||||
|
||||
# # geev: Nonsymmetric eigendecomposition
|
||||
# # geev: Nonsymmetric eigendecomposition (eig)
|
||||
|
||||
def geev_hlo(dtype, a, jobvl=True, jobvr=True):
|
||||
def geev_hlo(dtype, input, *,
|
||||
input_shape_vals: Tuple[ir.Value, ...], # input.shape as ir.Values
|
||||
jobvl=True, jobvr=True):
|
||||
# input_shape_vals are used for when input has dynamic shapes.
|
||||
_initialize()
|
||||
dims = ir.RankedTensorType(a.type).shape
|
||||
assert len(dims) >= 2
|
||||
m, n = dims[-2:]
|
||||
assert m == n
|
||||
batch_dims = tuple(dims[:-2])
|
||||
input_shape = ir.RankedTensorType(input.type).shape
|
||||
assert len(input_shape) >= 2
|
||||
n = input_shape[-1]
|
||||
n_val: ir.Value = input_shape_vals[-1]
|
||||
batch_dims = tuple(input_shape[:-2])
|
||||
batch_dims_vals = input_shape_vals[:-2]
|
||||
num_bd = len(batch_dims)
|
||||
b = 1
|
||||
for d in batch_dims:
|
||||
b *= d
|
||||
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
|
||||
jobvl_c = ord('V' if jobvl else 'N')
|
||||
jobvr_c = ord('V' if jobvr else 'N')
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
f32_type = ir.F32Type.get()
|
||||
f64_type = ir.F64Type.get()
|
||||
c64_type = ir.ComplexType.get(ir.F32Type.get())
|
||||
c128_type = ir.ComplexType.get(ir.F64Type.get())
|
||||
|
||||
if n == ir.ShapedType.get_dynamic_size():
|
||||
two_n = ir.ShapedType.get_dynamic_size()
|
||||
else:
|
||||
two_n = n + n
|
||||
if dtype == np.float32:
|
||||
fn = b"lapack_sgeev"
|
||||
real = True
|
||||
eigvecs_type = ir.ComplexType.get(ir.F32Type.get())
|
||||
workspaces = [ir.RankedTensorType.get([n, n], ir.F32Type.get()),
|
||||
ir.RankedTensorType.get([n, n], ir.F32Type.get()),
|
||||
ir.RankedTensorType.get([n, n], ir.F32Type.get())]
|
||||
eigvecs_type = c64_type
|
||||
workspace_types = [ir.RankedTensorType.get([n, n], f32_type)] * 3
|
||||
workspace_result_shapes = [shape_tensor((n_val, n_val))] * 3
|
||||
workspace_layouts = [[0, 1]] * 3
|
||||
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), ir.F32Type.get()),
|
||||
ir.RankedTensorType.get(batch_dims + (n,), ir.F32Type.get())]
|
||||
eigval_types = [
|
||||
ir.RankedTensorType.get(batch_dims + (n,), f32_type)] * 2
|
||||
eigval_result_shapes = [
|
||||
shape_tensor(batch_dims_vals + (n_val,))] * 2
|
||||
eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2
|
||||
elif dtype == np.float64:
|
||||
fn = b"lapack_dgeev"
|
||||
real = True
|
||||
eigvecs_type = ir.ComplexType.get(ir.F64Type.get())
|
||||
workspaces = [ir.RankedTensorType.get([n, n], ir.F64Type.get()),
|
||||
ir.RankedTensorType.get([n, n], ir.F64Type.get()),
|
||||
ir.RankedTensorType.get([n, n], ir.F64Type.get())]
|
||||
eigvecs_type = c128_type
|
||||
workspace_types = [ir.RankedTensorType.get([n, n], f64_type)] * 3
|
||||
workspace_result_shapes = [shape_tensor((n_val, n_val))] * 3
|
||||
workspace_layouts = [[0, 1]] * 3
|
||||
eigvals = [ir.RankedTensorType.get(batch_dims + (n,), ir.F64Type.get()),
|
||||
ir.RankedTensorType.get(batch_dims + (n,), ir.F64Type.get())]
|
||||
eigval_types = [
|
||||
ir.RankedTensorType.get(batch_dims + (n,), f64_type)] * 2
|
||||
eigval_result_shapes = [
|
||||
shape_tensor(batch_dims_vals + (n_val,))] * 2
|
||||
eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2
|
||||
elif dtype == np.complex64:
|
||||
fn = b"lapack_cgeev"
|
||||
real = False
|
||||
eigvecs_type = ir.ComplexType.get(ir.F32Type.get())
|
||||
workspaces = [ir.RankedTensorType.get([n, n],
|
||||
ir.ComplexType.get(ir.F32Type.get())),
|
||||
ir.RankedTensorType.get([2 * n], ir.F32Type.get())]
|
||||
eigvecs_type = c64_type
|
||||
workspace_types = [
|
||||
ir.RankedTensorType.get([n, n], c64_type),
|
||||
ir.RankedTensorType.get([two_n], f32_type)]
|
||||
workspace_result_shapes = [
|
||||
shape_tensor((n_val, n_val)),
|
||||
shape_tensor((hlo.AddOp(n_val, n_val).result,))]
|
||||
workspace_layouts = [[0, 1], [0]]
|
||||
eigvals = [ir.RankedTensorType.get(batch_dims + (n,),
|
||||
ir.ComplexType.get(ir.F32Type.get()))]
|
||||
eigval_types = [
|
||||
ir.RankedTensorType.get(batch_dims + (n,), c64_type)]
|
||||
eigval_result_shapes = [shape_tensor(batch_dims_vals + (n_val,))]
|
||||
eigvals_layouts = [tuple(range(num_bd, -1, -1))]
|
||||
elif dtype == np.complex128:
|
||||
fn = b"lapack_zgeev"
|
||||
real = False
|
||||
eigvecs_type = ir.ComplexType.get(ir.F64Type.get())
|
||||
workspaces = [ir.RankedTensorType.get([n, n],
|
||||
ir.ComplexType.get(ir.F64Type.get())),
|
||||
ir.RankedTensorType.get([2 * n], ir.F64Type.get())]
|
||||
eigvecs_type = c128_type
|
||||
workspace_types = [
|
||||
ir.RankedTensorType.get([n, n], c128_type),
|
||||
ir.RankedTensorType.get([two_n], f64_type)]
|
||||
workspace_result_shapes = [
|
||||
shape_tensor((n_val, n_val)),
|
||||
shape_tensor((hlo.AddOp(n_val, n_val).result,))]
|
||||
workspace_layouts = [[0, 1], [0]]
|
||||
eigvals = [ir.RankedTensorType.get(batch_dims + (n,),
|
||||
ir.ComplexType.get(ir.F64Type.get()))]
|
||||
eigval_types = [
|
||||
ir.RankedTensorType.get(batch_dims + (n,), c128_type)]
|
||||
eigval_result_shapes = [
|
||||
shape_tensor(batch_dims_vals + (n_val,))]
|
||||
eigvals_layouts = [tuple(range(num_bd, -1, -1))]
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dtype {dtype}")
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
scalar_layout = []
|
||||
info_layout = tuple(range(num_bd - 1, -1, -1))
|
||||
|
||||
batch_size_val = ir_constant_i32(1)
|
||||
for b_v in batch_dims_vals:
|
||||
batch_size_val = hlo.MulOp(batch_size_val, b_v).result
|
||||
|
||||
result_types = (
|
||||
workspace_types + eigval_types + [
|
||||
ir.RankedTensorType.get(input_shape, eigvecs_type),
|
||||
ir.RankedTensorType.get(input_shape, eigvecs_type),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
])
|
||||
if any(a == ir.ShapedType.get_dynamic_size() for a in input_shape):
|
||||
result_shapes = workspace_result_shapes + eigval_result_shapes + [
|
||||
shape_tensor(input_shape_vals),
|
||||
shape_tensor(input_shape_vals),
|
||||
shape_tensor(batch_dims_vals),
|
||||
]
|
||||
else:
|
||||
result_shapes = None
|
||||
out = custom_call(
|
||||
fn,
|
||||
workspaces + eigvals + [
|
||||
ir.RankedTensorType.get(dims, eigvecs_type),
|
||||
ir.RankedTensorType.get(dims, eigvecs_type),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
],
|
||||
[_hlo_s32(b), _hlo_s32(n), _hlo_u8(jobvl_c), _hlo_u8(jobvr_c), a],
|
||||
result_types,
|
||||
[batch_size_val, n_val,
|
||||
ir_constant_u8(jobvl_c),
|
||||
ir_constant_u8(jobvr_c),
|
||||
input],
|
||||
operand_layouts=[scalar_layout] * 4 + [layout],
|
||||
result_layouts=(workspace_layouts + eigvals_layouts + [layout] * 2 +
|
||||
[info_layout])
|
||||
[info_layout]),
|
||||
result_shapes=result_shapes,
|
||||
)
|
||||
if real:
|
||||
return (hlo.ComplexOp(out[3], out[4]).result, out[5], out[6], out[7])
|
||||
|
Loading…
x
Reference in New Issue
Block a user