[shape_poly] Reimplement the shape constraint checking using shape assertions.

Most of the functionality is for the JAX native serialization case.
This relies on newly added functionality to xla_extension.refine_polymorphic_shapes
that handles custom calls @static_assertion.

As a beneficial side-effect now we get shape constraint checking for jax2tf
graph serialization when the resulting function is executed in graph mode.
This commit is contained in:
George Necula 2023-07-03 17:31:31 +03:00
parent f97dca79a2
commit e643f98558
11 changed files with 338 additions and 450 deletions

View File

@ -68,6 +68,9 @@ Remember to align the itemized text with the first line of an item within a list
* JAX now supports a configuration flag --jax_serialization_version
and a JAX_SERIALIZATION_VERSION environment variable to control the
serialization version ({jax-issue}`#16746`).
* jax2tf in presence of shape polymorphism now generates code that checks
certain shape constraints, if the serialization version is at least 7.
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism.
## jaxlib 0.4.14

View File

@ -2044,7 +2044,7 @@ def custom_call(
out_types: Sequence[ir.Type],
operands: Sequence[ir.Value],
*,
backend_config: Union[str, dict] = "",
backend_config: Union[str, dict[str, ir.Attribute]] = "",
has_side_effect: bool = False,
result_shapes: Optional[Sequence[ir.Value]] = None,
called_computations: Sequence[str] = (),

View File

@ -1230,18 +1230,21 @@ def parameterized_filterable(*,
Args:
kwargs: Each entry is a set of kwargs to be passed to the test function.
testcase_name: Optionally, a function to construct the testcase_name from
one kwargs dict. If not given then the kwarg must contain `testcase_name`.
one kwargs dict. If not given then kwarg may contain `testcase_name` and
if not, the test case name is constructed as `str(kwarg)`.
one_containing: If given, then leave the test name unchanged, and use
only one `kwargs` whose `testcase_name` includes `one_containing`.
"""
# Ensure that all kwargs contain a testcase_name
kwargs_with_testcase_name: Sequence[dict[str, Any]]
if testcase_name is not None:
kwargs_with_testcase_name = [dict(testcase_name=testcase_name(kw), **kw)
kwargs_with_testcase_name = [dict(testcase_name=str(testcase_name(kw)), **kw)
for kw in kwargs]
else:
for kw in kwargs:
assert "testcase_name" in kw
if "testcase_name" not in kw:
kw["testcase_name"] = "_".join(f"{k}={str(kw[k])}"
for k in sorted(kw.keys()))
kwargs_with_testcase_name = kwargs
if one_containing is not None:
filtered = tuple(kw for kw in kwargs_with_testcase_name

View File

@ -547,11 +547,9 @@ cannot be used anymore as dimension parameters and will raise a JAX error.
### Errors in presence of shape polymorphism
If you write your program assuming that all shapes are tuples of integers,
and then try to trace it with shape polymorphism you can run into a number
of errors.
The program:
Most JAX code assumes that the shapes of JAX arrays are tuples of integers,
but with shape polymorphism some dimensions may be symbolic expressions.
This can lead to a number of errors. For example, the program:
```python
four_ones = np.ones((4,))
@ -597,6 +595,30 @@ The solution is to avoid `np.array`, `float`, or JAX arrays in operations whose
results are used as shapes, e.g., instead of `np.arange(n) * x.shape[0]` write
`[i * x.shape[0] for i in range(n)]`.
JAX assumes that dimension variables range over strictly positive integers.
These assumptions are now checked against the shapes of the actual arguments
when the lowered code is invoked.
For example, given the `polymorphic_shapes="(b, b, 2*d)"`
specification, we will generate code to check the following constraints when
invoked with actual argument `arg`:
* `arg.shape[0] >= 1`
* `arg.shape[1] == arg.shape[0]`
* `arg.shape[2] % 2 == 0` and `arg.shape[0] // 2 >= 1`
When using native serialization these are checked by the `tf.XlaCallModule`
op (starting with serialization
[version 7](https://github.com/search?q=repo%3Agoogle%2Fjax+path%3Aconfig.py+jax_serialization_version&type=code)),
and you will get `tf.errors.InvalidArgument` errors.
You can disable this checking by including `DisabledSafetyCheck.shape_assertions()`
in the `disabled_checks` parameter to `jax2tf.convert`, or by setting
the environment variable
`TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=shape_assertions`.
When using graph serialization these are checked using `tf.debugging.assert`,
which will also result in `tf.errors.InvalidArgument`.
Note that due to limitations in TensorFlow, these errors are suppressed when using
`jit_compile=True` and when running on TPU.
### Comparison of symbolic dimensions is partially supported
Inside JAX there are a number of equality and inequality comparisons

View File

@ -567,17 +567,6 @@ class GraphSerializationImpl(SerializationImpl):
self.args_avals_flat, args_kwargs_tree=self.in_tree),
self.args_flat_tf, self.args_avals_flat, self.name_stack)
# We invoke shape checking to give it a chance to raise shape errors that
# are evident statically. This should work in TF eager mode because all
# the shapes are known.
# TODO: handle non-static shape checking for graph serialization
acc_shape_check_messages: list[str] = []
_, _ = _interpret_fun_jax(
partial(shape_poly.compute_shape_check_from_arg_shapes,
self.args_avals_flat, args_kwargs_tree=self.in_tree,
acc_shape_check_messages=acc_shape_check_messages),
self.args_flat_tf, self.args_avals_flat, self.name_stack)
_thread_local_state.shape_env = zip(dim_vars, dim_values)
fun_flat_jax, out_tree_thunk = flatten_fun_jax(self.fun_jax, self.in_tree)
@ -3386,6 +3375,16 @@ def _dim_as_value_jax2tf(dim: shape_poly.DimSize):
tf_impl[shape_poly.dim_as_value_p] = _dim_as_value_jax2tf
def _shape_assertion_jax2tf(assert_what, *error_message_inputs,
error_message: str):
tf.debugging.assert_equal(
assert_what, True,
message=error_message.format(*error_message_inputs))
return []
tf_impl[shape_poly.shape_assertion_p] = _shape_assertion_jax2tf
def _reduce_precision(x, *, exponent_bits, mantissa_bits):
return tfxla.reduce_precision(x, exponent_bits=exponent_bits,
mantissa_bits=mantissa_bits)

View File

@ -82,6 +82,15 @@ class DisabledSafetyCheck:
"""
return DisabledSafetyCheck(f"custom_call:{target_name}")
@classmethod
def shape_assertions(cls) -> "DisabledSafetyCheck":
"""Allows invocations with shapes that do not meet the constraints.
Has effect on serialization (to supress the generation of the assertions)
and also on deserialization (to suppress the checking of the assertions).
"""
return DisabledSafetyCheck("shape_assertions")
def is_custom_call(self) -> Optional[str]:
"""Returns the custom call target allowed by this directive."""
m = re.match(r'custom_call:(.+)$', self._impl)
@ -163,66 +172,6 @@ class Exported:
def mlir_module(self) -> ir.Module:
return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized)
def shape_check_module(self) -> Optional[tuple[bytes, Sequence[str]]]:
"""Generates a serialized shape checking module and the error messages.
Consider the exporting of a function with one array argument of type
"f32[w, 2 * h]", where "w" and "h" are two dimension variables. JAX tracing
assumes that `arg.shape[1]` is even, and that both `w` and `h` have
values >= 1. We must ensure that these assumptions hold when the function
is invoked.
If we `call_exported` for this module we perform these checks
statically (in `call_exported_abstract_eval`).
But if we wrap the exported module with XlaCallModule for execution outside
JAX, we must defer these checks to compile time, when the static shapes
are known.
We generate a separate MLIR shape checking module with a main
function that returns a triple of int32's, with a shape checking code and
two shape checking operands. The shape checking code is -1 if the checks
pass, or is an index into the returned shape_check_messages otherwise.
The returned shape checking operands are substituted for "%1" and "%2" in
the shape checking messages to obtain the error message. Here is a shape
checking module for our example:
func public main(arg: f32[?, ?]) {
# code will be the index of the last failed shape check.
code = op1 = op2 = -1 # Assume no errors initially
# Check that w is >= 1
arg_w = hlo.get_dimension_size(arg, 0)
code = hlo.select(arg_w >= 1, code, 0) # error message 0
op1 = hlo.select(arg_w >= 1, op1, arg_w)
op2 = hlo.select(arg_w >= 1, op2, 1)
# Check that dim1 is even
dim1 = hlo.get_dimension_size(arg, 1)
code = hlo.select(dim1 % 2 == 0, code, 1) # error message 1
op1 = hlo.select(dim1 % 2 == 0, op1, dim1 % 2)
op2 = hlo.select(dim1 % 2 == 0, op2, 0)
# Check that h >= 1
arg_h = hlo.floordiv(dim1, 2)
code = hlo.select(arg_h >= 1, code, 2) # error message 2
op1 = hlo.select(arg_h >= 1, op1, arg_h)
op2 = hlo.select(arg_h >= 1, op2, 1)
return (code, op1, op2)
We also return the following shape checking messages to be used for
constructing error messages:
shape_check_messages = [
"Dimension variable 'w' must have integer value >= 1. Found %1",
"Dimension variable 'h' must have integer value >= 1. Found non-zero remainder %1",
"Dimension variable 'h' must have integer value >= 1. Found %1"]
"""
shape_check_res = _make_shape_check_module(
self.in_avals, args_kwargs_tree=self.in_tree)
if shape_check_res is None:
return None
module, shape_check_messages = shape_check_res
module_serialized, xla_call_module_version = _serialize_module(module)
assert xla_call_module_version == self.xla_call_module_version
return module_serialized, shape_check_messages
def __str__(self):
# This is called to make a MLIR source location when we call an Exported, and we
# do not want the entire serialized module to end up in locations.
@ -359,6 +308,8 @@ def export(fun_jax: Callable,
exported = jax_export.export(f_jax)(*args, **kwargs)
"""
fun_name = getattr(fun_jax, "__name__", "unknown")
version = config.jax_serialization_version
def do_export(*args_specs, **kwargs_specs) -> Exported:
if not hasattr(fun_jax, "lower"):
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
@ -373,28 +324,39 @@ def export(fun_jax: Callable,
allow_non_replicated_sharding = True
lowering_platform_str = lowering_platform or default_lowering_platform()
lowered = wrapped_fun_jax.lower(
*args_specs, **kwargs_specs,
_experimental_lowering_platform=lowering_platform_str)
lowering = lowered._lowering # type: ignore
_check_lowering(lowering)
mlir_module = lowering.stablehlo()
args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals)
if "kept_var_idx" in lowering.compile_args:
module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"]))
else:
# For pmap
module_kept_var_idx = tuple(range(len(args_avals_flat)))
shape_poly_state = lowering.compile_args["shape_poly_state"]
if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat)
or lowering.compile_args.get("ordered_effects", [])):
# All arguments are kept if we have dimension variables.
assert len(module_kept_var_idx) == len(args_avals_flat)
mlir_module = _wrap_main_func(
mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree
)
mlir_module_serialized, xla_call_module_version = _serialize_module(mlir_module)
# Do not include shape assertions if the version is < 7.
enable_shape_assertions = (
DisabledSafetyCheck.shape_assertions() not in disabled_checks and
version >= 7) # type: ignore
try:
prev_enable_shape_assertions = shape_poly.thread_local_state.enable_shape_assertions
shape_poly.thread_local_state.enable_shape_assertions = enable_shape_assertions
lowered = wrapped_fun_jax.lower(
*args_specs, **kwargs_specs,
_experimental_lowering_platform=lowering_platform_str)
lowering = lowered._lowering # type: ignore
_check_lowering(lowering)
mlir_module = lowering.stablehlo()
args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals)
if "kept_var_idx" in lowering.compile_args:
module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"]))
else:
# For pmap
module_kept_var_idx = tuple(range(len(args_avals_flat)))
shape_poly_state = lowering.compile_args["shape_poly_state"]
if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat)
or lowering.compile_args.get("ordered_effects", [])):
# All arguments are kept if we have dimension variables.
assert len(module_kept_var_idx) == len(args_avals_flat)
mlir_module = _wrap_main_func(
mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree
)
finally:
shape_poly.thread_local_state.enable_shape_assertions = prev_enable_shape_assertions
mlir_module_serialized = _serialize_module(mlir_module)
# Figure out the result types and shapes
if "global_out_avals" in lowering.compile_args:
@ -408,7 +370,7 @@ def export(fun_jax: Callable,
# Log and then check the module.
if logging.vlog_is_on(3):
mlir_module_text = mlir.module_to_string(mlir_module)
logmsg = (f"version={xla_call_module_version} "
logmsg = (f"version={version} "
f"lowering_platform={lowering_platform_str} "
f"disabled_checks={disabled_checks}")
logging.info("Lowered JAX module: %s\n", logmsg)
@ -432,14 +394,13 @@ def export(fun_jax: Callable,
mlir_module_serialized=mlir_module_serialized,
module_kept_var_idx=module_kept_var_idx,
module_uses_dim_vars=shape_poly_state.uses_dim_vars,
xla_call_module_version=xla_call_module_version,
xla_call_module_version=version, # type: ignore
_get_vjp=lambda exported: _export_native_vjp(fun_jax, exported))
return do_export
def _serialize_module(module: ir.Module) -> tuple[bytes, int]:
xla_call_module_version = config.jax_serialization_version
def _serialize_module(module: ir.Module) -> bytes:
mlir_str = mlir.module_to_bytecode(module)
if hlo.get_api_version() < 4:
target_version = hlo.get_earliest_forward_compatible_version()
@ -463,7 +424,7 @@ def _serialize_module(module: ir.Module) -> tuple[bytes, int]:
target_version = hlo.get_minimum_version()
module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
mlir_str, target_version)
return module_serialized, xla_call_module_version
return module_serialized
def _wrap_main_func(
@ -481,9 +442,9 @@ def _wrap_main_func(
the `args_avals`, in sorted order.
Consider the lowering of a function with one array argument of type "f32[w,
2 * h]",
where "w" and "h" are two dimension variables. The `module` will also
contain two dimension arguments, corresponding to "h" and "w" respectively:
2 * h]", where "w" and "h" are two dimension variables. The `module` will
also contain two dimension arguments, corresponding to "h" and "w"
respectively:
func public main(arg_h: i32, arg_w: i32, arg: f32[?, ?]) {
...
@ -495,6 +456,7 @@ def _wrap_main_func(
arg_w = hlo.get_dimension_size(arg, 0)
dim1 = hlo.get_dimension_size(arg, 1)
arg_h = hlo.floordiv(dim1, 2)
call _check_shape_assertions(arg) # See below
res = call _wrapped_jax_export_main(arg_h, arg_w, arg)
return res
}
@ -503,6 +465,31 @@ def _wrap_main_func(
function by providing dummy values. This ensures that the main function's
calling convention is as expected.
Note that the lowering contains a call to `_check_shape_assertions.
JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h`
have values >= 1. We must check these constraints when we invoke the
module. We use a special custom call `@shape_assertion` that takes
a boolean first operand, a string `error_message` attribute that may contain
format specifiers `{0}`, `{1}`, ..., and a variadic number of integer
scalar operands corresponding to the format specifiers.
func private _check_shape_assertions(arg: f32[?, ?]) {
# Check that w is >= 1
arg_w = hlo.get_dimension_size(arg, 0)
custom_call @shape_assertion(arg_w >= 1, arg_w,
error_message="Dimension variable 'w' must have integer value >= 1. Found {0}")
# Check that dim1 is even
dim1 = hlo.get_dimension_size(arg, 1)
custom_call @shape_assertion(dim1 % 2 == 0, dim1,
error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}")
# Check that h >= 1
arg_h = hlo.floordiv(dim1, 2)
custom_call @shape_assertion(arg_h >= 1, arg_h,
error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}")
If we `call_exported` with this module we perform these checks
statically (in `call_exported_abstract_eval`).
Args:
module: the HLO module as obtained from lowering. May have a number of
dimension arguments, followed by the kept array arguments.
@ -601,55 +588,6 @@ def _wrap_main_func(
return wrapped_module
def _make_shape_check_module(
args_avals_flat: Sequence[core.AbstractValue],
*,
args_kwargs_tree: tree_util.PyTreeDef
) -> Optional[tuple[ir.Module, Sequence[str]]]:
"""Codegens the shape checking function.
The shape checking function takes the array inputs and returns a triple.
See `Exported.shape_check_module` docstring.
Returns the shape checking module and the tuple of shape checking messages.
"""
context = mlir.make_ir_context()
with context, ir.Location.unknown(context):
array_input_types = tuple(mlir.aval_to_ir_type(a) for a in args_avals_flat)
module = ir.Module.create(loc=ir.Location.unknown(context))
symbol_table = ir.SymbolTable(module.operation)
shape_check_code_type = mlir.aval_to_ir_type(core.ShapedArray((), np.int32))
shape_check_output_types = (shape_check_code_type,) * 3
shape_check_ftype = ir.FunctionType.get(array_input_types,
shape_check_output_types)
shape_check_func = func_dialect.FuncOp(
"main", shape_check_ftype,
ip=ir.InsertionPoint.at_block_begin(module.body))
shape_check_func.attributes["sym_visibility"] = ir.StringAttr.get("public")
symbol_table.insert(shape_check_func)
entry_block = shape_check_func.add_entry_block()
with ir.InsertionPoint(entry_block):
module_context = mlir.ModuleContext(
"cpu", "cpu", sharding_impls.ShardingContext([]),
source_info_util.new_name_stack(),
[], itertools.count(1), [], module=module, context=context)
ctx = mlir.LoweringRuleContext(module_context=module_context,
primitive=None, avals_in=args_avals_flat,
avals_out=None, tokens_in=mlir.TokenSet(),
tokens_out=None)
acc_shape_check_messages: list[str] = []
values = mlir.lower_fun(
functools.partial(shape_poly.compute_shape_check_from_arg_shapes,
args_avals_flat, args_kwargs_tree=args_kwargs_tree,
acc_shape_check_messages=acc_shape_check_messages),
multiple_results=True)(ctx, *shape_check_func.arguments)
func_dialect.ReturnOp(util.flatten(values))
return (module, acc_shape_check_messages) if acc_shape_check_messages else None
def _check_lowering(lowering) -> None:
if not isinstance(lowering, pxla.MeshComputation):
raise NotImplementedError(f"serialization is supported only for pjit. {lowering}")
@ -749,6 +687,7 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
# See https://github.com/openxla/stablehlo/issues/8.
"stablehlo.dynamic_reduce_window",
"stablehlo.dynamic_rng_bit_generator",
"shape_assertion", # Used by shape_poly to evaluate assertions
}
@ -763,6 +702,7 @@ def _check_module(mod: ir.Module, *,
disabled_checks: the safety checks that are disabled.
"""
sharding_attr = ir.StringAttr.get("Sharding", mod.context)
shape_assertion_attr = ir.StringAttr.get("shape_assertion", mod.context)
allowed_custom_call_targets: set[str] = copy.copy(_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
for dc in disabled_checks:
target = dc.is_custom_call()
@ -799,6 +739,8 @@ def _check_module(mod: ir.Module, *,
disallowed_custom_call_ops.append(f"{op} at {op.location}")
if call_target_name_attr == sharding_attr:
check_sharding(op, op.location)
elif call_target_name_attr == shape_assertion_attr:
assert (DisabledSafetyCheck.shape_assertions() not in disabled_checks)
def walk_operations(op):
check_op(op)

View File

@ -38,6 +38,7 @@ import itertools
import io
import math
import operator as op
import threading
import tokenize
from typing import Any, Callable, Iterable, Optional, Sequence, Union
@ -50,6 +51,7 @@ from jax.interpreters import xla
from jax._src import core
from jax._src import dtypes
from jax._src import effects
from jax._src.lax import lax
from jax._src.interpreters import mlir
from jax._src.numpy import lax_numpy
@ -79,6 +81,15 @@ for more details.
# https://github.com/python/mypy/issues/5887
super().__init__(error_msg) # type: ignore
class _ShapePolyThreadLocalState(threading.local):
def __init__(self):
# TODO(necula): this does not play well with some lowering caches, because
# this state is not part of the cache key.
self.enable_shape_assertions = True
thread_local_state = _ShapePolyThreadLocalState()
class _DimAtom:
"""Represents an atom in a symbolic dimension expression.
@ -792,6 +803,62 @@ def _einsum_contract_path(*operands, **kwargs):
lax_numpy._poly_einsum_handlers[_DimExpr] = _einsum_contract_path
# To implement shape-constraint checking we use a shape assertion primitive.
# shape_assertion_p.bind(assert_what: bool, *error_message_inputs,
# error_message="...{0}...{1}")
# where "{0}" refers to error_message_inputs[0], etc.
shape_assertion_p = core.Primitive("shape_assertion")
shape_assertion_p.multiple_results = True
shape_assertion_p.def_effectful_abstract_eval(
lambda *_, **__: ((), {shape_assertion_effect})) # type: ignore
def _shape_assertion_lowering_rule(ctx: mlir.LoweringRuleContext,
assert_what: mlir.ir.Value,
*error_message_inputs: mlir.ir.Value,
error_message: str):
op = mlir.custom_call(
"shape_assertion",
[], # No results
[assert_what, *error_message_inputs],
has_side_effect=True,
extra_attributes=dict(error_message=mlir.ir.StringAttr.get(error_message))
)
return op.results
mlir.register_lowering(shape_assertion_p, _shape_assertion_lowering_rule)
class ShapeAssertionEffect(effects.Effect):
__str__ = lambda _: "ShapeAssertionEffect"
shape_assertion_effect = ShapeAssertionEffect()
effects.lowerable_effects.add_type(ShapeAssertionEffect)
effects.control_flow_allowed_effects.add_type(ShapeAssertionEffect)
effects.remat_allowed_effects.add_type(ShapeAssertionEffect)
effects.custom_derivatives_allowed_effects.add_type(ShapeAssertionEffect)
def shape_assertion(assert_what: jax.Array,
*error_message_inputs: jax.Array,
error_message: str) -> None:
"""Adds a shape assertion in the code.
Args:
assert_what: a boolean asserted to be true. Must be computed based only
on dimension expressions, so that it can be evaluated after shape
refinement.
error_message_inputs: integers expressions whose values can be referenced
in the `error_message`. Must be computed based only
on dimension expressions, so that they can be evaluated after shape
refinement.
error_message: an error message, possibly containing format specifiers
{0}, {1}, ..., referencing the values of the `error_message_inputs`.
The format specifiers are sometimes processed with Python's
`string::format` method, and sometimes with `llvm::formatv`.
"""
if thread_local_state.enable_shape_assertions:
shape_assertion_p.bind(assert_what, *error_message_inputs,
error_message=error_message)
# A JAX primitive with no array arguments but with a dimension parameter
# that is a DimExpr. The value of the primitive is the value of the dimension,
# using int64 in x64 mode or int32 otherwise (core.dim_value_dtype())
@ -1055,22 +1122,6 @@ def _evaluate_multiply(v1, v2):
pass
return v1 * v2
def _is_known_constant(v) -> Optional[int]:
try:
return int(v)
except Exception:
# TODO(necula): added this so that in jax2tf, in Eager mode, we can tell
# that a tensor is a constant. We should move this dependency into some
# jax2tf-specific area.
if hasattr(v, "val"):
try:
vint = int(v.val)
if isinstance(vint, int): # In TF, int(tf.Tensor) is tf.Tensor!
return vint
except Exception:
pass
return None
# dimension_size(operand, dimension=i) get the operand.shape[i] as a
# value of type shape_poly.dim_as_value_dtype().
dimension_size_p = core.Primitive("dimension_size")
@ -1180,8 +1231,8 @@ class ShapeConstraint:
left, right = [self._eval_operand(o, shapeenv) for o in [self.left,
self.right]]
# Try to evaluate the constraint statically.
left_int, right_int = _is_known_constant(left), _is_known_constant(right)
if left_int is not None and right_int is not None:
if core.is_constant_shape((left, right)):
left_int, right_int = op.index(left), op.index(right)
if self.comp == ShapeConstraint.Comparator.EQ:
if not (left_int == right_int):
raise ValueError(self.make_err_msg(left_int, right_int))
@ -1225,34 +1276,21 @@ class ShapeConstraints:
for constraint in self.constraints:
constraint.check_statically(shapeenv)
def compute(self, shapeenv: DimVarEnv) -> tuple[jax.Array, jax.Array, jax.Array, Sequence[str]]:
"""Computes the error code for the set of constraints.
def shape_assertions(self, shapeenv: DimVarEnv) -> None:
"""Computes the shape assertions the set of constraints.
The error code is -1 if all constraints are satisfied, or an index into
the returned error messages.
See Exported.shape_check_module docstring.
See _wrap_main_func docstring.
"""
# We want to report the errors in the same order as `check_statically`.
# So, we process them in order, in case some fail statically, but we
# accumulate the errors in reverse order in the computation, because the
# code-generation strategy will report the last error that failed.
acc: list[tuple[jax.Array, jax.Array, jax.Array, str]] = []
# So, we process them in order, in case some fail statically, and we
# generate the shape assertions in the same order.
for constraint in self.constraints:
check_res = constraint.compute(shapeenv)
if check_res is not None:
acc.append((*check_res, constraint.make_err_msg("%1", "%2"))) # type: ignore
shape_check_messages: list[str] = []
shape_check_code: jax.Array = np.int32(-1) # type: ignore
shape_check_op1 = shape_check_op2 = shape_check_code
for (is_ok, op1, op2, msg) in reversed(acc):
shape_check_code = lax.select(is_ok, shape_check_code,
np.int32(len(shape_check_messages)))
shape_check_op1 = lax.select(is_ok, shape_check_op1, op1)
shape_check_op2 = lax.select(is_ok, shape_check_op2, op2)
shape_check_messages.append(msg)
return (shape_check_code, shape_check_op1, shape_check_op2, shape_check_messages)
is_ok, op1, op2 = check_res
shape_assertion(
is_ok, op1, op2,
error_message=constraint.make_err_msg("{0}", "{1}"))
@dataclasses.dataclass
class _DimEquation:
@ -1352,7 +1390,8 @@ def compute_dim_vars_from_arg_shapes(
Like `solve_dim_vars` except that here we express the solution as
JAX arrays that reference the `actual_args`. This function can be used to
generate the code for computing the dimension variables.
generate the code for computing the dimension variables. It also generates
the shape assertions.
Returns: the values of the dimension variables, in the order determined by
`all_dim_vars(args_avals)`.
@ -1366,39 +1405,9 @@ def compute_dim_vars_from_arg_shapes(
dimension=dim_idx)
for (vname, arg_idx, dim_idx) in synth_dim_vars}
dim_values = [solution[var].evaluate(synthetic_env) for var in dim_vars]
shape_constraints.shape_assertions(synthetic_env)
return tuple(dim_values)
def compute_shape_check_from_arg_shapes(
args_avals: Sequence[core.AbstractValue],
*actual_args: jax.Array,
args_kwargs_tree: tree_util.PyTreeDef,
acc_shape_check_messages: list[str]) -> Sequence[jax.Array]:
"""Computes the shape check code from the actual arguments.
`acc_shape_check_messages` is an initially empty list where we append the
shape checking messages. Each of these correspond to a constraint.
See the `Exported.shape_check_module` docstring.
Returns: a triple of JAX arrays: the shape check code and the values of the
two operands for the failed constraint.
Raises ValueError if a constraint is invalidated statically.
"""
assert not acc_shape_check_messages
solution, shape_constraints, synth_dim_vars = solve_dim_vars(
tuple(args_avals), args_kwargs_tree=args_kwargs_tree)
# Replace the synthetic vars with the dynamic shape of the actual arg
synthetic_env = {vname: dimension_size_p.bind(actual_args[arg_idx],
dimension=dim_idx)
for (vname, arg_idx, dim_idx) in synth_dim_vars}
shape_check_code, shape_check_op1, shape_check_op2, shape_check_messages = (
shape_constraints.compute(synthetic_env))
acc_shape_check_messages.extend(shape_check_messages)
return (shape_check_code, shape_check_op1, shape_check_op2)
def _solve_dim_equations(
eqns: list[_DimEquation]
) -> tuple[DimVarEnv, ShapeConstraints]:

View File

@ -125,6 +125,8 @@ class CompatTest(bctu.CompatTestBase):
covered_targets = covered_targets.union({
"tpu_custom_call", # tested separately
# TODO(necula): add tests for shape_assertion
"shape_assertion",
})
not_covered = targets_to_cover.difference(covered_targets)
self.assertEmpty(not_covered)
@ -608,7 +610,9 @@ class CompatTest(bctu.CompatTestBase):
data = self.load_testdata(tpu_stablehlo_dynamic_reduce_window.data_unary_2023_06_17)
self.run_one_test(
func, data,
polymorphic_shapes=("b, ...",))
polymorphic_shapes=("b, ...",),
# TODO(necula): now also includes shape_assertion
compare_with_current=False)
def test_tpu_stablehlo_dynamic_reduce_window_variadic(self):
# stablehlo.dynamic_reduce_window is used temporarily on TPU for a
@ -629,7 +633,9 @@ class CompatTest(bctu.CompatTestBase):
data = self.load_testdata(tpu_stablehlo_dynamic_reduce_window.data_variadic_2023_06_17)
self.run_one_test(
func, data,
polymorphic_shapes=("b, ...", "b, ..."))
polymorphic_shapes=("b, ...", "b, ..."),
# TODO(necula): now also includes shape_assertion
compare_with_current=False)
def test_stablehlo_dynamic_rbg_bit_generator(self):
# stablehlo.dynamic_rbg_bit_generator is used temporarily for a
@ -660,7 +666,9 @@ class CompatTest(bctu.CompatTestBase):
try:
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
self.run_one_test(func, data, polymorphic_shapes=(None, "b0, b1"))
self.run_one_test(func, data, polymorphic_shapes=(None, "b0, b1"),
# TODO(necula): now also includes shape_assertion
compare_with_current=False)
finally:
jax.config.update("jax_default_prng_impl", prev_default_prng_impl)

View File

@ -13,7 +13,6 @@
# limitations under the License.
import contextlib
import re
from typing import Optional, Sequence
import unittest
from absl.testing import absltest
@ -28,136 +27,14 @@ from jax._src import core
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_extension
import numpy as np
config.parse_flags_with_absl()
if xc.mlir_api_version >= 50:
def _temp_call_exported(exp: jax_export.Exported, *args: jax.Array,
skip_shape_check: bool = False):
assert all(core.is_constant_shape(a.shape) for a in args)
return jax_export.call_exported(exp)(*args)
else:
def _temp_call_exported(exp: jax_export.Exported, *args: jax.Array,
skip_shape_check: bool = False):
"""Temporary runner for an Exported.
Normally we would use jax_export.call_exported, but if the exported has
shape polymorphism and we are using jaxlib before 0.4.12 we use
Once we upgrade the jaxlib we can replace all uses of this function with
jax_export.call_exported.
"""
assert all(core.is_constant_shape(a.shape) for a in args)
if not exp.module_uses_dim_vars:
return jax_export.call_exported(exp)(*args)
else:
# We only get here in external tests, because internal ones use newest jaxlib
from jax.experimental.jax2tf import jax2tf # TODO: temporary
# call_exported does the shape checking, we must do it manually for
# XlaCallModule.
if not skip_shape_check:
shape_check = exp.shape_check_module()
if shape_check is not None:
err_msg = _run_shape_check_module(exp, shape_check[0], shape_check[1], args)
if err_msg is not None:
raise ValueError(err_msg)
numpy_results = map(lambda res_tf: res_tf.numpy(),
jax2tf._run_exported_as_tf(args, exp))
return exp.out_tree.unflatten(numpy_results)
def _run_shape_check_module(primal_exported: jax_export.Exported,
shape_check_module_serialized: bytes,
shape_check_messages: Sequence[str],
args: Sequence[jax.Array]
) -> Optional[str]:
"""Helper to run a shape checking module.
We only need to do this in tests, because otherwise this will be done
implicitly when we call XlaCallModule.
Returns: the error message, or None if no error.
"""
args = tuple(args)
# We cannot just make an Exported and run it, because call_exported will
# do the shape checks statically. So, we wrap the shape check module
# with static shape arguments
static_in_avals = tuple(core.get_aval(a) for a in args)
context = mlir.make_ir_context()
with context, ir.Location.unknown(context):
wrapped_module = ir.Module.parse(
xla_extension.mlir.deserialize_portable_artifact(
shape_check_module_serialized))
symbol_table = ir.SymbolTable(wrapped_module.operation)
orig_main = symbol_table["main"]
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main")
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value
# Use static shapes
new_main_input_types = [mlir.aval_to_ir_type(a) for a in static_in_avals]
orig_output_types = orig_main.type.results
new_main_ftype = ir.FunctionType.get(
new_main_input_types, orig_output_types
)
new_main_op = func_dialect.FuncOp(
"main",
new_main_ftype,
ip=ir.InsertionPoint.at_block_begin(wrapped_module.body),
)
new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public")
symbol_table.insert(new_main_op)
entry_block = new_main_op.add_entry_block()
with ir.InsertionPoint(entry_block):
orig_main_args: list[ir.Value] = []
for new_arg, orig_arg_type in zip(
new_main_op.arguments, orig_main.type.inputs
):
orig_main_args.append(hlo.ConvertOp(orig_arg_type, new_arg).result)
call = func_dialect.CallOp(
orig_output_types,
ir.FlatSymbolRefAttr.get(orig_main_name),
orig_main_args,
)
func_dialect.ReturnOp(call.results)
symbol_table.set_symbol_name(new_main_op, "main")
wrapped_module_serialized, version = jax_export._serialize_module(wrapped_module)
# Make an Exported and then run it
out_avals = (core.ShapedArray((), dtype=np.int32),) * 3
exp = jax_export.Exported(
fun_name=f"shape_check_{primal_exported.fun_name}",
in_tree=tree_util.tree_flatten((args, {}))[1],
in_avals=static_in_avals,
out_tree=tree_util.tree_flatten(out_avals)[1],
out_avals=out_avals,
in_shardings=None,
out_shardings=None,
lowering_platform=primal_exported.lowering_platform,
disabled_checks=(),
mlir_module_serialized=wrapped_module_serialized,
xla_call_module_version=version,
module_kept_var_idx=tuple(sorted(range(len(static_in_avals)))),
module_uses_dim_vars=True,
_get_vjp=lambda _: None) # type: ignore
code, op1, op2 = _temp_call_exported(exp, *args,
skip_shape_check=True)
if code == -1:
return None
else:
return shape_check_messages[code].replace("%1", str(int(op1))).replace(
"%2", str(int(op2)))
class JaxExportTest(jtu.JaxTestCase):
def test_basic_export_only(self):
@ -396,7 +273,7 @@ class JaxExportTest(jtu.JaxTestCase):
dict(poly_spec="3,4,12", arg_shape=(3, 4, 12)),
dict(poly_spec="3,4,12", arg_shape=(3, 4, 13),
# The shape check module does not test constant dimensions
expect_error_run=re.escape(
expect_error=re.escape(
r"Shape mismatch for args[0].shape[2] (expected same constant)")),
dict(poly_spec="3,4,6*a", arg_shape=(3, 4, 12)),
dict(poly_spec="3,a,a+8", arg_shape=(3, 4, 12)),
@ -418,49 +295,31 @@ class JaxExportTest(jtu.JaxTestCase):
def test_poly_shape_checks(
self, poly_spec="3,a,a+8",
arg_shape=(3, 4, 12), arg_dtype=np.float32,
expect_error=None, # If given, applies for expect_error_run and expect_error_shape_check
expect_error_run=None, # Error from running the exported module
expect_error_shape_check=None): # Error from running the shape check module
if expect_error is not None:
self.assertIsNone(expect_error_run, None)
self.assertIsNone(expect_error_shape_check, None)
expect_error_run = expect_error_shape_check = expect_error
expect_error=None): # If given, error from running the exported module
def f(x): # x: f32[poly_spec]
return jnp.reshape(x, (-1, x.shape[1]))
exp_f = jax_export.export(f)(
if xc.mlir_api_version <= 51:
disabled_checks = (jax_export.DisabledSafetyCheck.shape_assertions(),)
else:
disabled_checks = ()
exp_f = jax_export.export(f, disabled_checks=disabled_checks)(
jax_export.poly_spec((3, 4, 12), np.float32, poly_spec))
self.assertEqual(exp_f.module_uses_dim_vars, poly_spec != "3,4,12")
arg = np.arange(np.prod(arg_shape),
dtype=arg_dtype).reshape(arg_shape) # arg : f32[3,4,12]
with contextlib.ExitStack() as stack:
if expect_error_run is not None:
stack.push(self.assertRaisesRegex(Exception, expect_error_run))
if expect_error is not None:
stack.push(self.assertRaisesRegex(Exception, expect_error))
res = _temp_call_exported(exp_f, arg)
assert core.is_constant_shape(arg.shape)
res = jax_export.call_exported(exp_f)(arg)
if not expect_error_run:
if not expect_error:
self.assertAllClose(res, f(arg))
# Test the shape_check_module
shape_check = exp_f.shape_check_module()
# We have shape_check only if the exported has polymorphic inputs shapes.
if all(core.is_constant_shape(a.shape) for a in exp_f.in_avals):
self.assertIsNone(shape_check)
self.assertIsNone(expect_error_shape_check)
else:
self.assertIsNotNone(shape_check)
shape_check_module, shape_check_messages = shape_check
err_msg = _run_shape_check_module(exp_f,
shape_check_module, shape_check_messages, (arg,))
if expect_error_shape_check is None:
self.assertIsNone(err_msg)
else:
self.assertRegex(err_msg, expect_error_shape_check)
# An inner function is exported with polymorphic shapes inner_poly_spec, and
# is called from an outer function, which is exported with outer_poly_spec.
@jtu.parameterized_filterable(
@ -536,12 +395,16 @@ class JaxExportTest(jtu.JaxTestCase):
expect_error_outer_exp=None,
expect_error_run=None):
# Polymorphic export called with static or polymorphic shapes
if xc.mlir_api_version <= 51:
disabled_checks = (jax_export.DisabledSafetyCheck.shape_assertions(),)
else:
disabled_checks = ()
def inner(x): # x: inner_poly_spec
return jnp.reshape(x, (-1, x.shape[1]))
arg = np.arange(np.prod(arg_shape),
dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12]
inner_exp = jax_export.export(inner)(
inner_exp = jax_export.export(inner, disabled_checks=disabled_checks)(
jax_export.poly_spec((3, 4, 12), np.float32, inner_poly_spec))
self.assertEqual(inner_exp.module_uses_dim_vars,
@ -556,7 +419,7 @@ class JaxExportTest(jtu.JaxTestCase):
stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp))
# Call it after exporting again, with polymorphic shapes
outer_exp = jax_export.export(outer)(
outer_exp = jax_export.export(outer, disabled_checks=disabled_checks)(
jax_export.poly_spec(arg.shape, arg.dtype, outer_poly_spec))
if expect_error_outer_exp is not None:
@ -564,17 +427,12 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertEqual(outer_exp.module_uses_dim_vars,
(inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12"))
shape_check = outer_exp.shape_check_module()
if all(core.is_constant_shape(a.shape) for a in outer_exp.in_avals):
self.assertIsNone(shape_check)
else:
self.assertIsNotNone(shape_check)
with contextlib.ExitStack() as stack:
if expect_error_run is not None:
stack.push(self.assertRaisesRegex(Exception, expect_error_run))
res = _temp_call_exported(outer_exp, arg)
res = jax_export.call_exported(outer_exp)(arg)
if expect_error_run is not None:
return

View File

@ -87,7 +87,9 @@ class DimExprTest(tf_test_util.JaxToTfTestCase):
self.assertEqual((a, 3), shape_poly._parse_spec("(a, ...) ", tshape))
a, b = shape_poly._parse_spec("a, b", (2, 3))
@parameterized.named_parameters(
@jtu.parameterized_filterable(
kwargs=[
dict(testcase_name=f"_{dim_spec}",
dim_spec=dim_spec, dim_poly=dim_poly)
for dim_spec, dim_poly in [
@ -100,34 +102,36 @@ class DimExprTest(tf_test_util.JaxToTfTestCase):
("a + -1", a - 1),
("3 * a * mod(a + 2, b + 2)", 3 * a * ((a + 2) % (b + 2))),
("3 * floordiv(a + 2, b + 2) * 2", 3 * ((a + 2) // (b + 2)) * 2),
])
]])
def test_parse_dim(self,
dim_spec="-2 * a^2 * b + b^2",
dim_poly=-2 * a * a * b + b * b):
self.assertEqual((dim_poly,), shape_poly._parse_spec(dim_spec, (None,)))
self.assertEqual((dim_poly,), shape_poly._parse_spec(str(dim_poly), (None,)))
@parameterized.named_parameters(
@jtu.parameterized_filterable(
kwargs=[
dict(testcase_name=f"_{shape_spec=}",
shape_spec=shape_spec)
for shape_spec in [
"2.5", "a + a a", "a ^ a", "a, a",
"_", "...", "a ;", ")(", "2a", "a@", "'a'", "('a', ...)",
"mod(a)", "floordiv(a, b, c)", "..., 3"
])
]])
def test_parse_error(self,
shape_spec="a + a a"):
with self.assertRaisesRegex(ValueError,
"syntax error in polymorphic shape"):
shape_poly._parse_spec(shape_spec, (None,))
@parameterized.named_parameters(
@jtu.parameterized_filterable(
kwargs=[
dict(testcase_name=f"_{shape_spec=}",
shape_spec=shape_spec, arg_shape=arg_shape)
for shape_spec, arg_shape in [
("3", (4,)),
("b, 3", (None, 4)),
])
]])
def test_parse_mismatch_error(self,
shape_spec="3", arg_shape=(4,)):
with self.assertRaisesRegex(ValueError,
@ -347,7 +351,8 @@ class DimExprTest(tf_test_util.JaxToTfTestCase):
self.assertEqual(a * 2 // a, 2)
self.assertIsInstance(a * 2 // a, int)
@parameterized.named_parameters(
@jtu.parameterized_filterable(
kwargs=[
dict(testcase_name=f"_D={dividend}_d={divisor}_q={quotient}_r={remainder}",
dividend=dividend, divisor=divisor, quotient=quotient,
remainder=remainder)
@ -364,7 +369,7 @@ class DimExprTest(tf_test_util.JaxToTfTestCase):
(3 * a, 2, "floordiv(3*a, 2)", "mod(3*a, 2)"),
(2 * a * b + b * b, a + b, "floordiv(2*a*b + b^2, a + b)", "mod(2*a*b + b^2, a + b)"),
(3, a, "floordiv(3, a)", "mod(3, a)"),
])
]])
def test_poly_divmod(self, *, dividend, quotient, divisor, remainder):
if isinstance(quotient, str):
d1, d2 = divmod(dividend, divisor)
@ -930,37 +935,26 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
re.escape("pytree structure error: different types")):
conv_and_run(arg_shape=(2,), polymorphic_shape=("a tuple",))
# The following do not work yet with native serialization because
# XlaCallModule does not yet do shape checking.
if config.jax2tf_default_native_serialization:
return
# TODO(necula): enable even for native serialization
with self.assertRaisesRegex(ValueError,
"Cannot solve for values of dimension variables {'b'}"):
conv_and_run(arg_shape=(4, 36, 3), polymorphic_shape="b * b, b * d * d, d")
# TODO(necula): enable even for native serialization
with self.assertRaisesRegex(ValueError,
with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
"Dimension variable 'b' must have integer value >= 1"):
conv_and_run(arg_shape=(5, 36), polymorphic_shape="3 * b, ...")
# TODO(necula): enable even for native serialization
with self.assertRaisesRegex(ValueError,
with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
"Dimension variable 'b' must have integer value >= 1"):
conv_and_run(arg_shape=(10, 3), polymorphic_shape="3 * b + 10, ...")
# TODO(necula): enable even for native serialization
with self.assertRaisesRegex(ValueError,
with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
"Dimension variable 'b' must have integer value >= 1"):
conv_and_run(arg_shape=(7, 3), polymorphic_shape="3 * b + 10, ...")
# TODO(necula): enable even for native serialization
with self.assertRaisesRegex(
ValueError,
tf.errors.InvalidArgumentError,
"Found inconsistency 3 != 2 when solving.*"):
conv_and_run(arg_shape=(2, 3), polymorphic_shape="(a, a)")
def test_pytree(self):
"""Arguments and polymorphic_shapes are pytrees."""
@ -1250,7 +1244,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
res_tf = jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x)
self.assertAllClose(f(x), res_tf)
@jtu.sample_product(with_function=[False, True])
@jtu.parameterized_filterable(
kwargs=[dict(with_function=v) for v in [True, False]]
)
def test_grad_int(self, with_function=False):
# https://github.com/google/jax/issues/7093
# Also issue #6975.
@ -1418,34 +1414,56 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
jax2tf.convert(lambda x: 0 if x.shape[0] + 1 >= x.shape[1] else 1,
polymorphic_shapes=["(a, b)"])(np.ones((4, 4)))
# Unsoundness: not checking that the dimension variable is 0
# Checking that the dimension variable is >= 1
def f1_jax(x): # f32[b]
# We have to use "x"
return jnp.concatenate([x, jnp.array([0. if x.shape[0] == 0 else 1.],
dtype=np.float32)])
x0 = np.array([], np.float32)
# JAX with static shapes sees that the x.shape[0] == 0
self.assertEqual(jnp.array([0.], dtype=np.float32), f1_jax(x0))
# In graph serialization eager mode we catch the error
with self.assertRaisesRegex(
ValueError,
tf.errors.InvalidArgumentError,
"Dimension variable 'b' must have integer value >= 1. Found 0"):
jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
native_serialization=False)(x0)
# In native serialization, or if we trace to a TF graph, we miss this
res1_tf = jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
native_serialization=True)(x0)
self.assertEqual(jnp.array([1.], dtype=np.float32), res1_tf)
# In graph serialization graph mode we also catch it (except on TPU)
f1_tf = tf.function(
jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
native_serialization=False)
native_serialization=False),
autograph=False,
).get_concrete_function(tf.TensorSpec([None], dtype=np.float32))
# In graph serialization graph mode we also catch it (except on TPU, where
# the behavior is as for jit_compile=1)
if jtu.device_under_test() == "tpu":
self.assertEqual(jnp.array([1.], dtype=np.float32), f1_tf(x0))
else:
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
"Dimension variable .* must have integer value"):
_ = f1_tf(x0)
# In graph serialization with jit_compile=True we do not catch the error
# and we return the wrong result
f1_tf = tf.function(
jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
native_serialization=False),
autograph=False,
jit_compile=True
)
self.assertEqual(jnp.array([1.], dtype=np.float32), f1_tf(x0))
# Unsoundness: not checking that the actual dimensions denoted by the same
# We also catch the error with native serialization
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
"Dimension variable 'b' must have integer value >= 1. Found 0"):
_ = jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
native_serialization=True)(x0)
# Checking that the actual dimensions denoted by the same
# dimension variables have equal sizes.
def f2_jax(x): # f32[b, b]
# We have to use "x"
@ -1455,25 +1473,46 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
# JAX with static shapes sees that x.shape[0] != x.shape[1]
self.assertEqual(jnp.sum(x45), f2_jax(x45))
# jax2tf catches the broken assumption b >= 1 if the converted function is executed
# eagerly.
# In graph serialization eager mode, we catch the broken assumption b >= 1
with self.assertRaisesRegex(
ValueError,
tf.errors.InvalidArgumentError,
r"Found inconsistency 5 != 4 when solving b == args\[0\].shape\[1\]"):
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
native_serialization=False)(x45)
# In native serialization, or if we trace to a TF graph, we miss this
res2_tf = jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
native_serialization=True)(x45)
self.assertEqual(1. + jnp.sum(x45), res2_tf)
# In graph serialization graph mode we also catch it (except on TPU, where
# the behavior is as for jit_compile=1)
f2_tf = tf.function(
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
native_serialization=False)
native_serialization=False),
autograph=False,
).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32))
if jtu.device_under_test() == "tpu":
self.assertEqual(1. + jnp.sum(x45), f2_tf(x45))
else:
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
r"Found inconsistency"):
_ = f2_tf(x45)
# In graph serialization with jit_compile=True we do not catch the error
# and we return the wrong result
f2_tf = tf.function(
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
native_serialization=False),
autograph=False,
jit_compile=True
)
self.assertEqual(1. + jnp.sum(x45), f2_tf(x45))
# We also catch the error with native serialization
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
"Found inconsistency 5 != 4"):
_ = jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
native_serialization=True)(x45)
x = np.ones((5,), dtype=np.float32)
with self.assertRaisesRegex(
ValueError,

View File

@ -27,7 +27,6 @@ from typing import Any, Sequence
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
@ -180,12 +179,13 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
else:
assert False
@parameterized.named_parameters(
@jtu.parameterized_filterable(
kwargs=[
dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}",
in_shardings=in_shardings, out_shardings=out_shardings)
for in_shardings in ("missing", None, "P")
for out_shardings in ("missing", None, "P")
)
])
@jtu.with_mesh([("x", 2)])
def test_pjit_basic(self, in_shardings="P", out_shardings="P"):
# Ensure that we can distinguish the inputs and outputs by shape
@ -316,14 +316,15 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
res_tf = f_tf(x)
self.assertAllClose(res_tf, res_jax)
@parameterized.named_parameters(
@jtu.parameterized_filterable(
kwargs=[
dict(testcase_name=f"_nested_pjit={nested_pjit}_constraint={constraint}_poly={poly}",
nested_pjit=nested_pjit, constraint=constraint, poly=poly)
# We add a constraint either with a nested pjit or with a sharding_constraint
for nested_pjit in (True, False)
for constraint in (None, "P")
for poly in (None, "2*b1,_", "_,b2", "2*b1,b2")
)
])
@jtu.with_mesh([("x", 2)])
def test_pjit_sharding_constraint(self, nested_pjit=True, constraint="P", poly="2*b1,b2"):
constraint_sharding = P("x", None) if constraint == "P" else None
@ -363,12 +364,13 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
(r"custom_call_target.*Sharding", 2 + count_inner_sharding + count_inner_replicated)
])
@parameterized.named_parameters(
@jtu.parameterized_filterable(
kwargs=[
dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}",
in_shardings=in_shardings, out_shardings=out_shardings)
for in_shardings in ("missing", None, "P")
for out_shardings in ("missing", None, "P")
)
])
@jtu.with_mesh([("x", 2)])
def test_grad_pjit(self, in_shardings="P", out_shardings=None):
def f_jax(x): # x: f32[10,20] -> f32[20,10]
@ -414,7 +416,8 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", count_out_replicated),
])
@parameterized.named_parameters(
@jtu.parameterized_filterable(
kwargs=[
dict(testcase_name=f"_kind={kind}_in_shardings={in_shardings}_out_shardings={out_shardings}",
kind=kind, in_shardings=in_shardings, out_shardings=out_shardings)
for kind in ("pjit", "jit", "sharding_constraint")
@ -425,7 +428,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
for out_shardings in (
("unspecified",) if kind in ["sharding_constraint", "jit"] else
("unspecified", "none", "P"))
)
])
def test_pjit_error_inner_sharding(self, kind="pjit", in_shardings="P",
out_shardings="none"):
# Check that we raise an error if there is no top-level pjit but we convert
@ -462,11 +465,12 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
with Mesh(self.devices, axis_names=("x",)):
f_tf(x)
@parameterized.named_parameters(
@jtu.parameterized_filterable(
kwargs=[
dict(testcase_name=f"_func={func}", func=func)
for func in ("pjit_sharded", "pjit_replicated",
"nested_pjit_sharded", "nested_pjit_replicated")
)
])
def test_pjit_eager_error(self, func="pjit_sharded"):
if config.jax2tf_default_native_serialization:
raise unittest.SkipTest("There is no error in eager mode for native serialization")
@ -693,10 +697,11 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
res_tf = f_tf(a)
self.assertAllClose(res_tf, expected)
@parameterized.named_parameters(
@jtu.parameterized_filterable(
kwargs=[
dict(testcase_name=f"_poly={poly}", poly=poly)
for poly in (None, "2*b1,_", "_,b2", "2*b1,b2")
)
])
def test_shmap_collective_permute(self, poly=None):
if jtu.device_under_test() == "cpu":
raise unittest.SkipTest("TODO(b/268295912): ShardingRemover crash")