Merge pull request #16211 from gnecula:poly_checks

PiperOrigin-RevId: 540030091
This commit is contained in:
jax authors 2023-06-13 11:54:50 -07:00
commit e9f2e40a3e
6 changed files with 744 additions and 301 deletions

View File

@ -22,7 +22,7 @@ import re
import os
import tempfile
import textwrap
from typing import Callable, List, Generator, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Generator, Optional, Sequence, Tuple, Union
import unittest
import warnings
import zlib
@ -1207,3 +1207,42 @@ def _parse_version(v: str) -> Tuple[int, ...]:
def numpy_version():
return _parse_version(np.__version__)
def parameterized_filterable(*,
kwargs: Sequence[Dict[str, Any]],
testcase_name: Optional[Callable[[Dict[str, Any]], str]] = None,
one_containing: Optional[str] = None,
):
"""
Decorator for named parameterized tests, with filtering.
Works like parameterized.named_parameters, except that it supports the
`one_containing` option. This is useful to select only one of the tests,
and to leave the test name unchanged (helps with specifying the desired test
when debugging).
Args:
kwargs: Each entry is a set of kwargs to be passed to the test function.
testcase_name: Optionally, a function to construct the testcase_name from
one kwargs dict. If not given then the kwarg must contain `testcase_name`.
one_containing: If given, then leave the test name unchanged, and use
only one `kwargs` whose `testcase_name` includes `one_containing`.
"""
# Ensure that all kwargs contain a testcase_name
kwargs_with_testcase_name: Sequence[Dict[str, Any]]
if testcase_name is not None:
kwargs_with_testcase_name = [dict(testcase_name=testcase_name(kw), **kw)
for kw in kwargs]
else:
for kw in kwargs:
assert "testcase_name" in kw
kwargs_with_testcase_name = kwargs
if one_containing is not None:
filtered = tuple(kw for kw in kwargs_with_testcase_name
if one_containing in kw["testcase_name"])
assert filtered, f"No testcase_name contains '{one_containing}'"
kw = filtered[0]
kw["testcase_name"] = ""
return parameterized.named_parameters([kw])
else:
return parameterized.named_parameters(*kwargs_with_testcase_name)

View File

@ -575,6 +575,18 @@ class GraphSerializationImpl(SerializationImpl):
partial(shape_poly.compute_dim_vars_from_arg_shapes,
self.args_avals_flat, args_kwargs_tree=self.in_tree),
self.args_flat_tf, self.args_avals_flat, self.name_stack)
# We invoke shape checking to give it a chance to raise shape errors that
# are evident statically. This should work in TF eager mode because all
# the shapes are known.
# TODO: handle non-static shape checking for graph serialization
acc_shape_check_messages: List[str] = []
_, _ = _interpret_fun_jax(
partial(shape_poly.compute_shape_check_from_arg_shapes,
self.args_avals_flat, args_kwargs_tree=self.in_tree,
acc_shape_check_messages=acc_shape_check_messages),
self.args_flat_tf, self.args_avals_flat, self.name_stack)
_thread_local_state.shape_env = zip(dim_vars, dim_values)
fun_flat_jax, out_tree_thunk = flatten_fun_jax(self.fun_jax, self.in_tree)

View File

