[shape_poly] Add static constraint checking to the computation of dim vars

Previously we had one function `shape_poly.unify_avals_with_args` that was
solving the dimension variables and was also used for generating the code
to compute them. Now we separate the solving part, which is now using just
symbolic expressions (`shape_poly.solve_dim_vars`), from the code
generator for the dimension variables (`shape_poly.compute_dim_vars_from_arg_shapes`).

We also add a notion of shape constraints, e.g., `dimexpr1 == dimexpr2` or
`dimexpr1 >= dimexpr2`, under which the solution for the dimension variables
is valid.

For now we implement the static checking of the shape constraints, e.g., when
the dimension expressions are constant or TF EagerTensor. We do not yet
have compile-time checking of the constraints. This matches
the previous behavior. However, the code is now ready for implementing
compile-time checking of the constraints that cannot be checked statically.
This commit is contained in:
George Necula 2023-05-13 16:57:27 +02:00
parent acfeb9bb13
commit 9ad8c3b9f1
7 changed files with 285 additions and 176 deletions

View File

@ -82,11 +82,14 @@ def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i)
def shape_tensor(sizes: Sequence[Union[int, ir.RankedTensorType]]
) -> ir.RankedTensorType:
int1d = aval_to_ir_type(core.ShapedArray((1,), np.int32))
i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
def lower_dim(d):
if type(d) is int:
return ir_constant(np.array([d], np.int32))
else:
return hlo.ReshapeOp(int1d, hlo.ConvertOp(aval_to_ir_type(core.ShapedArray((), np.int32)), d))
if d.type != i32_type:
d = hlo.ConvertOp(i32_type, d)
return hlo.ReshapeOp(int1d, d)
ds = map(lower_dim, sizes)
if not ds:
return ir_constant(np.array([], np.int32))

View File

@ -2291,7 +2291,7 @@ def arange(start: DimSize, stop: Optional[DimSize] = None,
step = 1
elif stop is not None and step is None:
step = 1
return _arange_dynamic(start, stop, step, dtype or int_)
return _arange_dynamic(start, stop, step, dtype or dtypes.canonicalize_dtype(np.int64))
if dtype is None:
dtype = result_type(start, *(x for x in [stop, step] if x is not None))
dtype = _jnp_dtype(dtype)

View File

@ -564,9 +564,8 @@ class GraphSerializationImpl(SerializationImpl):
map(lambda a: core.raise_to_shaped(core.get_aval(a)), args_specs_flat))
dim_vars = shape_poly.all_dim_vars(self.args_avals_flat)
dim_values, _ = _interpret_fun_jax(
partial(shape_poly.unify_avals_with_args, self.args_avals_flat, dim_vars,
use_static_dimension_size=False,
args_kwargs_tree=self.in_tree),
partial(shape_poly.compute_dim_vars_from_arg_shapes,
self.args_avals_flat, args_kwargs_tree=self.in_tree),
self.args_flat_tf, self.args_avals_flat, self.name_stack)
_thread_local_state.shape_env = zip(dim_vars, dim_values)

View File

