mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[shape_poly] Add static constraint checking to the computation of dim vars
Previously we had one function `shape_poly.unify_avals_with_args` that was solving the dimension variables and was also used for generating the code to compute them. Now we separate the solving part, which is now using just symbolic expressions (`shape_poly.solve_dim_vars`), from the code generator for the dimension variables (`shape_poly.compute_dim_vars_from_arg_shapes`). We also add a notion of shape constraints, e.g., `dimexpr1 == dimexpr2` or `dimexpr1 >= dimexpr2`, under which the solution for the dimension variables is valid. For now we implement the static checking of the shape constraints, e.g., when the dimension expressions are constant or TF EagerTensor. We do not yet have compile-time checking of the constraints. This matches the previous behavior. However, the code is now ready for implementing compile-time checking of the constraints that cannot be checked statically.
This commit is contained in:
parent
acfeb9bb13
commit
9ad8c3b9f1
@ -82,11 +82,14 @@ def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i)
|
||||
def shape_tensor(sizes: Sequence[Union[int, ir.RankedTensorType]]
|
||||
) -> ir.RankedTensorType:
|
||||
int1d = aval_to_ir_type(core.ShapedArray((1,), np.int32))
|
||||
i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
|
||||
def lower_dim(d):
|
||||
if type(d) is int:
|
||||
return ir_constant(np.array([d], np.int32))
|
||||
else:
|
||||
return hlo.ReshapeOp(int1d, hlo.ConvertOp(aval_to_ir_type(core.ShapedArray((), np.int32)), d))
|
||||
if d.type != i32_type:
|
||||
d = hlo.ConvertOp(i32_type, d)
|
||||
return hlo.ReshapeOp(int1d, d)
|
||||
ds = map(lower_dim, sizes)
|
||||
if not ds:
|
||||
return ir_constant(np.array([], np.int32))
|
||||
|
@ -2291,7 +2291,7 @@ def arange(start: DimSize, stop: Optional[DimSize] = None,
|
||||
step = 1
|
||||
elif stop is not None and step is None:
|
||||
step = 1
|
||||
return _arange_dynamic(start, stop, step, dtype or int_)
|
||||
return _arange_dynamic(start, stop, step, dtype or dtypes.canonicalize_dtype(np.int64))
|
||||
if dtype is None:
|
||||
dtype = result_type(start, *(x for x in [stop, step] if x is not None))
|
||||
dtype = _jnp_dtype(dtype)
|
||||
|
@ -564,9 +564,8 @@ class GraphSerializationImpl(SerializationImpl):
|
||||
map(lambda a: core.raise_to_shaped(core.get_aval(a)), args_specs_flat))
|
||||
dim_vars = shape_poly.all_dim_vars(self.args_avals_flat)
|
||||
dim_values, _ = _interpret_fun_jax(
|
||||
partial(shape_poly.unify_avals_with_args, self.args_avals_flat, dim_vars,
|
||||
use_static_dimension_size=False,
|
||||
args_kwargs_tree=self.in_tree),
|
||||
partial(shape_poly.compute_dim_vars_from_arg_shapes,
|
||||
self.args_avals_flat, args_kwargs_tree=self.in_tree),
|
||||
self.args_flat_tf, self.args_avals_flat, self.name_stack)
|
||||
_thread_local_state.shape_env = zip(dim_vars, dim_values)
|
||||
|
||||
|
@ -483,12 +483,11 @@ def _compute_dim_args(
|
||||
the values of the dimension variables, in the sorted order of the
|
||||
dimension variables.
|
||||
"""
|
||||
dim_vars = shape_poly.all_dim_vars(args_avals_flat)
|
||||
dim_values = mlir.lower_fun(
|
||||
functools.partial(shape_poly.unify_avals_with_args, args_avals_flat, dim_vars,
|
||||
use_static_dimension_size=False,
|
||||
args_kwargs_tree=args_kwargs_tree),
|
||||
functools.partial(shape_poly.compute_dim_vars_from_arg_shapes,
|
||||
args_avals_flat, args_kwargs_tree=args_kwargs_tree),
|
||||
multiple_results=True)(ctx, *array_args)
|
||||
|
||||
res = []
|
||||
for dim_arg, dim_arg_type in zip(util.flatten(dim_values), dim_arg_types):
|
||||
if dim_arg.type != dim_arg_type:
|
||||
@ -752,16 +751,42 @@ def call_exported(exported: Exported) -> Callable[..., jax.Array]:
|
||||
call_exported_p = core.Primitive("call_exported")
|
||||
call_exported_p.multiple_results = True
|
||||
|
||||
@util.cache()
|
||||
def _call_exported_abstract_eval(*in_avals: core.AbstractValue,
|
||||
exported: Exported) -> Tuple[core.AbstractValue, ...]:
|
||||
exported_dim_vars = shape_poly.all_dim_vars(exported.in_avals)
|
||||
assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure
|
||||
# Must express the exported_dim_vars in terms of the shapes in in_avals.
|
||||
exported_dim_values = shape_poly.unify_avals_with_args(
|
||||
exported.in_avals, exported_dim_vars, *in_avals, # type: ignore
|
||||
use_static_dimension_size=True,
|
||||
args_kwargs_tree=exported.in_tree)
|
||||
# Check that the expected shapes match the actual ones
|
||||
for arg_idx, (exp_aval, actual_aval) in enumerate(zip(exported.in_avals, in_avals)):
|
||||
def pp_arg_dim(dim_idx: Optional[int]) -> str:
|
||||
return shape_poly.pretty_print_dimension_descriptor(exported.in_tree,
|
||||
arg_idx, dim_idx)
|
||||
if len(exp_aval.shape) != len(actual_aval.shape):
|
||||
raise ValueError(
|
||||
f"Rank mismatch for {pp_arg_dim(None)}: expected {exp_aval.shape} "
|
||||
f"and called with {actual_aval.shape}")
|
||||
if exp_aval.dtype != actual_aval.dtype:
|
||||
raise ValueError(
|
||||
f"Dtype mismatch for {pp_arg_dim(None)}: expected {exp_aval.dtype} "
|
||||
f"and called with {actual_aval.dtype}")
|
||||
for dim_idx, aval_d in enumerate(exp_aval.shape):
|
||||
# If the exp_aval has a constant dimension then the actual argument must have
|
||||
# a matching constant dimension.
|
||||
if core.is_constant_dim(aval_d):
|
||||
if (not core.is_constant_dim(actual_aval.shape[dim_idx]) or
|
||||
aval_d != actual_aval.shape[dim_idx]):
|
||||
raise ValueError(
|
||||
f"Shape mismatch for {pp_arg_dim(dim_idx)} (expected constant): "
|
||||
f"expected {exp_aval.shape} and called with {actual_aval.shape}")
|
||||
|
||||
# Must express the exported_dim_vars in terms of the shapes in in_avals.
|
||||
solution, shape_constraints, known_dim_vars = shape_poly.solve_dim_vars(
|
||||
exported.in_avals, args_kwargs_tree=exported.in_tree)
|
||||
known_env = {vname: in_avals[arg_idx].shape[dim_idx]
|
||||
for (vname, arg_idx, dim_idx) in known_dim_vars}
|
||||
shape_constraints.check(known_env)
|
||||
exported_dim_values = [solution[var].evaluate(known_env)
|
||||
for var in exported_dim_vars]
|
||||
return tuple(
|
||||
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars,
|
||||
*exported_dim_values),
|
||||
|
@ -32,6 +32,7 @@ jax2tf.convert docstring, and the
|
||||
"""
|
||||
import collections
|
||||
import dataclasses
|
||||
from enum import Enum
|
||||
import functools
|
||||
import itertools
|
||||
import io
|
||||
@ -53,6 +54,7 @@ from jax._src import dtypes
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.typing import DimSize, Shape
|
||||
|
||||
|
||||
@ -640,7 +642,8 @@ class _DimExpr():
|
||||
|
||||
def evaluate(self, env: DimVarEnv):
|
||||
# Evaluates as a value of dtype=core.dim_value_dtype()
|
||||
terms = [_evaluate_multiply(mon.evaluate(env), core.dim_constant(coeff)) for mon, coeff in self.monomials()]
|
||||
terms = [_evaluate_multiply(mon.evaluate(env), core.dim_constant(coeff))
|
||||
for mon, coeff in self.monomials()]
|
||||
return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0]
|
||||
|
||||
@staticmethod
|
||||
@ -1098,21 +1101,80 @@ def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Sequence[str]:
|
||||
dim_vars = dim_vars.union(d.get_vars())
|
||||
return sorted(tuple(dim_vars))
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DimEquation:
|
||||
# Represents arg.shape[dim_idx] == dim_expr
|
||||
arg: jax.Array
|
||||
dim_idx: int
|
||||
dim_expr: _DimExpr
|
||||
debug_arg_str: Callable[[], str] # A pretty-printer for a descriptor for `arg`
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ShapeConstraint:
|
||||
class Comparator(Enum):
|
||||
EQ = 1
|
||||
GEQ = 2
|
||||
|
||||
comp: Comparator
|
||||
left: DimSize
|
||||
right: DimSize
|
||||
# make_err_msg is invoked with (left_int, right_int) if the constraint fails.
|
||||
make_err_msg: Callable[[int, int], str]
|
||||
|
||||
def check(self, shapeenv: DimVarEnv) -> None:
|
||||
"""Evaluates a constraint statically and raises an error if fails."""
|
||||
def eval_operand(o: DimSize) -> Union[int, jax.Array]:
|
||||
if core.is_constant_dim(o): return op.index(o)
|
||||
return o.evaluate(shapeenv) # type: ignore
|
||||
try:
|
||||
left1, right1 = eval_operand(self.left), eval_operand(self.right)
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
left_int, right_int = _is_known_constant(left1), _is_known_constant(right1)
|
||||
if left_int is not None and right_int is not None:
|
||||
if self.comp == ShapeConstraint.Comparator.EQ:
|
||||
if not (left_int == right_int):
|
||||
raise ValueError(self.make_err_msg(left_int, right_int))
|
||||
elif self.comp == ShapeConstraint.Comparator.GEQ:
|
||||
if not (left_int >= right_int):
|
||||
raise ValueError(self.make_err_msg(left_int, right_int))
|
||||
else: assert False
|
||||
else:
|
||||
return None # TODO: evaluate constraint dynamically
|
||||
|
||||
def __str__(self):
|
||||
return (f"{self.dim_expr} == {self.debug_arg_str()}"
|
||||
f".shape[{self.dim_idx}] (statically {self.arg.shape[self.dim_idx]})")
|
||||
return (f"{self.left} {'==' if self.comp == ShapeConstraint.Comparator.EQ else '>='} {self.right}"
|
||||
f" ({self.make_err_msg(self.left, self.right)})")
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
class ShapeConstraints:
|
||||
def __init__(self):
|
||||
self.constraints: Set[ShapeConstraint] = set() # map DimConstraint to an integer >= 0
|
||||
|
||||
|
||||
def add_constraint(self,
|
||||
comp: ShapeConstraint.Comparator,
|
||||
left: DimSize, right: DimSize,
|
||||
make_err_msg: Callable[[int, int], str]):
|
||||
# Try to evaluate it statically
|
||||
c = ShapeConstraint(comp, left, right, make_err_msg)
|
||||
self.constraints.add(c)
|
||||
|
||||
def check(self, shapeenv: DimVarEnv) -> None:
|
||||
for constraint in self.constraints:
|
||||
constraint.check(shapeenv)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DimEquation:
|
||||
# Represents dim_expr == dim_value, where `dim_expr` contain unknown dimension
|
||||
# variables, in terms of `dim_value`.
|
||||
dim_expr: _DimExpr
|
||||
dim_value: _DimExpr
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.dim_expr} == {self.dim_value}"
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
def args_kwargs_path_to_str(path: tree_util.KeyPath) -> str:
|
||||
# String description of args or kwargs, assuming the path is in a tree for
|
||||
# the tuple (args, kwargs)
|
||||
# String description of `args` or `kwargs`, assuming the path for a tree for
|
||||
# the tuple `(args, kwargs)`.
|
||||
if path[0] == tree_util.SequenceKey(0):
|
||||
return f"args{tree_util.keystr(path[1:])}"
|
||||
elif path[0] == tree_util.SequenceKey(1):
|
||||
@ -1120,92 +1182,104 @@ def args_kwargs_path_to_str(path: tree_util.KeyPath) -> str:
|
||||
else:
|
||||
assert False
|
||||
|
||||
def unify_avals_with_args(
|
||||
args_avals: Sequence[core.AbstractValue],
|
||||
dim_vars: Sequence[str],
|
||||
*args: jax.Array,
|
||||
use_static_dimension_size: bool,
|
||||
def pretty_print_dimension_descriptor(
|
||||
args_kwargs_tree: tree_util.PyTreeDef,
|
||||
) -> Sequence[jax.Array]:
|
||||
"""Computes values of dimension variables to unify avals with actual arguments.
|
||||
flat_arg_idx: int, dim_idx: Optional[int]) -> str:
|
||||
args_kwargs_with_paths, _ = tree_util.tree_flatten_with_path(
|
||||
args_kwargs_tree.unflatten((0,) * args_kwargs_tree.num_leaves))
|
||||
arg_str = args_kwargs_path_to_str(args_kwargs_with_paths[flat_arg_idx][0])
|
||||
if dim_idx is not None:
|
||||
arg_str += f".shape[{dim_idx}]"
|
||||
return arg_str
|
||||
|
||||
Computes values for dimension variables for which the shapes in `args_avals`
|
||||
(abstract values for a function's parameters) match the shapes of `args` (the
|
||||
actual arguments). This is done by forming equations
|
||||
between the symbolic expressions from `args_avals` and the actual dimension
|
||||
sizes of the actual arguments, and then solving for the dimension variables.
|
||||
@util.cache()
|
||||
def solve_dim_vars(
|
||||
args_avals: Sequence[core.AbstractValue],
|
||||
args_kwargs_tree: tree_util.PyTreeDef,
|
||||
) -> Tuple[DimVarEnv, ShapeConstraints, Sequence[Tuple[str, int, int]]]:
|
||||
"""Solves dimension variables in a called function's avals in terms of actual argument shapes.
|
||||
|
||||
Not all equations are solvable. For now, the linear uni-variate equations
|
||||
are solved first, then the solved variables are used to simplify the
|
||||
remaining equations to linear uni-variate equations, and the process continues
|
||||
For example, given:
|
||||
|
||||
args_avals = [ShapedArray((3, a, a + b), f32)]
|
||||
|
||||
we introduce fresh "known" dimension variables to represent the actual dimension
|
||||
size of actual arguments for each non-constant dimension. Each known variable
|
||||
has a name, an arg_idx, and a dim_idx, e.g.:
|
||||
|
||||
known_vars = [("args[0].shape[1]", 0, 1), ("args[0].shape[2]", 0, 2)]
|
||||
|
||||
and then we express the solution for the unknown dimension variables {a, b}
|
||||
as symbolic expressions in terms of the known variables:
|
||||
|
||||
dict(a=args[0].shape[1], b=args[0].shape[2] - args[0].shape[1])
|
||||
|
||||
Not all equations are solvable. For now, we solve first the linear uni-variate
|
||||
equations, then the solved variables are used to simplify the remaining
|
||||
equations to linear uni-variate equations, and the process continues
|
||||
until all dimension variables are solved.
|
||||
|
||||
Args:
|
||||
args_avals: the abstract values of the `args`, with shapes that may
|
||||
include dimension variables. A flat sequence.
|
||||
dim_vars: the dimension variables that occur in `args_avals`. The only
|
||||
reason we need these is to ensure that the result of this function is a
|
||||
flat list of jax.Array in the same order.
|
||||
args: the actual function arguments, as jax.Array or any value with `.shape`
|
||||
and `.dtype` attributes if `use_static_dimension_size`.
|
||||
use_static_dimension_size: if `True` then it forms the equations using the
|
||||
static shapes of `args`. This is useful, e.g., when we want to compute
|
||||
the dimension variables statically. If `False` then it forms the
|
||||
equations using the dynamic shapes of `args`, e.g., using
|
||||
`stablehlo.GetDimensionSizeOp` for native serialization or `tf.shape`
|
||||
for TF graph serialization. This is useful when we want to generate
|
||||
code to compute the dimension variables at compilation-time.
|
||||
include unknown dimension variables.
|
||||
args_kwargs_tree: a PyTreeDef that describes the tuple `(args, kwargs)` from
|
||||
which the flat sequence `args_avals` is extracted. Used for referencing args and
|
||||
kwargs in error messages.
|
||||
which the flat sequence `args_avals` is extracted. Used for describing
|
||||
args and kwargs in known variable names and in error messages.
|
||||
|
||||
Returns: the values of `dim_vars` in the same order.
|
||||
Returns: a 3-tuple with: (a) the solution for the unknown dimension variables
|
||||
(b) a list of constraints that must be satisfied for the solution to be a
|
||||
valid one, and (c) and the list of known variables that may appear in
|
||||
the solution and the constraints.
|
||||
|
||||
Raises ValueError if it cannot solve for the `dim_vars`.
|
||||
Raises ValueError if it cannot solve some dimension variable.
|
||||
"""
|
||||
dim_equations: List[DimEquation] = []
|
||||
def debug_arg_str(flat_arg_idx: int) -> str:
|
||||
# Debug descriptor of an argument.
|
||||
args_avals_tree = args_kwargs_tree.unflatten(args_avals)
|
||||
args_avals_with_paths, _ = tree_util.tree_flatten_with_path(args_avals_tree)
|
||||
return args_kwargs_path_to_str(args_avals_with_paths[flat_arg_idx][0])
|
||||
|
||||
for arg_idx, (aval, arg) in enumerate(zip(args_avals, args)):
|
||||
if len(aval.shape) != len(arg.shape):
|
||||
raise ValueError(
|
||||
f"Rank mismatch for {debug_arg_str(arg_idx)}: expected {aval.shape} "
|
||||
f"and called with {arg.shape}")
|
||||
continue
|
||||
if aval.dtype != arg.dtype:
|
||||
raise ValueError(
|
||||
f"Dtype mismatch for {debug_arg_str(arg_idx)}: expected {aval.dtype} "
|
||||
f"and called with {arg.dtype}")
|
||||
dim_equations: List[_DimEquation] = []
|
||||
known_dimension_vars: List[Tuple[str, int, int]] = []
|
||||
for arg_idx, aval in enumerate(args_avals):
|
||||
for dim_idx, aval_d in enumerate(aval.shape):
|
||||
# If the aval has a constant dimension then the actual argument must have
|
||||
# a matching constant dimension.
|
||||
if not is_poly_dim(aval_d):
|
||||
if _is_known_constant(arg.shape[dim_idx]) is None or aval_d != arg.shape[dim_idx]:
|
||||
raise ValueError(
|
||||
f"Shape mismatch for {debug_arg_str(arg_idx)} in dimension {dim_idx}: "
|
||||
f"expected {aval.shape} and called with {arg.shape}")
|
||||
else:
|
||||
if is_poly_dim(aval_d):
|
||||
known_dim_var = pretty_print_dimension_descriptor(args_kwargs_tree,
|
||||
arg_idx, dim_idx)
|
||||
known_dimension_vars.append((known_dim_var, arg_idx, dim_idx))
|
||||
dim_equations.append(
|
||||
DimEquation(arg=arg,
|
||||
dim_idx=dim_idx, dim_expr=_ensure_poly(aval_d, "unify_avals_with_args"),
|
||||
debug_arg_str=functools.partial(debug_arg_str, arg_idx)))
|
||||
_DimEquation(dim_expr=_ensure_poly(aval_d, "solve_dim_vars"),
|
||||
dim_value=_DimExpr.from_var(known_dim_var)))
|
||||
|
||||
dim_env = _solve_dim_equations(dim_equations,
|
||||
use_static_dimension_size=use_static_dimension_size)
|
||||
dim_values = tuple(dim_env[dv] for dv in dim_vars)
|
||||
return dim_values
|
||||
solution, shape_constraints = _solve_dim_equations(dim_equations)
|
||||
return solution, shape_constraints, known_dimension_vars
|
||||
|
||||
def compute_dim_vars_from_arg_shapes(
|
||||
args_avals: Sequence[core.AbstractValue],
|
||||
*actual_args: jax.Array,
|
||||
args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[jax.Array]:
|
||||
"""Computes values of dimension variables to unify args_avals with actual arguments.
|
||||
|
||||
Like `solve_dim_vars` except that here we express the solution as
|
||||
JAX arrays that reference the `actual_args`. This function can be used to
|
||||
generate the code for computing the dimension variables.
|
||||
|
||||
Returns: the values of the dimension variables, in the order determined by
|
||||
`all_dim_vars(args_avals)`.
|
||||
"""
|
||||
dim_vars = all_dim_vars(args_avals)
|
||||
solution, shape_constraints, known_dim_vars = solve_dim_vars(
|
||||
tuple(args_avals), args_kwargs_tree=args_kwargs_tree)
|
||||
|
||||
# Replace the synthetic vars with the dynamic shape of the actual arg
|
||||
known_env = {vname: dimension_size_p.bind(actual_args[arg_idx], dimension=dim_idx)
|
||||
for (vname, arg_idx, dim_idx) in known_dim_vars}
|
||||
dim_values = [solution[var].evaluate(known_env) for var in dim_vars]
|
||||
shape_constraints.check(known_env)
|
||||
return tuple(dim_values)
|
||||
|
||||
|
||||
def _solve_dim_equations(eqns: List[DimEquation],
|
||||
use_static_dimension_size: bool) -> DimVarEnv:
|
||||
# Returns a shape environment if it can solve all dimension variables.
|
||||
# Raises an exception if it cannot.
|
||||
def _solve_dim_equations(
|
||||
eqns: List[_DimEquation]
|
||||
) -> Tuple[DimVarEnv, ShapeConstraints]:
|
||||
# Returns a shape environment and the shape constraints if it can solve all
|
||||
# dimension variables. Raises an exception if it cannot.
|
||||
shapeenv: DimVarEnv = {}
|
||||
|
||||
shape_constraints = ShapeConstraints()
|
||||
def _shapeenv_to_str() -> str:
|
||||
if shapeenv:
|
||||
return (" Partial solution: " +
|
||||
@ -1213,27 +1287,17 @@ def _solve_dim_equations(eqns: List[DimEquation],
|
||||
else:
|
||||
return ""
|
||||
|
||||
def process_one_eqn(eqn: DimEquation) -> bool:
|
||||
def process_one_eqn(eqn: _DimEquation) -> bool:
|
||||
# We start with a DimEquation of the form `dim_expr = dim_value`
|
||||
# Try to rewrite the equation as `var * factor_var = dim_value_2` (a linear
|
||||
# uni-variate equation). Returns `False` if this rewrite fails.
|
||||
# Otherwise, compute the `var` value as `dim_value_2 // factor`, add it to
|
||||
# `shapeenv` and return `True`.
|
||||
#
|
||||
# TODO: does not yet fully handle the cases when `dim_value` is not
|
||||
# divisible by `factor`, or when the value is not greater or equal to 1.
|
||||
|
||||
# Invariant:
|
||||
# var * factor_var + remaining_monomials_from_dim_expr = dim_value
|
||||
var, factor_var = None, None
|
||||
if use_static_dimension_size:
|
||||
dim_value = eqn.arg.shape[eqn.dim_idx]
|
||||
if _is_known_constant(dim_value):
|
||||
dim_value = core.dim_constant(dim_value)
|
||||
else:
|
||||
# We use the dimension_size_p primitive when we want to lower to code
|
||||
# that fetches the dimension size at compile-time.
|
||||
dim_value = dimension_size_p.bind(eqn.arg, dimension=eqn.dim_idx)
|
||||
dim_value = eqn.dim_value
|
||||
|
||||
for mon, factor in eqn.dim_expr.monomials():
|
||||
# Perhaps we can already evaluate this monomial (all vars solved)
|
||||
@ -1253,26 +1317,22 @@ def _solve_dim_equations(eqns: List[DimEquation],
|
||||
|
||||
if var is not None:
|
||||
if factor_var == 1:
|
||||
var_value, var_remainder = dim_value, core.dim_constant(0)
|
||||
var_value = dim_value
|
||||
else:
|
||||
var_value, var_remainder = divmod(dim_value, core.dim_constant(factor_var)) # type: ignore
|
||||
shape_constraints.add_constraint(
|
||||
ShapeConstraint.Comparator.EQ, var_remainder, 0,
|
||||
make_err_msg=lambda rem_int, _: (
|
||||
f"Dimension variable '{var}' must have integer value >= 1. "
|
||||
f"Non-zero remainder {rem_int} for factor {factor_var} when solving "
|
||||
f"{eqn}.{_shapeenv_to_str()}"))
|
||||
|
||||
# Check that the division is even. Works only in TF eager mode.
|
||||
# TODO: check dynamically if not possible statically
|
||||
# Or maybe we should abort right away if the remainder is not 0?
|
||||
var_remainder_int = _is_known_constant(var_remainder)
|
||||
if var_remainder_int is not None and var_remainder_int != 0:
|
||||
msg = (f"Dimension variable '{var}' must have integer value >= 1. " # type: ignore
|
||||
f"Found value {int(_is_known_constant(dim_value)) / factor_var} when solving " # type: ignore
|
||||
f"{eqn}.{_shapeenv_to_str()}")
|
||||
raise ValueError(msg)
|
||||
var_value_int = _is_known_constant(var_value)
|
||||
if var_value_int is not None and var_value_int <= 0:
|
||||
# TODO: check dynamically if not possible statically
|
||||
msg = (f"Dimension variable '{var}' must have integer value >= 1. "
|
||||
f"Found value {int(var_value_int)} when solving "
|
||||
f"{eqn}.{_shapeenv_to_str()}")
|
||||
raise ValueError(msg)
|
||||
shape_constraints.add_constraint(
|
||||
ShapeConstraint.Comparator.GEQ, var_value, 1,
|
||||
make_err_msg=lambda var_int, _: (
|
||||
f"Dimension variable '{var}' must have integer value >= 1. "
|
||||
f"Found {var_int} when "
|
||||
f"solving {eqn}.{_shapeenv_to_str()}"))
|
||||
|
||||
if not isinstance(var_value, _DimExpr):
|
||||
assert var_value.dtype == core.dim_value_dtype()
|
||||
@ -1280,20 +1340,18 @@ def _solve_dim_equations(eqns: List[DimEquation],
|
||||
return True
|
||||
else:
|
||||
# All variables are resolved for this equation
|
||||
dim_value_int = _is_known_constant(dim_value)
|
||||
if dim_value_int is not None and dim_value_int != 0:
|
||||
# TODO: check dynamically if not possible statically
|
||||
err_msg = (
|
||||
"Found inconsistency when solving "
|
||||
f"{eqn}.{_shapeenv_to_str()}")
|
||||
raise ValueError(err_msg)
|
||||
shape_constraints.add_constraint(
|
||||
ShapeConstraint.Comparator.EQ, eqn.dim_value,
|
||||
eqn.dim_expr.evaluate(shapeenv),
|
||||
make_err_msg=lambda val1, val2: (
|
||||
f"Found inconsistency {val1} != {val2} when solving {eqn}.{_shapeenv_to_str()}"))
|
||||
return True
|
||||
|
||||
while True:
|
||||
nr_eqns = len(eqns)
|
||||
eqns = [eqn for eqn in eqns if not process_one_eqn(eqn)]
|
||||
if not eqns:
|
||||
return shapeenv # SUCCESS
|
||||
return shapeenv, shape_constraints # SUCCESS
|
||||
elif len(eqns) >= nr_eqns:
|
||||
break
|
||||
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import re
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
@ -69,13 +68,25 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
def test_poly_export_only(self):
|
||||
a = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
def f(a):
|
||||
return jnp.concatenate([a, a], axis=0)
|
||||
def f(a, b): # a: f32[2w,h] b: f32[w,h]
|
||||
return jnp.concatenate([a, b], axis=0)
|
||||
|
||||
exp = jax_export.export(f)(
|
||||
jax_export.poly_spec(a.shape, a.dtype, "(2*w, h)"),
|
||||
jax_export.poly_spec(a.shape, a.dtype, "(w, h)"))
|
||||
self.assertEqual("(2*w, h)", str(exp.in_avals[0].shape))
|
||||
self.assertEqual("(w, h)", str(exp.in_avals[1].shape))
|
||||
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))
|
||||
|
||||
def test_poly_pytree_export_only(self):
|
||||
a = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
def f(a0, a1, *, ak):
|
||||
return jnp.concatenate([a0, a1, ak], axis=0)
|
||||
|
||||
a_poly_spec = jax_export.poly_spec(a.shape, a.dtype, "(w, h)")
|
||||
exp = jax_export.export(f)(a_poly_spec, a_poly_spec, ak=a_poly_spec)
|
||||
self.assertEqual("(w, h)", str(exp.in_avals[0].shape))
|
||||
self.assertEqual("(2*w, h)", str(exp.out_avals[0].shape))
|
||||
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))
|
||||
|
||||
def test_basic(self):
|
||||
f = jnp.sin
|
||||
@ -142,11 +153,11 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
exp_f = jax_export.export(f)(f32_4, b=f32_4)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Shape mismatch for args\[0\] in dimension 0"):
|
||||
r"Shape mismatch for args\[0\].shape\[0\]"):
|
||||
jax_export.call_exported(exp_f)(np.arange(6, dtype=np.float32), b=f32_4)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Shape mismatch for kwargs\['b'\] in dimension 0"):
|
||||
r"Shape mismatch for kwargs\['b'\].shape\[0\]"):
|
||||
jax_export.call_exported(exp_f)(f32_4, b=np.arange(6, dtype=np.float32))
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
@ -214,50 +225,65 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
jax_export.call_exported(exp_f2)(a))
|
||||
|
||||
# An inner function is exported with polymorphic shapes inner_poly_spec, and
|
||||
# is called from an outer function, that is exported with outer_poly_spec.
|
||||
# is called from an outer function, which is exported with outer_poly_spec.
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"inner={inner_poly_spec}_outer={outer_poly_spec}",
|
||||
inner_poly_spec=inner_poly_spec, outer_poly_spec=outer_poly_spec,
|
||||
expect_error=expect_error)
|
||||
for inner_poly_spec, outer_poly_spec, expect_error in (
|
||||
("3,a,a+b", "3,4,12", None),
|
||||
("3,a,a+b", "3,4,c", None),
|
||||
("3,a,a+b", "3,c,c", r"Dimension variable.*b.*must have.* >= 1. Found value 0"),
|
||||
("3,a,a+b", "c,4,12", r"Shape mismatch for args\[0\] in dimension 0"),
|
||||
("3,a,a+b", "3,c+4,12", None), # TODO: This should be an error, c = 0
|
||||
("3,4,3*a", "3,4,12", None),
|
||||
("3,4,5*a", "3,4,12", r"Dimension variable 'a' must have integer value >= 1. Found value 2.4"),
|
||||
# ("3,a,a", "3,a,a", None), # TODO: wrong error. It should be shape mismatch
|
||||
# ("3,4,5*a", "3,4,c", None), # TODO: wrong error. It should be "not divisible by 5"
|
||||
dict(testcase_name=f"inner={d['inner_poly_spec']}_outer={d['outer_poly_spec']}", # type: ignore
|
||||
**d) # type: ignore
|
||||
for d in (
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,12"),
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,c"),
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c,c",
|
||||
expect_error=(
|
||||
r"Dimension variable 'b' must have integer value >= 1. "
|
||||
r"Found 0 when solving a \+ b == args\[0\].shape\[2\]")),
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="c,4,12",
|
||||
expect_error=r"Shape mismatch for args\[0\].shape\[0\] \(expected constant\)"),
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c+4,12"), # TODO: This should be an error, c = 0
|
||||
dict(inner_poly_spec="3,4,3*a", outer_poly_spec="3,4,12"),
|
||||
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,12",
|
||||
expect_error=(
|
||||
r"Dimension variable 'a' must have integer value >= 1. "
|
||||
r"Non-zero remainder 2 for factor 5 when solving 5\*a == args\[0\].shape\[2\]")),
|
||||
# dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,c"), # TODO: there should be an error 5*a != c == 12
|
||||
# dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a"), # TODO: this should be a dynamic error
|
||||
dict(inner_poly_spec="3,a", inner_x_shape=(3, 4), outer_poly_spec="3,a,a",
|
||||
expect_error=r"Rank mismatch for args\[0\]"),
|
||||
dict(inner_poly_spec="3,a,a+b", inner_x_dtype=np.int32, outer_poly_spec="3,c,d",
|
||||
expect_error=r"Dtype mismatch for args\[0\]"),
|
||||
))
|
||||
def test_poly(self, inner_poly_spec="3,a,a+b",
|
||||
outer_poly_spec="3,4,12", expect_error=None):
|
||||
def test_poly(self, inner_poly_spec="3,a,a+b", inner_x_shape=(3, 4, 6),
|
||||
inner_x_dtype=np.float32,
|
||||
outer_poly_spec="3,c+4,12", outer_x_shape=(3, 4, 12),
|
||||
expect_error=None):
|
||||
# Polymorphic export called with static or polymorphic shapes
|
||||
def inner(x): # x: export_poly_spec
|
||||
return jnp.reshape(x, (x.shape[0] * x.shape[1], x.shape[2]))
|
||||
def inner(x): # x: inner_poly_spec
|
||||
return jnp.reshape(x, (-1, x.shape[1]))
|
||||
|
||||
x1 = np.arange(3 * 4 * 6, dtype=np.float32).reshape((3, 4, 6)) # x1 : f32[3,4,6]
|
||||
exp1 = jax_export.export(inner)(jax_export.poly_spec(x1.shape, x1.dtype, inner_poly_spec))
|
||||
inner_x = np.arange(np.prod(inner_x_shape),
|
||||
dtype=inner_x_dtype).reshape(inner_x_shape) # inner_x : f32[3,4,6]
|
||||
inner_exp = jax_export.export(inner)(
|
||||
jax_export.poly_spec(inner_x.shape, inner_x.dtype, inner_poly_spec))
|
||||
|
||||
x2 = np.concatenate([x1, x1], axis=2) # x2: f32[3,4,12]
|
||||
def outer(x): # x: call_poly_spec
|
||||
outer_x = np.arange(np.prod(outer_x_shape),
|
||||
dtype=np.float32).reshape(outer_x_shape) # outer_x : f32[3,4,12]
|
||||
def outer(x): # x: outer_poly_spec
|
||||
# Use an addition to test that the shapes are refined properly for the
|
||||
# result of the call_exported.
|
||||
return jax_export.call_exported(exp1)(x) + inner(x)
|
||||
return jax_export.call_exported(inner_exp)(x) + inner(x)
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
if expect_error is not None:
|
||||
stack.push(self.assertRaisesRegex(ValueError, expect_error))
|
||||
|
||||
# Call it after exporting again, with polymorphic shapes
|
||||
exp2 = jax_export.export(outer)(
|
||||
jax_export.poly_spec(x2.shape, x2.dtype, outer_poly_spec))
|
||||
outer_exp = jax_export.export(outer)(
|
||||
jax_export.poly_spec(outer_x.shape, outer_x.dtype, outer_poly_spec))
|
||||
# TODO: for now, we use XlaCallModule to run modules with polymorphic shapes
|
||||
# until we create the python bindings to invoke shape refinement.
|
||||
if jax2tf is not None:
|
||||
res2 = jax2tf._run_exported_as_tf([x2], exp2)[0].numpy()
|
||||
res2 = jax2tf._run_exported_as_tf([outer_x], outer_exp)[0].numpy()
|
||||
# res2 = jax_export.call_exported(exp2)(x2)
|
||||
self.assertAllClose(2. * inner(x2), res2)
|
||||
self.assertAllClose(2. * inner(outer_x), res2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -715,7 +715,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
("0_b_1_f32", lambda b: (0, b, 1, np.float32))
|
||||
]
|
||||
])
|
||||
def test_arange(self, make_args=lambda b: (0, 0, b)):
|
||||
def test_arange(self, make_args=lambda b: (0, -b, 2, None)):
|
||||
def f_jax(x): # x: i32[b]
|
||||
return x[0] + jnp.arange(*(make_args(x.shape[0])))
|
||||
x = np.ones((3,), dtype=np.int32)
|
||||
@ -850,9 +850,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
avals = tuple(map(shape_poly.arg_aval, arg_shapes, arg_dtypes, polymorphic_shapes))
|
||||
dim_vars = shape_poly.all_dim_vars(avals)
|
||||
dim_values, _ = jax2tf.jax2tf._interpret_fun_jax(
|
||||
partial(shape_poly.unify_avals_with_args, avals, dim_vars,
|
||||
use_static_dimension_size=False,
|
||||
args_kwargs_tree=tree_util.tree_flatten((avals, {}))[1]),
|
||||
partial(shape_poly.compute_dim_vars_from_arg_shapes,
|
||||
avals, args_kwargs_tree=tree_util.tree_flatten((avals, {}))[1]),
|
||||
args_tf, avals, "")
|
||||
if expected_avals is not None:
|
||||
self.assertEqual(expected_avals, avals)
|
||||
@ -988,7 +987,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Found inconsistency when solving.*"):
|
||||
"Found inconsistency 3 != 2 when solving.*"):
|
||||
check_avals(
|
||||
arg_shapes=[(2, 3)],
|
||||
polymorphic_shapes=["(a, a)"],
|
||||
@ -997,7 +996,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
# Same error across multiple arguments
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Found inconsistency when solving.*"):
|
||||
"Found inconsistency 5 != 2 when solving.*"):
|
||||
check_avals(
|
||||
arg_shapes=[(2, 3), (5,)],
|
||||
polymorphic_shapes=["a, ...", "a"],
|
||||
@ -1480,11 +1479,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
# JAX with static shapes sees that the x.shape[0] == 0
|
||||
self.assertEqual(jnp.array([0.], dtype=np.float32), f1_jax(x0))
|
||||
|
||||
# jax2tf catches the broken assumption b >= 1 if the converted function is executed
|
||||
# eagerly.
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Dimension variable 'b' must have integer value >= 1. Found value 0 when solving .*"):
|
||||
"Dimension variable 'b' must have integer value >= 1. Found 0"):
|
||||
jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
|
||||
native_serialization=False)(x0)
|
||||
|
||||
@ -1511,8 +1508,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
# jax2tf catches the broken assumption b >= 1 if the converted function is executed
|
||||
# eagerly.
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Found inconsistency when solving b == .*"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Found inconsistency 5 != 4 when solving b == args\[0\].shape\[1\]"):
|
||||
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
|
||||
native_serialization=False)(x45)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user