@ -23,6 +23,7 @@ import re
from typing import Any, Callable, List, Optional, Sequence, Set, Tuple, Union
from absl import logging
import numpy as np
import jax
from jax import sharding
@ -101,10 +102,10 @@ class Exported:
expressions in the shapes, with dimension variables among those in
`in_avals.
in_shardings: the flattened input shardings. Only for the inputs that are
specified in `module_kept_var_idx`.
specified in `module_kept_var_idx`. If `None` then it is equivalent
to unspecified shardings.
out_shardings: the flattened output shardings, as long as `in_avals`.
lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'
mlir_module_serialized: the serialized lowered VHLO module.
xla_call_module_version: a version number for the serialized module.
See more versioning details at https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+VERSION_MAXIMUM_SUPPORTED%22&type=code
@ -113,7 +114,7 @@ class Exported:
because they are not used. Same length as `in_shardings`.
module_uses_dim_vars: whether the `mlir_module_serialized` uses shape
polymorphic dimension variables. This may be from `in_avals` but also
from inner calls of Exported modules.
from inner calls of shape-polymorphic Exported modules.
disabled_checks: a list of descriptors of safety checks that have been
disabled at export time.
_get_vjp: an optional function that takes the current exported function and
@ -129,8 +130,8 @@ class Exported:
out_tree: tree_util.PyTreeDef
out_avals: Tuple[core.AbstractValue, ...]
in_shardings: Tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]
out_shardings: Tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]
in_shardings: Optional[Tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]]
out_shardings: Optional[Tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]]
lowering_platform: str
disabled_checks: Sequence[DisabledSafetyCheck]
@ -145,6 +146,66 @@ class Exported:
def mlir_module(self) -> ir.Module:
return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized)
def shape_check_module(self) -> Optional[Tuple[bytes, Sequence[str]]]:
"""Generates a serialized shape checking module and the error messages.
Consider the exporting of a function with one array argument of type
"f32[w, 2 * h]", where "w" and "h" are two dimension variables. JAX tracing
assumes that `arg.shape[1]` is even, and that both `w` and `h` have
values >= 1. We must ensure that these assumptions hold when the function
is invoked.
If we `call_exported` for this module we perform these checks
statically (in `call_exported_abstract_eval`).
But if we wrap the exported module with XlaCallModule for execution outside
JAX, we must defer these checks to compile time, when the static shapes
are known.
We generate a separate MLIR shape checking module with a main
function that returns a triple of int32's, with a shape checking code and
two shape checking operands. The shape checking code is -1 if the checks
pass, or is an index into the returned shape_check_messages otherwise.
The returned shape checking operands are substituted for "%1" and "%2" in
the shape checking messages to obtain the error message. Here is a shape
checking module for our example:
func public main(arg: f32[?, ?]) {
# code will be the index of the last failed shape check.
code = op1 = op2 = -1 # Assume no errors initially
# Check that w is >= 1
arg_w = hlo.get_dimension_size(arg, 0)
code = hlo.select(arg_w >= 1, code, 0) # error message 0
op1 = hlo.select(arg_w >= 1, op1, arg_w)
op2 = hlo.select(arg_w >= 1, op2, 1)
# Check that dim1 is even
dim1 = hlo.get_dimension_size(arg, 1)
code = hlo.select(dim1 % 2 == 0, code, 1) # error message 1
op1 = hlo.select(dim1 % 2 == 0, op1, dim1 % 2)
op2 = hlo.select(dim1 % 2 == 0, op2, 0)
# Check that h >= 1
arg_h = hlo.floordiv(dim1, 2)
code = hlo.select(arg_h >= 1, code, 2) # error message 2
op1 = hlo.select(arg_h >= 1, op1, arg_h)
op2 = hlo.select(arg_h >= 1, op2, 1)
return (code, op1, op2)
We also return the following shape checking messages to be used for
constructing error messages:
shape_check_messages = [
"Dimension variable 'w' must have integer value >= 1. Found %1",
"Dimension variable 'h' must have integer value >= 1. Found non-zero remainder %1",
"Dimension variable 'h' must have integer value >= 1. Found %1"]
"""
shape_check_res = _make_shape_check_module(
self.in_avals, args_kwargs_tree=self.in_tree)
if shape_check_res is None:
return None
module, shape_check_messages = shape_check_res
module_serialized, xla_call_module_version = _serialize_module(module)
assert xla_call_module_version == self.xla_call_module_version
return module_serialized, shape_check_messages
def __str__(self):
# This is called to make a MLIR source location when we call an Exported, and we
# do not want the entire serialized module to end up in locations.
@ -186,6 +247,10 @@ def poly_spec(
convenience, zero or more trailing `_` can be abbreviated with `...`, and
the surrounding parentheses may be missing.
Note that this function does not ensure that the provided `arg_shape`
is compatible with `polymorphic_shape`. The `arg_shape` is used only
to fill-in placeholders from `polymorphic_shape`.
See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
@ -215,6 +280,10 @@ def poly_specs(
See [how optional parameters are matched to
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
Note that this function does not ensure that the provided `args` shapes
are compatible with `polymorphic_shapes`. The `args.shape` are used only
to fill-in placeholders from `polymorphic_shapes`.
See docstring of `poly_spec` and
[the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
@ -290,9 +359,7 @@ def export(fun_jax: Callable,
*args_specs, **kwargs_specs,
_experimental_lowering_platform=lowering_platform_str)
lowering = lowered._lowering # type: ignore
_check_lowering(lowering)
mlir_module = lowering.stablehlo()
args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals)
@ -309,31 +376,7 @@ def export(fun_jax: Callable,
mlir_module = _wrap_main_func(
mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree
)
xla_call_module_version = 6
mlir_str = mlir.module_to_bytecode(mlir_module)
if stablehlo.get_api_version() < 4:
target_version = stablehlo.get_earliest_forward_compatible_version()
else:
# `target_version` is used to manage situations when a StableHLO producer
# (in this case, jax2tf) and a StableHLO consumer were built using
# different versions of StableHLO.
#
# Each StableHLO version `producer_version` has a compatibility window,
# i.e. range of versions [`consumer_version_min`, `consumer_version_max`],
# where StableHLO portable artifacts serialized by `producer_version`
# can be deserialized by `consumer_version` within the window.
# See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
# for the exact extent of these compatibility guarantees.
#
# `stablehlo.get_minimum_version()` returns `consumer_version_min`
# for the current version of StableHLO. We are using it here to maximize
# forward compatibility, i.e. to maximize how far into the past we can go
# and still have the payloads produced by `serialize_portable_artifact`
# compatible with potential consumers from the past.
target_version = stablehlo.get_minimum_version()
mlir_module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
mlir_str, target_version)
mlir_module_serialized, xla_call_module_version = _serialize_module(mlir_module)
# Figure out the result types and shapes
if "global_out_avals" in lowering.compile_args:
@ -377,6 +420,34 @@ def export(fun_jax: Callable,
return do_export
def _serialize_module(module: ir.Module) -> Tuple[bytes, int]:
xla_call_module_version = 6
mlir_str = mlir.module_to_bytecode(module)
if stablehlo.get_api_version() < 4:
target_version = stablehlo.get_earliest_forward_compatible_version()
else:
# `target_version` is used to manage situations when a StableHLO producer
# (in this case, jax2tf) and a StableHLO consumer were built using
# different versions of StableHLO.
#
# Each StableHLO version `producer_version` has a compatibility window,
# i.e. range of versions [`consumer_version_min`, `consumer_version_max`],
# where StableHLO portable artifacts serialized by `producer_version`
# can be deserialized by `consumer_version` within the window.
# See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
# for the exact extent of these compatibility guarantees.
#
# `stablehlo.get_minimum_version()` returns `consumer_version_min`
# for the current version of StableHLO. We are using it here to maximize
# forward compatibility, i.e. to maximize how far into the past we can go
# and still have the payloads produced by `serialize_portable_artifact`
# compatible with potential consumers from the past.
target_version = stablehlo.get_minimum_version()
module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
mlir_str, target_version)
return module_serialized, xla_call_module_version
def _wrap_main_func(
module: ir.Module,
args_avals_flat: Sequence[core.ShapedArray],
@ -392,7 +463,7 @@ def _wrap_main_func(
the `args_avals`, in sorted order.
Consider the lowering of a function with one array argument of type "f32[w,
h]",
2 * h]",
where "w" and "h" are two dimension variables. The `module` will also
contain two dimension arguments, corresponding to "h" and "w" respectively:
@ -403,8 +474,9 @@ def _wrap_main_func(
we rename "main" to "_wrapped_jax_export_main" and add a new "main":
func public main(arg: f32[?, ?]) {
arg_h = hlo.get_dimension_size(arg, 1)
arg_w = hlo.get_dimension_size(arg, 0)
dim1 = hlo.get_dimension_size(arg, 1)
arg_h = hlo.floordiv(dim1, 2)
res = call _wrapped_jax_export_main(arg_h, arg_w, arg)
return res
}
@ -424,11 +496,11 @@ def _wrap_main_func(
Returns the wrapped module.
"""
dim_vars = shape_poly.all_dim_vars(args_avals_flat)
# Make a new module, do not mutate the "module" because it may be cached
context = mlir.make_ir_context()
with context, ir.Location.unknown(context):
new_module = ir.Module.parse(mlir.module_to_bytecode(module))
symbol_table = ir.SymbolTable(new_module.operation)
# Make a copy, do not mutate because it may be cached
wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module))
symbol_table = ir.SymbolTable(wrapped_module.operation)
orig_main = symbol_table["main"]
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main")
@ -446,8 +518,9 @@ def _wrap_main_func(
nr_array_args = len(orig_input_types) - len(dim_vars) - nr_token_args
assert nr_array_args >= 0
assert not any(is_token(attrs) for attrs in arg_attrs[-nr_array_args:])
new_main_input_types = orig_input_types[-nr_array_args:]
# The order of args: dim args, token args, array args.
new_main_input_types = orig_input_types[- nr_array_args:]
dim_var_input_types = orig_input_types[:len(dim_vars)]
orig_output_types = orig_main.type.results
result_attrs = list(ir.ArrayAttr(orig_main.result_attrs))
nr_token_results = sum(1 for attrs in result_attrs if is_token(attrs))
@ -457,10 +530,9 @@ def _wrap_main_func(
is_token(attrs) for attrs in result_attrs[-nr_array_results:]
)
new_main_output_types = orig_output_types[-nr_array_results:]
ftype = ir.FunctionType.get(new_main_input_types, new_main_output_types)
new_main_ftype = ir.FunctionType.get(new_main_input_types, new_main_output_types)
new_main_op = func_dialect.FuncOp(
"main", ftype, ip=ir.InsertionPoint.at_block_begin(new_module.body))
"main", new_main_ftype, ip=ir.InsertionPoint.at_block_begin(wrapped_module.body))
new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public")
try:
new_main_op.arg_attrs = ir.ArrayAttr.get(arg_attrs[-nr_array_args:])
@ -479,13 +551,21 @@ def _wrap_main_func(
module_context = mlir.ModuleContext(
"cpu", "cpu", sharding_impls.ShardingContext([]),
source_info_util.new_name_stack(),
[], itertools.count(1), [], module=new_module, context=context)
[], itertools.count(1), [], module=wrapped_module, context=context)
ctx = mlir.LoweringRuleContext(module_context=module_context,
primitive=None, avals_in=args_avals_flat, avals_out=None,
tokens_in=mlir.TokenSet(), tokens_out=None)
dim_args = _compute_dim_args(ctx, args_avals_flat, tuple(new_main_op.arguments),
orig_input_types[:len(dim_vars)],
args_kwargs_tree=args_kwargs_tree)
dim_values = mlir.lower_fun(
functools.partial(shape_poly.compute_dim_vars_from_arg_shapes,
args_avals_flat, args_kwargs_tree=args_kwargs_tree),
multiple_results=True)(ctx, *new_main_op.arguments)
dim_args = []
for dim_arg, dim_arg_type in zip(util.flatten(dim_values), dim_var_input_types):
if dim_arg.type != dim_arg_type:
dim_args.append(hlo.ConvertOp(dim_arg_type, dim_arg).result)
else:
dim_args.append(dim_arg)
# The first arguments are the dimension variable
orig_main_args.extend(dim_args)
# Then the token arguments
@ -497,40 +577,56 @@ def _wrap_main_func(
orig_main_args)
func_dialect.ReturnOp(call.results[-nr_array_results:])
symbol_table.set_symbol_name(new_main_op, "main")
return new_module
return wrapped_module
def _compute_dim_args(
ctx: mlir.LoweringRuleContext,
args_avals_flat: Sequence[core.ShapedArray],
array_args: Sequence[ir.Value],
dim_arg_types: Sequence[ir.Type], *,
args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[ir.Value]:
"""Compute the values of the dimension arguments.
def _make_shape_check_module(
args_avals_flat: Sequence[core.AbstractValue],
*,
args_kwargs_tree: tree_util.PyTreeDef
) -> Optional[Tuple[ir.Module, Sequence[str]]]:
"""Codegens the shape checking function.
Args:
args_avals_flat: the abstract values of the array arguments.
array_args: the values of the array arguments.
dim_arg_types: the desired types for the dimension arguments.
args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for
error messages.
The shape checking function takes the array inputs and returns a triple.
See `Exported.shape_check_module` docstring.
Returns:
the values of the dimension variables, in the sorted order of the
dimension variables.
Returns the shape checking module and the tuple of shape checking messages.
"""
dim_values = mlir.lower_fun(
functools.partial(shape_poly.compute_dim_vars_from_arg_shapes,
args_avals_flat, args_kwargs_tree=args_kwargs_tree),
multiple_results=True)(ctx, *array_args)
context = mlir.make_ir_context()
with context, ir.Location.unknown(context):
array_input_types = tuple(mlir.aval_to_ir_type(a) for a in args_avals_flat)
module = ir.Module.create(loc=ir.Location.unknown(context))
res = []
for dim_arg, dim_arg_type in zip(util.flatten(dim_values), dim_arg_types):
if dim_arg.type != dim_arg_type:
res.append(hlo.ConvertOp(dim_arg_type, dim_arg).result)
else:
res.append(dim_arg)
return tuple(res)
symbol_table = ir.SymbolTable(module.operation)
shape_check_code_type = mlir.aval_to_ir_type(core.ShapedArray((), np.int32))
shape_check_output_types = (shape_check_code_type,) * 3
shape_check_ftype = ir.FunctionType.get(array_input_types,
shape_check_output_types)
shape_check_func = func_dialect.FuncOp(
"main", shape_check_ftype,
ip=ir.InsertionPoint.at_block_begin(module.body))
shape_check_func.attributes["sym_visibility"] = ir.StringAttr.get("public")
symbol_table.insert(shape_check_func)
entry_block = shape_check_func.add_entry_block()
with ir.InsertionPoint(entry_block):
module_context = mlir.ModuleContext(
"cpu", "cpu", sharding_impls.ShardingContext([]),
source_info_util.new_name_stack(),
[], itertools.count(1), [], module=module, context=context)
ctx = mlir.LoweringRuleContext(module_context=module_context,
primitive=None, avals_in=args_avals_flat,
avals_out=None, tokens_in=mlir.TokenSet(),
tokens_out=None)
acc_shape_check_messages: List[str] = []
values = mlir.lower_fun(
functools.partial(shape_poly.compute_shape_check_from_arg_shapes,
args_avals_flat, args_kwargs_tree=args_kwargs_tree,
acc_shape_check_messages=acc_shape_check_messages),
multiple_results=True)(ctx, *shape_check_func.arguments)
func_dialect.ReturnOp(util.flatten(values))
return (module, acc_shape_check_messages) if acc_shape_check_messages else None
def _check_lowering(lowering) -> None:
@ -707,9 +803,9 @@ def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported:
# Expand in_shardings to all in_avals even not kept ones.
all_in_shardings = [sharding_impls.UNSPECIFIED] * len(primal.in_avals)
for idx, in_s in zip(sorted(primal.module_kept_var_idx),
primal.in_shardings):
primal.in_shardings): # type: ignore
all_in_shardings[idx] = in_s # type: ignore
all_shardings = all_in_shardings + list(primal.out_shardings)
all_shardings = all_in_shardings + list(primal.out_shardings) # type: ignore
# Cannot mix unspecified and specified shardings. Make the unspecified
# ones replicated.
specified_shardings = [
@ -817,16 +913,28 @@ def _call_exported_abstract_eval(*in_avals: core.AbstractValue,
if (not core.is_constant_dim(actual_aval.shape[dim_idx]) or
aval_d != actual_aval.shape[dim_idx]):
raise ValueError(
f"Shape mismatch for {pp_arg_dim(dim_idx)} (expected constant): "
f"Shape mismatch for {pp_arg_dim(dim_idx)} "
"(expected same constant): "
f"expected {exp_aval.shape} and called with {actual_aval.shape}")
# Must express the exported_dim_vars in terms of the shapes in in_avals.
solution, shape_constraints, known_dim_vars = shape_poly.solve_dim_vars(
solution, shape_constraints, synth_dim_vars = shape_poly.solve_dim_vars(
exported.in_avals, args_kwargs_tree=exported.in_tree)
known_env = {vname: in_avals[arg_idx].shape[dim_idx]
for (vname, arg_idx, dim_idx) in known_dim_vars}
shape_constraints.check(known_env)
exported_dim_values = [solution[var].evaluate(known_env)
synthetic_env = {vname: in_avals[arg_idx].shape[dim_idx]
for (vname, arg_idx, dim_idx) in synth_dim_vars}
# We discharge all the constraints statically. This results in much simpler
# composability (because we do not have to worry about the constraints of the
# Exported called recursively; we only need to worry about entry-point
# constraints). This also makes sense from a composibility point of view,
# because we get the same errors if we invoke the exported module, or if we
# trace the exported function. Consider for example, an exported module with
# signature `f32[a, a] -> f32[a]`. If we invoke the module with an argument
# `f32[c, d]` it is better to fail because `c == d` is inconclusive, than
# succeed and add a compile-time check that `c == d`. In the latter case,
# it would be ambiguous whether we should continue tracing with a result
# a type `f32[c]` or `f32[d]`.
shape_constraints.check_statically(synthetic_env)
exported_dim_values = [solution[var].evaluate(synthetic_env)
for var in exported_dim_vars]
return tuple(
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars,
@ -881,8 +989,9 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
kept_args)
# The ctx.avals_out already contain the abstract values refined by
# _call_exported_abstract_eval.
return tuple(convert_shape(out, out_aval, refined_out_aval)
for out, out_aval, refined_out_aval in zip(call.results, exported.out_avals, ctx.avals_out))
return tuple(
convert_shape(out, out_aval, refined_out_aval)
for out, out_aval, refined_out_aval in zip(call.results, exported.out_avals, ctx.avals_out))
for _p in ("cpu", "tpu", "cuda", "rocm"):

View File

@ -51,6 +51,7 @@ from jax.interpreters import xla
from jax._src import core
from jax._src import dtypes
from jax._src.lax import lax
from jax._src.interpreters import mlir
from jax._src.numpy import lax_numpy
from jax._src import tree_util
@ -1164,20 +1165,61 @@ class ShapeConstraint:
comp: Comparator
left: DimSize
right: DimSize
# make_err_msg is invoked with (left_int, right_int) if the constraint fails.
make_err_msg: Callable[[int, int], str]
# make_err_msg lazily produces an error message that may refer to the
# two operands being compared.
make_err_msg: Callable[[Any, Any], str]
def check(self, shapeenv: DimVarEnv) -> None:
"""Evaluates a constraint statically and raises an error if fails."""
def eval_operand(o: DimSize) -> Union[int, jax.Array]:
if core.is_constant_dim(o): return op.index(o)
return o.evaluate(shapeenv) # type: ignore
def _eval_operand(self, o: DimSize, shapeenv: DimVarEnv,
) -> Union[int, jax.Array]:
if core.is_constant_dim(o):
res = op.index(o)
else:
res = o.evaluate(shapeenv) # type: ignore
# Ensure we have np.int32
res_dtype = dtypes.dtype(res)
if res_dtype == np.int32 or is_poly_dim(res):
return res
if type(res) in dtypes.python_scalar_dtypes:
return dtypes.coerce_to_array(res, np.int32) # type: ignore
else:
return res.astype(np.int32)
def _err_msg(self, left, right):
return self.make_err_msg(str(left), str(right))
def check_statically(self, shapeenv: DimVarEnv) -> None:
"""Evaluates a constraint statically.
The `shapeenv` maps variables to _DimExpr. If the static checking
of the constraint fails, raise ValueError.
"""
left, right = [self._eval_operand(o, shapeenv) for o in [self.left,
self.right]]
try:
left1, right1 = eval_operand(self.left), eval_operand(self.right)
except KeyError:
return None
if self.comp == ShapeConstraint.Comparator.EQ:
ok = (left == right)
elif self.comp == ShapeConstraint.Comparator.GEQ:
ok = (left >= right)
else: assert False
except InconclusiveDimensionOperation as e:
raise ValueError(self.make_err_msg(left, right)) from e
if not ok:
raise ValueError(self.make_err_msg(left, right))
left_int, right_int = _is_known_constant(left1), _is_known_constant(right1)
def compute(self,
shapeenv: DimVarEnv
) -> Optional[Tuple[jax.Array, jax.Array, jax.Array]]:
"""Computes if the constraint is satisfied.
If the constraint can be resolved statically returns None
or raises ValueError otherwise. If the constraint cannot be
resolved statically, returns a triple with a boolean encoding if the
constraint is satisfied and the int32 operands of binary constraints.
"""
left, right = [self._eval_operand(o, shapeenv) for o in [self.left,
self.right]]
# Try to evaluate the constraint statically.
left_int, right_int = _is_known_constant(left), _is_known_constant(right)
if left_int is not None and right_int is not None:
if self.comp == ShapeConstraint.Comparator.EQ:
if not (left_int == right_int):
@ -1186,31 +1228,69 @@ class ShapeConstraint:
if not (left_int >= right_int):
raise ValueError(self.make_err_msg(left_int, right_int))
else: assert False
else:
return None # TODO: evaluate constraint dynamically
return None
if self.comp == ShapeConstraint.Comparator.EQ:
is_ok = lax.eq(left, right)
elif self.comp == ShapeConstraint.Comparator.GEQ:
is_ok = lax.ge(left, right)
else: assert False
return is_ok, left, right # type: ignore
def __str__(self):
return (f"{self.left} {'==' if self.comp == ShapeConstraint.Comparator.EQ else '>='} {self.right}"
f" ({self.make_err_msg(self.left, self.right)})")
f" ({self.make_err_msg('%1', '%2')})")
__repr__ = __str__
class ShapeConstraints:
def __init__(self):
self.constraints: Set[ShapeConstraint] = set() # map DimConstraint to an integer >= 0
self.constraints: List[ShapeConstraint] = []
def add_constraint(self,
comp: ShapeConstraint.Comparator,
left: DimSize, right: DimSize,
make_err_msg: Callable[[int, int], str]):
make_err_msg: Callable[[Any, Any], str]):
# Try to evaluate it statically
c = ShapeConstraint(comp, left, right, make_err_msg)
self.constraints.add(c)
self.constraints.append(c)
def check(self, shapeenv: DimVarEnv) -> None:
def check_statically(self, shapeenv: DimVarEnv) -> None:
"""Evaluates all the constraints statically.
The `shapeenv` maps variables to _DimExpr. If the static checking
of any constraint fails, raise ValueError.
"""
for constraint in self.constraints:
constraint.check(shapeenv)
constraint.check_statically(shapeenv)
def compute(self, shapeenv: DimVarEnv) -> Tuple[jax.Array, jax.Array, jax.Array, Sequence[str]]:
"""Computes the error code for the set of constraints.
The error code is -1 if all constraints are satisfied, or an index into
the returned error messages.
See Exported.shape_check_module docstring.
"""
# We want to report the errors in the same order as `check_statically`.
# So, we process them in order, in case some fail statically, but we
# accumulate the errors in reverse order in the computation, because the
# code-generation strategy will report the last error that failed.
acc: List[Tuple[jax.Array, jax.Array, jax.Array, str]] = []
for constraint in self.constraints:
check_res = constraint.compute(shapeenv)
if check_res is not None:
acc.append((*check_res, constraint.make_err_msg("%1", "%2"))) # type: ignore
shape_check_messages: List[str] = []
shape_check_code: jax.Array = np.int32(-1) # type: ignore
shape_check_op1 = shape_check_op2 = shape_check_code
for (is_ok, op1, op2, msg) in reversed(acc):
shape_check_code = lax.select(is_ok, shape_check_code,
np.int32(len(shape_check_messages)))
shape_check_op1 = lax.select(is_ok, shape_check_op1, op1)
shape_check_op2 = lax.select(is_ok, shape_check_op2, op2)
shape_check_messages.append(msg)
return (shape_check_code, shape_check_op1, shape_check_op2, shape_check_messages)
@dataclasses.dataclass
@ -1256,50 +1336,52 @@ def solve_dim_vars(
args_avals = [ShapedArray((3, a, a + b), f32)]
we introduce fresh "known" dimension variables to represent the actual dimension
size of actual arguments for each non-constant dimension. Each known variable
has a name, an arg_idx, and a dim_idx, e.g.:
we introduce fresh "synthetic" dimension variables to represent the actual
dimension size of actual arguments for each non-constant dimension.
Each synthetic variable has a name, an arg_idx, and a dim_idx, e.g.:
known_vars = [("args[0].shape[1]", 0, 1), ("args[0].shape[2]", 0, 2)]
synthetic_vars = [("args[0].shape[1]", 0, 1), ("args[0].shape[2]", 0, 2)]
and then we express the solution for the unknown dimension variables {a, b}
as symbolic expressions in terms of the known variables:
as symbolic expressions in terms of the synthetic variables:
dict(a=args[0].shape[1], b=args[0].shape[2] - args[0].shape[1])
Not all equations are solvable. For now, we solve first the linear uni-variate
equations, then the solved variables are used to simplify the remaining
equations to linear uni-variate equations, and the process continues
until all dimension variables are solved.
Not all equations are solvable. For now, we solve first the linear
uni-variate equations, then the solved variables are used to simplify the
remaining equations to linear uni-variate equations, and the process
continues until all dimension variables are solved.
Args:
args_avals: the abstract values of the `args`, with shapes that may
include unknown dimension variables.
args_kwargs_tree: a PyTreeDef that describes the tuple `(args, kwargs)` from
which the flat sequence `args_avals` is extracted. Used for describing
args and kwargs in known variable names and in error messages.
args_kwargs_tree: a PyTreeDef that describes the tuple `(args, kwargs)`
from which the flat sequence `args_avals` is extracted. Used for
describing args and kwargs in synthetic variable names and in
error messages.
Returns: a 3-tuple with: (a) the solution for the unknown dimension variables
(b) a list of constraints that must be satisfied for the solution to be a
valid one, and (c) and the list of known variables that may appear in
valid one, and (c) and the list of synthetic variables that may appear in
the solution and the constraints.
Raises ValueError if it cannot solve some dimension variable.
"""
dim_equations: List[_DimEquation] = []
known_dimension_vars: List[Tuple[str, int, int]] = []
synth_dimension_vars: List[Tuple[str, int, int]] = []
for arg_idx, aval in enumerate(args_avals):
for dim_idx, aval_d in enumerate(aval.shape):
if is_poly_dim(aval_d):
known_dim_var = pretty_print_dimension_descriptor(args_kwargs_tree,
synth_dim_var = pretty_print_dimension_descriptor(args_kwargs_tree,
arg_idx, dim_idx)
known_dimension_vars.append((known_dim_var, arg_idx, dim_idx))
synth_dimension_vars.append((synth_dim_var, arg_idx, dim_idx))
dim_equations.append(
_DimEquation(dim_expr=_ensure_poly(aval_d, "solve_dim_vars"),
dim_value=_DimExpr.from_var(known_dim_var)))
dim_value=_DimExpr.from_var(synth_dim_var)))
solution, shape_constraints = _solve_dim_equations(dim_equations)
return solution, shape_constraints, known_dimension_vars
return solution, shape_constraints, synth_dimension_vars
def compute_dim_vars_from_arg_shapes(
args_avals: Sequence[core.AbstractValue],
@ -1315,17 +1397,47 @@ def compute_dim_vars_from_arg_shapes(
`all_dim_vars(args_avals)`.
"""
dim_vars = all_dim_vars(args_avals)
solution, shape_constraints, known_dim_vars = solve_dim_vars(
solution, shape_constraints, synth_dim_vars = solve_dim_vars(
tuple(args_avals), args_kwargs_tree=args_kwargs_tree)
# Replace the synthetic vars with the dynamic shape of the actual arg
known_env = {vname: dimension_size_p.bind(actual_args[arg_idx], dimension=dim_idx)
for (vname, arg_idx, dim_idx) in known_dim_vars}
dim_values = [solution[var].evaluate(known_env) for var in dim_vars]
shape_constraints.check(known_env)
synthetic_env = {vname: dimension_size_p.bind(actual_args[arg_idx],
dimension=dim_idx)
for (vname, arg_idx, dim_idx) in synth_dim_vars}
dim_values = [solution[var].evaluate(synthetic_env) for var in dim_vars]
return tuple(dim_values)
def compute_shape_check_from_arg_shapes(
args_avals: Sequence[core.AbstractValue],
*actual_args: jax.Array,
args_kwargs_tree: tree_util.PyTreeDef,
acc_shape_check_messages: List[str]) -> Sequence[jax.Array]:
"""Computes the shape check code from the actual arguments.
`acc_shape_check_messages` is an initially empty list where we append the
shape checking messages. Each of these correspond to a constraint.
See the `Exported.shape_check_module` docstring.
Returns: a triple of JAX arrays: the shape check code and the values of the
two operands for the failed constraint.
Raises ValueError if a constraint is invalidated statically.
"""
assert not acc_shape_check_messages
solution, shape_constraints, synth_dim_vars = solve_dim_vars(
tuple(args_avals), args_kwargs_tree=args_kwargs_tree)
# Replace the synthetic vars with the dynamic shape of the actual arg
synthetic_env = {vname: dimension_size_p.bind(actual_args[arg_idx],
dimension=dim_idx)
for (vname, arg_idx, dim_idx) in synth_dim_vars}
shape_check_code, shape_check_op1, shape_check_op2, shape_check_messages = (
shape_constraints.compute(synthetic_env))
acc_shape_check_messages.extend(shape_check_messages)
return (shape_check_code, shape_check_op1, shape_check_op2)
def _solve_dim_equations(
eqns: List[_DimEquation]
) -> Tuple[DimVarEnv, ShapeConstraints]:
@ -1375,16 +1487,16 @@ def _solve_dim_equations(
var_value, var_remainder = divmod(dim_value, core.dim_constant(factor_var)) # type: ignore
shape_constraints.add_constraint(
ShapeConstraint.Comparator.EQ, var_remainder, 0,
make_err_msg=lambda rem_int, _: (
make_err_msg=lambda left, _: (
f"Dimension variable '{var}' must have integer value >= 1. "
f"Non-zero remainder {rem_int} for factor {factor_var} when solving "
f"Non-zero remainder {left} for factor {factor_var} when solving "
f"{eqn}.{_shapeenv_to_str()}"))
shape_constraints.add_constraint(
ShapeConstraint.Comparator.GEQ, var_value, 1,
make_err_msg=lambda var_int, _: (
make_err_msg=lambda left, _: (
f"Dimension variable '{var}' must have integer value >= 1. "
f"Found {var_int} when "
f"Found {left} when "
f"solving {eqn}.{_shapeenv_to_str()}"))
if not isinstance(var_value, _DimExpr):
@ -1396,8 +1508,8 @@ def _solve_dim_equations(
shape_constraints.add_constraint(
ShapeConstraint.Comparator.EQ, eqn.dim_value,
eqn.dim_expr.evaluate(shapeenv),
make_err_msg=lambda val1, val2: (
f"Found inconsistency {val1} != {val2} when solving {eqn}.{_shapeenv_to_str()}"))
make_err_msg=lambda left, right: (
f"Found inconsistency {left} != {right} when solving {eqn}.{_shapeenv_to_str()}"))
return True
while True:

View File

@ -15,31 +15,150 @@ import contextlib
import logging
import math
import re
from typing import List
from typing import List, Optional, Sequence
import unittest
from absl.testing import absltest, parameterized
from absl.testing import absltest
import jax
from jax import numpy as jnp
from jax import tree_util
from jax.config import config
from jax.experimental.jax2tf import jax_export
try:
from jax.experimental.jax2tf import jax2tf # TODO: temporary
except ImportError:
jax2tf = None # type: ignore
from jax.lib import xla_client as xc
from jax._src import core
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_extension
import numpy as np
config.parse_flags_with_absl()
if xc.mlir_api_version >= 50:
def _temp_call_exported(exp: jax_export.Exported, *args: jax.Array,
skip_shape_check: bool = False):
assert all(core.is_constant_shape(a.shape) for a in args)
return jax_export.call_exported(exp)(*args)
else:
def _temp_call_exported(exp: jax_export.Exported, *args: jax.Array,
skip_shape_check: bool = False):
"""Temporary runner for an Exported.
Normally we would use jax_export.call_exported, but if the exported has
shape polymorphism and we are using jaxlib before 0.4.12 we use
Once we upgrade the jaxlib we can replace all uses of this function with
jax_export.call_exported.
"""
assert all(core.is_constant_shape(a.shape) for a in args)
if not exp.module_uses_dim_vars:
return jax_export.call_exported(exp)(*args)
else:
# We only get here in external tests, because internal ones use newest jaxlib
from jax.experimental.jax2tf import jax2tf # TODO: temporary
# call_exported does the shape checking, we must do it manually for
# XlaCallModule.
if not skip_shape_check:
shape_check = exp.shape_check_module()
if shape_check is not None:
err_msg = _run_shape_check_module(exp, shape_check[0], shape_check[1], args)
if err_msg is not None:
raise ValueError(err_msg)
numpy_results = map(lambda res_tf: res_tf.numpy(),
jax2tf._run_exported_as_tf(args, exp))
return exp.out_tree.unflatten(numpy_results)
def _run_shape_check_module(primal_exported: jax_export.Exported,
shape_check_module_serialized: bytes,
shape_check_messages: Sequence[str],
args: Sequence[jax.Array]
) -> Optional[str]:
"""Helper to run a shape checking module.
We only need to do this in tests, because otherwise this will be done
implicitly when we call XlaCallModule.
Returns: the error message, or None if no error.
"""
args = tuple(args)
# We cannot just make an Exported and run it, because call_exported will
# do the shape checks statically. So, we wrap the shape check module
# with static shape arguments
static_in_avals = tuple(core.get_aval(a) for a in args)
context = mlir.make_ir_context()
with context, ir.Location.unknown(context):
wrapped_module = ir.Module.parse(
xla_extension.mlir.deserialize_portable_artifact(
shape_check_module_serialized))
symbol_table = ir.SymbolTable(wrapped_module.operation)
orig_main = symbol_table["main"]
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main")
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value
# Use static shapes
new_main_input_types = [mlir.aval_to_ir_type(a) for a in static_in_avals]
orig_output_types = orig_main.type.results
new_main_ftype = ir.FunctionType.get(
new_main_input_types, orig_output_types
)
new_main_op = func_dialect.FuncOp(
"main",
new_main_ftype,
ip=ir.InsertionPoint.at_block_begin(wrapped_module.body),
)
new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public")
symbol_table.insert(new_main_op)
entry_block = new_main_op.add_entry_block()
with ir.InsertionPoint(entry_block):
orig_main_args: List[ir.Value] = []
for new_arg, orig_arg_type in zip(
new_main_op.arguments, orig_main.type.inputs
):
orig_main_args.append(hlo.ConvertOp(orig_arg_type, new_arg).result)
call = func_dialect.CallOp(
orig_output_types,
ir.FlatSymbolRefAttr.get(orig_main_name),
orig_main_args,
)
func_dialect.ReturnOp(call.results)
symbol_table.set_symbol_name(new_main_op, "main")
wrapped_module_serialized, version = jax_export._serialize_module(wrapped_module)
# Make an Exported and then run it
out_avals = (core.ShapedArray((), dtype=np.int32),) * 3
exp = jax_export.Exported(
fun_name=f"shape_check_{primal_exported.fun_name}",
in_tree=tree_util.tree_flatten((args, {}))[1],
in_avals=static_in_avals,
out_tree=tree_util.tree_flatten(out_avals)[1],
out_avals=out_avals,
in_shardings=None,
out_shardings=None,
lowering_platform=primal_exported.lowering_platform,
disabled_checks=(),
mlir_module_serialized=wrapped_module_serialized,
xla_call_module_version=version,
module_kept_var_idx=tuple(sorted(range(len(static_in_avals)))),
module_uses_dim_vars=True,
_get_vjp=lambda _: None) # type: ignore
code, op1, op2 = _temp_call_exported(exp, *args,
skip_shape_check=True)
if code == -1:
return None
else:
return shape_check_messages[code].replace("%1", str(int(op1))).replace(
"%2", str(int(op2)))
class JaxExportTest(jtu.JaxTestCase):
def test_basic_export_only(self):
@ -171,9 +290,10 @@ class JaxExportTest(jtu.JaxTestCase):
r"Dtype mismatch for args\[0\]"):
jax_export.call_exported(exp_f)(f32_4.astype(np.float16), b=f32_4)
@parameterized.named_parameters(
dict(testcase_name=p, platform=p)
for p in ("cpu", "cuda", "rocm", "tpu"))
@jtu.parameterized_filterable(
testcase_name=lambda kw: kw["platform"],
kwargs=[dict(platform=p)
for p in ("cpu", "cuda", "rocm", "tpu")])
def test_error_wrong_platform(self, platform):
a = np.arange(4, dtype=np.float32)
@ -239,6 +359,7 @@ class JaxExportTest(jtu.JaxTestCase):
exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct)
self.assertAllClose(jax_vjp, exp_vjp)
def test_roundtrip(self):
def f1(x):
return jnp.sin(x)
@ -253,150 +374,199 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))),
jax_export.call_exported(exp_f2)(a))
# A function is exported with f32[poly_spec] and is called with different arg
# shapes. We use jax_export.call_exported and we also run the shape check
# module.
@jtu.parameterized_filterable(
testcase_name=lambda kw:f"poly_spec={kw['poly_spec']}_arg_shape={kw['arg_shape']}", # type: ignore
#one_containing="",
kwargs=[
dict(poly_spec="3,4,12", arg_shape=(3, 4, 12)),
dict(poly_spec="3,4,12", arg_shape=(3, 4, 13),
# The shape check module does not test constant dimensions
expect_error_run=re.escape(
r"Shape mismatch for args[0].shape[2] (expected same constant)")),
dict(poly_spec="3,4,6*a", arg_shape=(3, 4, 12)),
dict(poly_spec="3,a,a+8", arg_shape=(3, 4, 12)),
dict(poly_spec="3,4,a+1", arg_shape=(3, 4, 1),
expect_error=re.escape(
r"Dimension variable 'a' must have integer "
r"value >= 1. Found 0 when solving "
r"a + 1 == args[0].shape[2].")),
dict(poly_spec="3,4,6*a", arg_shape=(3, 4, 13),
expect_error=re.escape(
r"Dimension variable 'a' must have integer value >= 1. "
r"Non-zero remainder 1 for factor 6 when solving "
r"6*a == args[0].shape[2]")),
dict(poly_spec="3,a,a+8", arg_shape=(3, 4, 13),
expect_error=re.escape(
r"Found inconsistency 13 != 12 when solving "
r"a + 8 == args[0].shape[2]")),
])
def test_poly_shape_checks(
self, poly_spec="3,a,a+8",
arg_shape=(3, 4, 12), arg_dtype=np.float32,
expect_error=None, # If given, applies for expect_error_run and expect_error_shape_check
expect_error_run=None, # Error from running the exported module
expect_error_shape_check=None): # Error from running the shape check module
if expect_error is not None:
self.assertIsNone(expect_error_run, None)
self.assertIsNone(expect_error_shape_check, None)
expect_error_run = expect_error_shape_check = expect_error
def f(x): # x: f32[poly_spec]
return jnp.reshape(x, (-1, x.shape[1]))
exp_f = jax_export.export(f)(
jax_export.poly_spec((3, 4, 12), np.float32, poly_spec))
self.assertEqual(exp_f.module_uses_dim_vars, poly_spec != "3,4,12")
arg = np.arange(np.prod(arg_shape),
dtype=arg_dtype).reshape(arg_shape) # arg : f32[3,4,12]
with contextlib.ExitStack() as stack:
if expect_error_run is not None:
stack.push(self.assertRaisesRegex(Exception, expect_error_run))
res = _temp_call_exported(exp_f, arg)
if not expect_error_run:
self.assertAllClose(res, f(arg))
# Test the shape_check_module
shape_check = exp_f.shape_check_module()
# We have shape_check only if the exported has polymorphic inputs shapes.
if all(core.is_constant_shape(a.shape) for a in exp_f.in_avals):
self.assertIsNone(shape_check)
self.assertIsNone(expect_error_shape_check)
else:
self.assertIsNotNone(shape_check)
shape_check_module, shape_check_messages = shape_check
err_msg = _run_shape_check_module(exp_f,
shape_check_module, shape_check_messages, (arg,))
if expect_error_shape_check is None:
self.assertIsNone(err_msg)
else:
self.assertRegex(err_msg, expect_error_shape_check)
# An inner function is exported with polymorphic shapes inner_poly_spec, and
# is called from an outer function, which is exported with outer_poly_spec.
@parameterized.named_parameters(
dict(testcase_name=f"inner={d['inner_poly_spec']}_outer={d['outer_poly_spec']}", # type: ignore
**d) # type: ignore
for d in (
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,12"),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,c"),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c,c",
expect_error=re.escape(
r"Dimension variable 'b' must have integer value >= 1. "
r"Found 0 when solving a + b == args[0].shape[2]")),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="c,4,12",
expect_error=r"Shape mismatch for args\[0\].shape\[0\] \(expected constant\)"),
# dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c+4,12"), # TODO: there should be an error
dict(inner_poly_spec="3,4,3*a", outer_poly_spec="3,4,12"),
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,12",
expect_error=re.escape(
r"Dimension variable 'a' must have integer value >= 1. "
r"Non-zero remainder 2 for factor 5 when solving 5*a == args[0].shape[2]")),
# dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,c"), # TODO: there should be an error 5*a != c == 12
# dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a"), # TODO: there should be an error 12 != 4
dict(inner_poly_spec="3,a", inner_x_shape=(3, 4), outer_poly_spec="3,a,a",
expect_error=r"Rank mismatch for args\[0\]"),
dict(inner_poly_spec="3,a,a+b", inner_x_dtype=np.int32, outer_poly_spec="3,c,d",
expect_error=r"Dtype mismatch for args\[0\]"),
))
def test_poly(self, inner_poly_spec="3,a,a", inner_x_shape=(3, 4, 6),
inner_x_dtype=np.float32,
outer_poly_spec="3,a,a", outer_x_shape=(3, 4, 12),
expect_error=None):
@jtu.parameterized_filterable(
testcase_name=lambda kw:f"inner={kw['inner_poly_spec']}_outer={kw['outer_poly_spec']}", # type: ignore
#one_containing="",
# By default arg_shape = (3, 4, 12) for both the outer function and the inner
# The inner function is exported for f32.
kwargs=[
# Both inner and outer are static shapes
dict(inner_poly_spec="3,4,12", outer_poly_spec="3,4,12"),
# Inner has poly shapes but outer has static shapes
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,12"),
dict(inner_poly_spec="3,4,3*a", outer_poly_spec="3,4,12"),
dict(inner_poly_spec="3,a,a", outer_poly_spec="3,4,12",
expect_error_outer_exp=re.escape(
r"Found inconsistency 12 != 4 when solving a == args[0].shape[2]")),
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,12",
expect_error_outer_exp=re.escape(
r"Dimension variable 'a' must have integer value >= 1. "
r"Non-zero remainder 2 for factor 5 when solving 5*a == args[0].shape[2]")),
dict(inner_poly_spec="3,4,12+a", outer_poly_spec="3,4,12",
expect_error_outer_exp=re.escape(
r"Dimension variable 'a' must have integer value >= 1. "
r"Found 0 when solving a + 12 == args[0].shape[2]")),
# Both inner and outer have poly shapes.
dict(inner_poly_spec="3,a,b", outer_poly_spec="3,4,c"),
dict(inner_poly_spec="3,4,3*a", outer_poly_spec="3,4,6*c"),
dict(inner_poly_spec="3,a,a+8", outer_poly_spec="3,c+2,c+10"),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,c",
expect_error_outer_exp=re.escape(
r"Dimension variable 'b' must have integer value >= 1. "
r"Found c + -4 when solving a + b == args[0].shape[2]")),
dict(inner_poly_spec="3,a,a", outer_poly_spec="3,4,c",
expect_error_outer_exp=re.escape(
r"Found inconsistency c != 4 when solving a == args[0].shape[2]"
)),
dict(inner_poly_spec="3,a,a", arg_shape=(3, 4),
outer_poly_spec="3,c",
expect_error_outer_exp=r"Rank mismatch for args\[0\]"),
dict(inner_poly_spec="3,a,a+b", arg_dtype=np.int32,
outer_poly_spec="3,c,d",
expect_error_outer_exp=r"Dtype mismatch for args\[0\]"),
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,c",
expect_error_outer_exp=re.escape(
r"Dimension variable 'a' must have integer value >= 1. "
r"Non-zero remainder mod(c, 5) for factor 5 when solving 5*a == args[0].shape[2]"
)),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c,c",
expect_error_outer_exp=re.escape(
r"Dimension variable 'b' must have integer value >= 1. "
r"Found 0 when solving a + b == args[0].shape[2]")),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="c,4,12",
expect_error_outer_exp=re.escape(
r"Shape mismatch for args[0].shape[0] (expected same constant)")),
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,25*c",
expect_error_run=re.escape(
r"Dimension variable 'c' must have integer value >= 1. "
r"Non-zero remainder 12 for factor 25 when solving 25*c == args[0].shape[2]")),
dict(inner_poly_spec="3,a,b", outer_poly_spec="3,c+4,12",
expect_error_run=re.escape(
r"Dimension variable 'c' must have integer value >= 1. "
r"Found 0 when solving c + 4 == args[0].shape[1]")),
dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a",
expect_error_run=re.escape(
r"Found inconsistency 12 != 4 when solving "
r"a == args[0].shape[2]")),
])
def test_poly_shape_checks_nested(
self, inner_poly_spec="3,4,5*a",
arg_shape=(3, 4, 12), arg_dtype=np.float32,
outer_poly_spec="3,4,25*c",
expect_error_outer_exp=None,
expect_error_run=None):
# Polymorphic export called with static or polymorphic shapes
def inner(x): # x: inner_poly_spec
return jnp.reshape(x, (-1, x.shape[1]))
inner_x = np.arange(np.prod(inner_x_shape),
dtype=inner_x_dtype).reshape(inner_x_shape) # inner_x : f32[3,4,6]
arg = np.arange(np.prod(arg_shape),
dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12]
inner_exp = jax_export.export(inner)(
jax_export.poly_spec(inner_x.shape, inner_x.dtype, inner_poly_spec))
jax_export.poly_spec((3, 4, 12), np.float32, inner_poly_spec))
self.assertEqual(inner_exp.module_uses_dim_vars,
(inner_poly_spec != "3,4,12"))
outer_x = np.arange(np.prod(outer_x_shape),
dtype=np.float32).reshape(outer_x_shape) # outer_x : f32[3,4,12]
def outer(x): # x: outer_poly_spec
# Use an addition to test that the shapes are refined properly for the
# result of the call_exported.
return jax_export.call_exported(inner_exp)(x) + inner(x)
with contextlib.ExitStack() as stack:
if expect_error is not None:
stack.push(self.assertRaisesRegex(ValueError, expect_error))
if expect_error_outer_exp is not None:
stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp))
# Call it after exporting again, with polymorphic shapes
outer_exp = jax_export.export(outer)(
jax_export.poly_spec(outer_x.shape, outer_x.dtype, outer_poly_spec))
self.assertEqual(outer_exp.module_uses_dim_vars,
(inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12"))
# TODO(necula): need conditionals until jaxlib 0.4.12 is the minimum version
if not outer_exp.module_uses_dim_vars or xc.mlir_api_version >= 50:
res = jax_export.call_exported(outer_exp)(outer_x)
self.assertAllClose(2. * inner(outer_x), res)
else:
# TODO: for now, we use XlaCallModule to run modules with polymorphic shapes
# until we create the python bindings to invoke shape refinement.
if jax2tf is not None:
res = jax2tf._run_exported_as_tf([outer_x], outer_exp)[0].numpy()
self.assertAllClose(2. * inner(outer_x), res)
jax_export.poly_spec(arg.shape, arg.dtype, outer_poly_spec))
def test_call_poly(self):
a_shape = (3, 4)
a = np.arange(math.prod(a_shape), dtype=np.float32).reshape(a_shape)
if expect_error_outer_exp is not None:
return
def f_inner(x): # x: f32[w, h]
return jnp.reshape(x, (-1,))
self.assertEqual(outer_exp.module_uses_dim_vars,
(inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12"))
shape_check = outer_exp.shape_check_module()
if all(core.is_constant_shape(a.shape) for a in outer_exp.in_avals):
self.assertIsNone(shape_check)
else:
self.assertIsNotNone(shape_check)
exp_inner = jax_export.export(f_inner)(
jax_export.poly_spec(a.shape, a.dtype, "w, h")
)
with contextlib.ExitStack() as stack:
if expect_error_run is not None:
stack.push(self.assertRaisesRegex(Exception, expect_error_run))
# There are dynamic shapes in the exported module
self.assertIn("?x", exp_inner.mlir_module)
self.assertIn("stablehlo.dynamic_reshape", exp_inner.mlir_module)
res = _temp_call_exported(outer_exp, arg)
# Add a wrapper "main" func with static shapes
# TODO(necula): We will add this functionality to jax_export.
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import func as func_dialect
from jax.lib import xla_client as xc
from jax._src.lib import xla_extension
context = mlir.make_ir_context()
with context, ir.Location.unknown(context):
wrapped_module = ir.Module.parse(exp_inner.mlir_module)
symbol_table = ir.SymbolTable(wrapped_module.operation)
orig_main = symbol_table["main"]
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main")
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value
# Use static shapes
new_main_input_types = [
mlir.aval_to_ir_type(core.ShapedArray((3, 4), np.float32))
]
orig_output_types = orig_main.type.results
new_main_ftype = ir.FunctionType.get(
new_main_input_types, orig_output_types
)
new_main_op = func_dialect.FuncOp(
"main",
new_main_ftype,
ip=ir.InsertionPoint.at_block_begin(wrapped_module.body),
)
new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public")
symbol_table.insert(new_main_op)
entry_block = new_main_op.add_entry_block()
with ir.InsertionPoint(entry_block):
orig_main_args: List[ir.Value] = []
for new_arg, orig_arg_type in zip(
new_main_op.arguments, orig_main.type.inputs
):
orig_main_args.append(hlo.ConvertOp(orig_arg_type, new_arg).result)
call = func_dialect.CallOp(
orig_output_types,
ir.FlatSymbolRefAttr.get(orig_main_name),
orig_main_args,
)
func_dialect.ReturnOp(call.results)
symbol_table.set_symbol_name(new_main_op, "main")
# TODO(necula): need conditionals until jaxlib 0.4.12 is the minimum version
if xc.mlir_api_version >= 50:
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
mlir.module_to_bytecode(wrapped_module)
)
context = mlir.make_ir_context()
with context:
refined_module = ir.Module.parse(refined_module_str)
logging.info("Postprocessed module %s", str(refined_module))
self.assertNotIn("?x", str(refined_module))
self.assertNotIn("stablehlo.dynamic_reshape", str(refined_module))
self.assertIn("stablehlo.reshape", str(refined_module))
if expect_error_run is not None:
return
self.assertAllClose(2. * inner(arg), res)
if __name__ == "__main__":

View File

@ -841,7 +841,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
f_tf: Callable[..., Any] = jax2tf.convert(f_jax, polymorphic_shapes=["b, ..."])
self.assertAllClose(f_jax(x, y=y), f_tf(x, y=y))
def test_arg_avals(self):
def test_arg_avals_non_native(self):
"""Test conversion of actual arguments to abstract values."""
def check_avals(*, arg_shapes: Sequence[Sequence[Optional[int]]],
@ -857,7 +857,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
dim_vars = shape_poly.all_dim_vars(avals)
dim_values, _ = jax2tf.jax2tf._interpret_fun_jax(
partial(shape_poly.compute_dim_vars_from_arg_shapes,
avals, args_kwargs_tree=tree_util.tree_flatten((avals, {}))[1]),
avals,
args_kwargs_tree=tree_util.tree_flatten((avals, {}))[1]),
args_tf, avals, "")
if expected_avals is not None:
self.assertEqual(expected_avals, avals)
@ -958,55 +959,55 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
eager_mode=True,
expected_shapeenv=dict(a=2, b=3, c=4))
def test_arg_avals_errors(self):
"""Test error reporting for shape polymorpish."""
def conv_and_run(*, arg_shape: core.Shape,
polymorphic_shape: str):
arg = np.arange(math.prod(arg_shape), dtype=np.float32).reshape(arg_shape)
jax2tf.convert(lambda x: x, polymorphic_shapes=[polymorphic_shape])(arg)
with self.assertRaisesRegex(ValueError,
re.escape("polymorphic shape spec should be")):
conv_and_run(arg_shape=(2,), polymorphic_shape=5.)
with self.assertRaisesRegex(ValueError,
re.escape("pytree structure error: different types")):
conv_and_run(arg_shape=(2,), polymorphic_shape=["a list"])
with self.assertRaisesRegex(ValueError,
re.escape("pytree structure error: different types")):
conv_and_run(arg_shape=(2,), polymorphic_shape=("a tuple",))
# The following do not work yet with native serialization because
# XlaCallModule does not yet do shape checking.
if config.jax2tf_default_native_serialization:
return
# TODO(necula): enable even for native serialization
with self.assertRaisesRegex(ValueError,
"Cannot solve for values of dimension variables {'b'}"):
check_avals(
arg_shapes=[(4, 36, 3)],
polymorphic_shapes=[PS("b * b", "b * d * d", "d")])
conv_and_run(arg_shape=(4, 36, 3), polymorphic_shape="b * b, b * d * d, d")
# TODO(necula): enable even for native serialization
with self.assertRaisesRegex(ValueError,
"Dimension variable 'b' must have integer value >= 1"):
check_avals(
arg_shapes=[(5, 36)],
polymorphic_shapes=[PS("3 * b", ...)],
eager_mode=True)
conv_and_run(arg_shape=(5, 36), polymorphic_shape="3 * b, ...")
# TODO(necula): enable even for native serialization
with self.assertRaisesRegex(ValueError,
"Dimension variable 'b' must have integer value >= 1"):
check_avals(
arg_shapes=[(10, 3)],
polymorphic_shapes=[PS("3 * b + 10", ...)],
eager_mode=True)
conv_and_run(arg_shape=(10, 3), polymorphic_shape="3 * b + 10, ...")
# TODO(necula): enable even for native serialization
with self.assertRaisesRegex(ValueError,
"Dimension variable 'b' must have integer value >= 1"):
check_avals(
arg_shapes=[(7, 3)],
polymorphic_shapes=[PS("3 * b + 10", ...)],
eager_mode=True)
for invalid_syntax in [5.0, ["a list"], ("a tuple",), re.compile(".")]:
with self.assertRaisesRegex(ValueError,
re.escape("Invalid polymorphic shape element")):
check_avals(
arg_shapes=[(2,)], polymorphic_shapes=[PS([invalid_syntax])])
conv_and_run(arg_shape=(7, 3), polymorphic_shape="3 * b + 10, ...")
# TODO(necula): enable even for native serialization
with self.assertRaisesRegex(
ValueError,
"Found inconsistency 3 != 2 when solving.*"):
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=["(a, a)"],
eager_mode=True)
# Same error across multiple arguments
with self.assertRaisesRegex(
ValueError,
"Found inconsistency 5 != 2 when solving.*"):
check_avals(
arg_shapes=[(2, 3), (5,)],
polymorphic_shapes=["a, ...", "a"],
eager_mode=True)
conv_and_run(arg_shape=(2, 3), polymorphic_shape="(a, a)")
def test_pytree(self):
"""Arguments and polymorphic_shapes are pytrees."""