@ -483,12 +483,11 @@ def _compute_dim_args(
the values of the dimension variables, in the sorted order of the
dimension variables.
"""
dim_vars = shape_poly.all_dim_vars(args_avals_flat)
dim_values = mlir.lower_fun(
functools.partial(shape_poly.unify_avals_with_args, args_avals_flat, dim_vars,
use_static_dimension_size=False,
args_kwargs_tree=args_kwargs_tree),
functools.partial(shape_poly.compute_dim_vars_from_arg_shapes,
args_avals_flat, args_kwargs_tree=args_kwargs_tree),
multiple_results=True)(ctx, *array_args)
res = []
for dim_arg, dim_arg_type in zip(util.flatten(dim_values), dim_arg_types):
if dim_arg.type != dim_arg_type:
@ -752,16 +751,42 @@ def call_exported(exported: Exported) -> Callable[..., jax.Array]:
call_exported_p = core.Primitive("call_exported")
call_exported_p.multiple_results = True
@util.cache()
def _call_exported_abstract_eval(*in_avals: core.AbstractValue,
exported: Exported) -> Tuple[core.AbstractValue, ...]:
exported_dim_vars = shape_poly.all_dim_vars(exported.in_avals)
assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure
# Must express the exported_dim_vars in terms of the shapes in in_avals.
exported_dim_values = shape_poly.unify_avals_with_args(
exported.in_avals, exported_dim_vars, *in_avals, # type: ignore
use_static_dimension_size=True,
args_kwargs_tree=exported.in_tree)
# Check that the expected shapes match the actual ones
for arg_idx, (exp_aval, actual_aval) in enumerate(zip(exported.in_avals, in_avals)):
def pp_arg_dim(dim_idx: Optional[int]) -> str:
return shape_poly.pretty_print_dimension_descriptor(exported.in_tree,
arg_idx, dim_idx)
if len(exp_aval.shape) != len(actual_aval.shape):
raise ValueError(
f"Rank mismatch for {pp_arg_dim(None)}: expected {exp_aval.shape} "
f"and called with {actual_aval.shape}")
if exp_aval.dtype != actual_aval.dtype:
raise ValueError(
f"Dtype mismatch for {pp_arg_dim(None)}: expected {exp_aval.dtype} "
f"and called with {actual_aval.dtype}")
for dim_idx, aval_d in enumerate(exp_aval.shape):
# If the exp_aval has a constant dimension then the actual argument must have
# a matching constant dimension.
if core.is_constant_dim(aval_d):
if (not core.is_constant_dim(actual_aval.shape[dim_idx]) or
aval_d != actual_aval.shape[dim_idx]):
raise ValueError(
f"Shape mismatch for {pp_arg_dim(dim_idx)} (expected constant): "
f"expected {exp_aval.shape} and called with {actual_aval.shape}")
# Must express the exported_dim_vars in terms of the shapes in in_avals.
solution, shape_constraints, known_dim_vars = shape_poly.solve_dim_vars(
exported.in_avals, args_kwargs_tree=exported.in_tree)
known_env = {vname: in_avals[arg_idx].shape[dim_idx]
for (vname, arg_idx, dim_idx) in known_dim_vars}
shape_constraints.check(known_env)
exported_dim_values = [solution[var].evaluate(known_env)
for var in exported_dim_vars]
return tuple(
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars,
*exported_dim_values),

View File

@ -32,6 +32,7 @@ jax2tf.convert docstring, and the
"""
import collections
import dataclasses
from enum import Enum
import functools
import itertools
import io
@ -53,6 +54,7 @@ from jax._src import dtypes
from jax._src.interpreters import mlir
from jax._src.numpy import lax_numpy
from jax._src import tree_util
from jax._src import util
from jax._src.typing import DimSize, Shape
@ -640,7 +642,8 @@ class _DimExpr():
def evaluate(self, env: DimVarEnv):
# Evaluates as a value of dtype=core.dim_value_dtype()
terms = [_evaluate_multiply(mon.evaluate(env), core.dim_constant(coeff)) for mon, coeff in self.monomials()]
terms = [_evaluate_multiply(mon.evaluate(env), core.dim_constant(coeff))
for mon, coeff in self.monomials()]
return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0]
@staticmethod
@ -1098,21 +1101,80 @@ def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Sequence[str]:
dim_vars = dim_vars.union(d.get_vars())
return sorted(tuple(dim_vars))
@dataclasses.dataclass
class DimEquation:
# Represents arg.shape[dim_idx] == dim_expr
arg: jax.Array
dim_idx: int
dim_expr: _DimExpr
debug_arg_str: Callable[[], str] # A pretty-printer for a descriptor for `arg`
@dataclasses.dataclass(frozen=True)
class ShapeConstraint:
class Comparator(Enum):
EQ = 1
GEQ = 2
comp: Comparator
left: DimSize
right: DimSize
# make_err_msg is invoked with (left_int, right_int) if the constraint fails.
make_err_msg: Callable[[int, int], str]
def check(self, shapeenv: DimVarEnv) -> None:
"""Evaluates a constraint statically and raises an error if fails."""
def eval_operand(o: DimSize) -> Union[int, jax.Array]:
if core.is_constant_dim(o): return op.index(o)
return o.evaluate(shapeenv) # type: ignore
try:
left1, right1 = eval_operand(self.left), eval_operand(self.right)
except KeyError:
return None
left_int, right_int = _is_known_constant(left1), _is_known_constant(right1)
if left_int is not None and right_int is not None:
if self.comp == ShapeConstraint.Comparator.EQ:
if not (left_int == right_int):
raise ValueError(self.make_err_msg(left_int, right_int))
elif self.comp == ShapeConstraint.Comparator.GEQ:
if not (left_int >= right_int):
raise ValueError(self.make_err_msg(left_int, right_int))
else: assert False
else:
return None # TODO: evaluate constraint dynamically
def __str__(self):
return (f"{self.dim_expr} == {self.debug_arg_str()}"
f".shape[{self.dim_idx}] (statically {self.arg.shape[self.dim_idx]})")
return (f"{self.left} {'==' if self.comp == ShapeConstraint.Comparator.EQ else '>='} {self.right}"
f" ({self.make_err_msg(self.left, self.right)})")
__repr__ = __str__
class ShapeConstraints:
def __init__(self):
self.constraints: Set[ShapeConstraint] = set() # map DimConstraint to an integer >= 0
def add_constraint(self,
comp: ShapeConstraint.Comparator,
left: DimSize, right: DimSize,
make_err_msg: Callable[[int, int], str]):
# Try to evaluate it statically
c = ShapeConstraint(comp, left, right, make_err_msg)
self.constraints.add(c)
def check(self, shapeenv: DimVarEnv) -> None:
for constraint in self.constraints:
constraint.check(shapeenv)
@dataclasses.dataclass
class _DimEquation:
# Represents dim_expr == dim_value, where `dim_expr` contain unknown dimension
# variables, in terms of `dim_value`.
dim_expr: _DimExpr
dim_value: _DimExpr
def __str__(self):
return f"{self.dim_expr} == {self.dim_value}"
__repr__ = __str__
def args_kwargs_path_to_str(path: tree_util.KeyPath) -> str:
# String description of args or kwargs, assuming the path is in a tree for
# the tuple (args, kwargs)
# String description of `args` or `kwargs`, assuming the path for a tree for
# the tuple `(args, kwargs)`.
if path[0] == tree_util.SequenceKey(0):
return f"args{tree_util.keystr(path[1:])}"
elif path[0] == tree_util.SequenceKey(1):
@ -1120,92 +1182,104 @@ def args_kwargs_path_to_str(path: tree_util.KeyPath) -> str:
else:
assert False
def unify_avals_with_args(
args_avals: Sequence[core.AbstractValue],
dim_vars: Sequence[str],
*args: jax.Array,
use_static_dimension_size: bool,
def pretty_print_dimension_descriptor(
args_kwargs_tree: tree_util.PyTreeDef,
) -> Sequence[jax.Array]:
"""Computes values of dimension variables to unify avals with actual arguments.
flat_arg_idx: int, dim_idx: Optional[int]) -> str:
args_kwargs_with_paths, _ = tree_util.tree_flatten_with_path(
args_kwargs_tree.unflatten((0,) * args_kwargs_tree.num_leaves))
arg_str = args_kwargs_path_to_str(args_kwargs_with_paths[flat_arg_idx][0])
if dim_idx is not None:
arg_str += f".shape[{dim_idx}]"
return arg_str
Computes values for dimension variables for which the shapes in `args_avals`
(abstract values for a function's parameters) match the shapes of `args` (the
actual arguments). This is done by forming equations
between the symbolic expressions from `args_avals` and the actual dimension
sizes of the actual arguments, and then solving for the dimension variables.
@util.cache()
def solve_dim_vars(
args_avals: Sequence[core.AbstractValue],
args_kwargs_tree: tree_util.PyTreeDef,
) -> Tuple[DimVarEnv, ShapeConstraints, Sequence[Tuple[str, int, int]]]:
"""Solves dimension variables in a called function's avals in terms of actual argument shapes.
Not all equations are solvable. For now, the linear uni-variate equations
are solved first, then the solved variables are used to simplify the
remaining equations to linear uni-variate equations, and the process continues
For example, given:
args_avals = [ShapedArray((3, a, a + b), f32)]
we introduce fresh "known" dimension variables to represent the actual dimension
size of actual arguments for each non-constant dimension. Each known variable
has a name, an arg_idx, and a dim_idx, e.g.:
known_vars = [("args[0].shape[1]", 0, 1), ("args[0].shape[2]", 0, 2)]
and then we express the solution for the unknown dimension variables {a, b}
as symbolic expressions in terms of the known variables:
dict(a=args[0].shape[1], b=args[0].shape[2] - args[0].shape[1])
Not all equations are solvable. For now, we solve first the linear uni-variate
equations, then the solved variables are used to simplify the remaining
equations to linear uni-variate equations, and the process continues
until all dimension variables are solved.
Args:
args_avals: the abstract values of the `args`, with shapes that may
include dimension variables. A flat sequence.
dim_vars: the dimension variables that occur in `args_avals`. The only
reason we need these is to ensure that the result of this function is a
flat list of jax.Array in the same order.
args: the actual function arguments, as jax.Array or any value with `.shape`
and `.dtype` attributes if `use_static_dimension_size`.
use_static_dimension_size: if `True` then it forms the equations using the
static shapes of `args`. This is useful, e.g., when we want to compute
the dimension variables statically. If `False` then it forms the
equations using the dynamic shapes of `args`, e.g., using
`stablehlo.GetDimensionSizeOp` for native serialization or `tf.shape`
for TF graph serialization. This is useful when we want to generate
code to compute the dimension variables at compilation-time.
include unknown dimension variables.
args_kwargs_tree: a PyTreeDef that describes the tuple `(args, kwargs)` from
which the flat sequence `args_avals` is extracted. Used for referencing args and
kwargs in error messages.
which the flat sequence `args_avals` is extracted. Used for describing
args and kwargs in known variable names and in error messages.
Returns: the values of `dim_vars` in the same order.
Returns: a 3-tuple with: (a) the solution for the unknown dimension variables
(b) a list of constraints that must be satisfied for the solution to be a
valid one, and (c) and the list of known variables that may appear in
the solution and the constraints.
Raises ValueError if it cannot solve for the `dim_vars`.
Raises ValueError if it cannot solve some dimension variable.
"""
dim_equations: List[DimEquation] = []
def debug_arg_str(flat_arg_idx: int) -> str:
# Debug descriptor of an argument.
args_avals_tree = args_kwargs_tree.unflatten(args_avals)
args_avals_with_paths, _ = tree_util.tree_flatten_with_path(args_avals_tree)
return args_kwargs_path_to_str(args_avals_with_paths[flat_arg_idx][0])
for arg_idx, (aval, arg) in enumerate(zip(args_avals, args)):
if len(aval.shape) != len(arg.shape):
raise ValueError(
f"Rank mismatch for {debug_arg_str(arg_idx)}: expected {aval.shape} "
f"and called with {arg.shape}")
continue
if aval.dtype != arg.dtype:
raise ValueError(
f"Dtype mismatch for {debug_arg_str(arg_idx)}: expected {aval.dtype} "
f"and called with {arg.dtype}")
dim_equations: List[_DimEquation] = []
known_dimension_vars: List[Tuple[str, int, int]] = []
for arg_idx, aval in enumerate(args_avals):
for dim_idx, aval_d in enumerate(aval.shape):
# If the aval has a constant dimension then the actual argument must have
# a matching constant dimension.
if not is_poly_dim(aval_d):
if _is_known_constant(arg.shape[dim_idx]) is None or aval_d != arg.shape[dim_idx]:
raise ValueError(
f"Shape mismatch for {debug_arg_str(arg_idx)} in dimension {dim_idx}: "
f"expected {aval.shape} and called with {arg.shape}")
else:
if is_poly_dim(aval_d):
known_dim_var = pretty_print_dimension_descriptor(args_kwargs_tree,
arg_idx, dim_idx)
known_dimension_vars.append((known_dim_var, arg_idx, dim_idx))
dim_equations.append(
DimEquation(arg=arg,
dim_idx=dim_idx, dim_expr=_ensure_poly(aval_d, "unify_avals_with_args"),
debug_arg_str=functools.partial(debug_arg_str, arg_idx)))
_DimEquation(dim_expr=_ensure_poly(aval_d, "solve_dim_vars"),
dim_value=_DimExpr.from_var(known_dim_var)))
dim_env = _solve_dim_equations(dim_equations,
use_static_dimension_size=use_static_dimension_size)
dim_values = tuple(dim_env[dv] for dv in dim_vars)
return dim_values
solution, shape_constraints = _solve_dim_equations(dim_equations)
return solution, shape_constraints, known_dimension_vars
def compute_dim_vars_from_arg_shapes(
args_avals: Sequence[core.AbstractValue],
*actual_args: jax.Array,
args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[jax.Array]:
"""Computes values of dimension variables to unify args_avals with actual arguments.
Like `solve_dim_vars` except that here we express the solution as
JAX arrays that reference the `actual_args`. This function can be used to
generate the code for computing the dimension variables.
Returns: the values of the dimension variables, in the order determined by
`all_dim_vars(args_avals)`.
"""
dim_vars = all_dim_vars(args_avals)
solution, shape_constraints, known_dim_vars = solve_dim_vars(
tuple(args_avals), args_kwargs_tree=args_kwargs_tree)
# Replace the synthetic vars with the dynamic shape of the actual arg
known_env = {vname: dimension_size_p.bind(actual_args[arg_idx], dimension=dim_idx)
for (vname, arg_idx, dim_idx) in known_dim_vars}
dim_values = [solution[var].evaluate(known_env) for var in dim_vars]
shape_constraints.check(known_env)
return tuple(dim_values)
def _solve_dim_equations(eqns: List[DimEquation],
use_static_dimension_size: bool) -> DimVarEnv:
# Returns a shape environment if it can solve all dimension variables.
# Raises an exception if it cannot.
def _solve_dim_equations(
eqns: List[_DimEquation]
) -> Tuple[DimVarEnv, ShapeConstraints]:
# Returns a shape environment and the shape constraints if it can solve all
# dimension variables. Raises an exception if it cannot.
shapeenv: DimVarEnv = {}
shape_constraints = ShapeConstraints()
def _shapeenv_to_str() -> str:
if shapeenv:
return (" Partial solution: " +
@ -1213,27 +1287,17 @@ def _solve_dim_equations(eqns: List[DimEquation],
else:
return ""
def process_one_eqn(eqn: DimEquation) -> bool:
def process_one_eqn(eqn: _DimEquation) -> bool:
# We start with a DimEquation of the form `dim_expr = dim_value`
# Try to rewrite the equation as `var * factor_var = dim_value_2` (a linear
# uni-variate equation). Returns `False` if this rewrite fails.
# Otherwise, compute the `var` value as `dim_value_2 // factor`, add it to
# `shapeenv` and return `True`.
#
# TODO: does not yet fully handle the cases when `dim_value` is not
# divisible by `factor`, or when the value is not greater or equal to 1.
# Invariant:
# var * factor_var + remaining_monomials_from_dim_expr = dim_value
var, factor_var = None, None
if use_static_dimension_size:
dim_value = eqn.arg.shape[eqn.dim_idx]
if _is_known_constant(dim_value):
dim_value = core.dim_constant(dim_value)
else:
# We use the dimension_size_p primitive when we want to lower to code
# that fetches the dimension size at compile-time.
dim_value = dimension_size_p.bind(eqn.arg, dimension=eqn.dim_idx)
dim_value = eqn.dim_value
for mon, factor in eqn.dim_expr.monomials():
# Perhaps we can already evaluate this monomial (all vars solved)
@ -1253,26 +1317,22 @@ def _solve_dim_equations(eqns: List[DimEquation],
if var is not None:
if factor_var == 1:
var_value, var_remainder = dim_value, core.dim_constant(0)
var_value = dim_value
else:
var_value, var_remainder = divmod(dim_value, core.dim_constant(factor_var)) # type: ignore
shape_constraints.add_constraint(
ShapeConstraint.Comparator.EQ, var_remainder, 0,
make_err_msg=lambda rem_int, _: (
f"Dimension variable '{var}' must have integer value >= 1. "
f"Non-zero remainder {rem_int} for factor {factor_var} when solving "
f"{eqn}.{_shapeenv_to_str()}"))
# Check that the division is even. Works only in TF eager mode.
# TODO: check dynamically if not possible statically
# Or maybe we should abort right away if the remainder is not 0?
var_remainder_int = _is_known_constant(var_remainder)
if var_remainder_int is not None and var_remainder_int != 0:
msg = (f"Dimension variable '{var}' must have integer value >= 1. " # type: ignore
f"Found value {int(_is_known_constant(dim_value)) / factor_var} when solving " # type: ignore
f"{eqn}.{_shapeenv_to_str()}")
raise ValueError(msg)
var_value_int = _is_known_constant(var_value)
if var_value_int is not None and var_value_int <= 0:
# TODO: check dynamically if not possible statically
msg = (f"Dimension variable '{var}' must have integer value >= 1. "
f"Found value {int(var_value_int)} when solving "
f"{eqn}.{_shapeenv_to_str()}")
raise ValueError(msg)
shape_constraints.add_constraint(
ShapeConstraint.Comparator.GEQ, var_value, 1,
make_err_msg=lambda var_int, _: (
f"Dimension variable '{var}' must have integer value >= 1. "
f"Found {var_int} when "
f"solving {eqn}.{_shapeenv_to_str()}"))
if not isinstance(var_value, _DimExpr):
assert var_value.dtype == core.dim_value_dtype()
@ -1280,20 +1340,18 @@ def _solve_dim_equations(eqns: List[DimEquation],
return True
else:
# All variables are resolved for this equation
dim_value_int = _is_known_constant(dim_value)
if dim_value_int is not None and dim_value_int != 0:
# TODO: check dynamically if not possible statically
err_msg = (
"Found inconsistency when solving "
f"{eqn}.{_shapeenv_to_str()}")
raise ValueError(err_msg)
shape_constraints.add_constraint(
ShapeConstraint.Comparator.EQ, eqn.dim_value,
eqn.dim_expr.evaluate(shapeenv),
make_err_msg=lambda val1, val2: (
f"Found inconsistency {val1} != {val2} when solving {eqn}.{_shapeenv_to_str()}"))
return True
while True:
nr_eqns = len(eqns)
eqns = [eqn for eqn in eqns if not process_one_eqn(eqn)]
if not eqns:
return shapeenv # SUCCESS
return shapeenv, shape_constraints # SUCCESS
elif len(eqns) >= nr_eqns:
break

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import re
import unittest
from absl.testing import absltest, parameterized
@ -69,13 +68,25 @@ class JaxExportTest(jtu.JaxTestCase):
def test_poly_export_only(self):
a = np.arange(12, dtype=np.float32).reshape((3, 4))
def f(a):
return jnp.concatenate([a, a], axis=0)
def f(a, b): # a: f32[2w,h] b: f32[w,h]
return jnp.concatenate([a, b], axis=0)
exp = jax_export.export(f)(
jax_export.poly_spec(a.shape, a.dtype, "(2*w, h)"),
jax_export.poly_spec(a.shape, a.dtype, "(w, h)"))
self.assertEqual("(2*w, h)", str(exp.in_avals[0].shape))
self.assertEqual("(w, h)", str(exp.in_avals[1].shape))
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))
def test_poly_pytree_export_only(self):
a = np.arange(12, dtype=np.float32).reshape((3, 4))
def f(a0, a1, *, ak):
return jnp.concatenate([a0, a1, ak], axis=0)
a_poly_spec = jax_export.poly_spec(a.shape, a.dtype, "(w, h)")
exp = jax_export.export(f)(a_poly_spec, a_poly_spec, ak=a_poly_spec)
self.assertEqual("(w, h)", str(exp.in_avals[0].shape))
self.assertEqual("(2*w, h)", str(exp.out_avals[0].shape))
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))
def test_basic(self):
f = jnp.sin
@ -142,11 +153,11 @@ class JaxExportTest(jtu.JaxTestCase):
exp_f = jax_export.export(f)(f32_4, b=f32_4)
with self.assertRaisesRegex(ValueError,
r"Shape mismatch for args\[0\] in dimension 0"):
r"Shape mismatch for args\[0\].shape\[0\]"):
jax_export.call_exported(exp_f)(np.arange(6, dtype=np.float32), b=f32_4)
with self.assertRaisesRegex(ValueError,
r"Shape mismatch for kwargs\['b'\] in dimension 0"):
r"Shape mismatch for kwargs\['b'\].shape\[0\]"):
jax_export.call_exported(exp_f)(f32_4, b=np.arange(6, dtype=np.float32))
with self.assertRaisesRegex(ValueError,
@ -214,50 +225,65 @@ class JaxExportTest(jtu.JaxTestCase):
jax_export.call_exported(exp_f2)(a))
# An inner function is exported with polymorphic shapes inner_poly_spec, and
# is called from an outer function, that is exported with outer_poly_spec.
# is called from an outer function, which is exported with outer_poly_spec.
@parameterized.named_parameters(
dict(testcase_name=f"inner={inner_poly_spec}_outer={outer_poly_spec}",
inner_poly_spec=inner_poly_spec, outer_poly_spec=outer_poly_spec,
expect_error=expect_error)
for inner_poly_spec, outer_poly_spec, expect_error in (
("3,a,a+b", "3,4,12", None),
("3,a,a+b", "3,4,c", None),
("3,a,a+b", "3,c,c", r"Dimension variable.*b.*must have.* >= 1. Found value 0"),
("3,a,a+b", "c,4,12", r"Shape mismatch for args\[0\] in dimension 0"),
("3,a,a+b", "3,c+4,12", None), # TODO: This should be an error, c = 0
("3,4,3*a", "3,4,12", None),
("3,4,5*a", "3,4,12", r"Dimension variable 'a' must have integer value >= 1. Found value 2.4"),
# ("3,a,a", "3,a,a", None), # TODO: wrong error. It should be shape mismatch
# ("3,4,5*a", "3,4,c", None), # TODO: wrong error. It should be "not divisible by 5"
dict(testcase_name=f"inner={d['inner_poly_spec']}_outer={d['outer_poly_spec']}", # type: ignore
**d) # type: ignore
for d in (
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,12"),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,c"),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c,c",
expect_error=(
r"Dimension variable 'b' must have integer value >= 1. "
r"Found 0 when solving a \+ b == args\[0\].shape\[2\]")),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="c,4,12",
expect_error=r"Shape mismatch for args\[0\].shape\[0\] \(expected constant\)"),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c+4,12"), # TODO: This should be an error, c = 0
dict(inner_poly_spec="3,4,3*a", outer_poly_spec="3,4,12"),
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,12",
expect_error=(
r"Dimension variable 'a' must have integer value >= 1. "
r"Non-zero remainder 2 for factor 5 when solving 5\*a == args\[0\].shape\[2\]")),
# dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,c"), # TODO: there should be an error 5*a != c == 12
# dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a"), # TODO: this should be a dynamic error
dict(inner_poly_spec="3,a", inner_x_shape=(3, 4), outer_poly_spec="3,a,a",
expect_error=r"Rank mismatch for args\[0\]"),
dict(inner_poly_spec="3,a,a+b", inner_x_dtype=np.int32, outer_poly_spec="3,c,d",
expect_error=r"Dtype mismatch for args\[0\]"),
))
def test_poly(self, inner_poly_spec="3,a,a+b",
outer_poly_spec="3,4,12", expect_error=None):
def test_poly(self, inner_poly_spec="3,a,a+b", inner_x_shape=(3, 4, 6),
inner_x_dtype=np.float32,
outer_poly_spec="3,c+4,12", outer_x_shape=(3, 4, 12),
expect_error=None):
# Polymorphic export called with static or polymorphic shapes
def inner(x): # x: export_poly_spec
return jnp.reshape(x, (x.shape[0] * x.shape[1], x.shape[2]))
def inner(x): # x: inner_poly_spec
return jnp.reshape(x, (-1, x.shape[1]))
x1 = np.arange(3 * 4 * 6, dtype=np.float32).reshape((3, 4, 6)) # x1 : f32[3,4,6]
exp1 = jax_export.export(inner)(jax_export.poly_spec(x1.shape, x1.dtype, inner_poly_spec))
inner_x = np.arange(np.prod(inner_x_shape),
dtype=inner_x_dtype).reshape(inner_x_shape) # inner_x : f32[3,4,6]
inner_exp = jax_export.export(inner)(
jax_export.poly_spec(inner_x.shape, inner_x.dtype, inner_poly_spec))
x2 = np.concatenate([x1, x1], axis=2) # x2: f32[3,4,12]
def outer(x): # x: call_poly_spec
outer_x = np.arange(np.prod(outer_x_shape),
dtype=np.float32).reshape(outer_x_shape) # outer_x : f32[3,4,12]
def outer(x): # x: outer_poly_spec
# Use an addition to test that the shapes are refined properly for the
# result of the call_exported.
return jax_export.call_exported(exp1)(x) + inner(x)
return jax_export.call_exported(inner_exp)(x) + inner(x)
with contextlib.ExitStack() as stack:
if expect_error is not None:
stack.push(self.assertRaisesRegex(ValueError, expect_error))
# Call it after exporting again, with polymorphic shapes
exp2 = jax_export.export(outer)(
jax_export.poly_spec(x2.shape, x2.dtype, outer_poly_spec))
outer_exp = jax_export.export(outer)(
jax_export.poly_spec(outer_x.shape, outer_x.dtype, outer_poly_spec))
# TODO: for now, we use XlaCallModule to run modules with polymorphic shapes
# until we create the python bindings to invoke shape refinement.
if jax2tf is not None:
res2 = jax2tf._run_exported_as_tf([x2], exp2)[0].numpy()
res2 = jax2tf._run_exported_as_tf([outer_x], outer_exp)[0].numpy()
# res2 = jax_export.call_exported(exp2)(x2)
self.assertAllClose(2. * inner(x2), res2)
self.assertAllClose(2. * inner(outer_x), res2)
if __name__ == "__main__":

View File

@ -715,7 +715,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
("0_b_1_f32", lambda b: (0, b, 1, np.float32))
]
])
def test_arange(self, make_args=lambda b: (0, 0, b)):
def test_arange(self, make_args=lambda b: (0, -b, 2, None)):
def f_jax(x): # x: i32[b]
return x[0] + jnp.arange(*(make_args(x.shape[0])))
x = np.ones((3,), dtype=np.int32)
@ -850,9 +850,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
avals = tuple(map(shape_poly.arg_aval, arg_shapes, arg_dtypes, polymorphic_shapes))
dim_vars = shape_poly.all_dim_vars(avals)
dim_values, _ = jax2tf.jax2tf._interpret_fun_jax(
partial(shape_poly.unify_avals_with_args, avals, dim_vars,
use_static_dimension_size=False,
args_kwargs_tree=tree_util.tree_flatten((avals, {}))[1]),
partial(shape_poly.compute_dim_vars_from_arg_shapes,
avals, args_kwargs_tree=tree_util.tree_flatten((avals, {}))[1]),
args_tf, avals, "")
if expected_avals is not None:
self.assertEqual(expected_avals, avals)
@ -988,7 +987,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
with self.assertRaisesRegex(
ValueError,
"Found inconsistency when solving.*"):
"Found inconsistency 3 != 2 when solving.*"):
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=["(a, a)"],
@ -997,7 +996,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
# Same error across multiple arguments
with self.assertRaisesRegex(
ValueError,
"Found inconsistency when solving.*"):
"Found inconsistency 5 != 2 when solving.*"):
check_avals(
arg_shapes=[(2, 3), (5,)],
polymorphic_shapes=["a, ...", "a"],
@ -1480,11 +1479,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
# JAX with static shapes sees that the x.shape[0] == 0
self.assertEqual(jnp.array([0.], dtype=np.float32), f1_jax(x0))
# jax2tf catches the broken assumption b >= 1 if the converted function is executed
# eagerly.
with self.assertRaisesRegex(
ValueError,
"Dimension variable 'b' must have integer value >= 1. Found value 0 when solving .*"):
"Dimension variable 'b' must have integer value >= 1. Found 0"):
jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
native_serialization=False)(x0)
@ -1511,8 +1508,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
# jax2tf catches the broken assumption b >= 1 if the converted function is executed
# eagerly.
with self.assertRaisesRegex(ValueError,
"Found inconsistency when solving b == .*"):
with self.assertRaisesRegex(
ValueError,
r"Found inconsistency 5 != 4 when solving b == args\[0\].shape\[1\]"):
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
native_serialization=False)(x45)