diff --git a/CHANGELOG.md b/CHANGELOG.md index 1498ff306..428dc7d3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,6 +68,9 @@ Remember to align the itemized text with the first line of an item within a list * JAX now supports a configuration flag --jax_serialization_version and a JAX_SERIALIZATION_VERSION environment variable to control the serialization version ({jax-issue}`#16746`). + * jax2tf in presence of shape polymorphism now generates code that checks + certain shape constraints, if the serialization version is at least 7. + See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism. ## jaxlib 0.4.14 diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 518bac48c..b8b52a284 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2044,7 +2044,7 @@ def custom_call( out_types: Sequence[ir.Type], operands: Sequence[ir.Value], *, - backend_config: Union[str, dict] = "", + backend_config: Union[str, dict[str, ir.Attribute]] = "", has_side_effect: bool = False, result_shapes: Optional[Sequence[ir.Value]] = None, called_computations: Sequence[str] = (), diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 1501a2ae9..326633f03 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1230,18 +1230,21 @@ def parameterized_filterable(*, Args: kwargs: Each entry is a set of kwargs to be passed to the test function. testcase_name: Optionally, a function to construct the testcase_name from - one kwargs dict. If not given then the kwarg must contain `testcase_name`. + one kwargs dict. If not given then kwarg may contain `testcase_name` and + if not, the test case name is constructed as `str(kwarg)`. one_containing: If given, then leave the test name unchanged, and use only one `kwargs` whose `testcase_name` includes `one_containing`. """ # Ensure that all kwargs contain a testcase_name kwargs_with_testcase_name: Sequence[dict[str, Any]] if testcase_name is not None: - kwargs_with_testcase_name = [dict(testcase_name=testcase_name(kw), **kw) + kwargs_with_testcase_name = [dict(testcase_name=str(testcase_name(kw)), **kw) for kw in kwargs] else: for kw in kwargs: - assert "testcase_name" in kw + if "testcase_name" not in kw: + kw["testcase_name"] = "_".join(f"{k}={str(kw[k])}" + for k in sorted(kw.keys())) kwargs_with_testcase_name = kwargs if one_containing is not None: filtered = tuple(kw for kw in kwargs_with_testcase_name diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index af5a190d9..0c83a15b1 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -547,11 +547,9 @@ cannot be used anymore as dimension parameters and will raise a JAX error. ### Errors in presence of shape polymorphism -If you write your program assuming that all shapes are tuples of integers, -and then try to trace it with shape polymorphism you can run into a number -of errors. - -The program: +Most JAX code assumes that the shapes of JAX arrays are tuples of integers, +but with shape polymorphism some dimensions may be symbolic expressions. +This can lead to a number of errors. For example, the program: ```python four_ones = np.ones((4,)) @@ -597,6 +595,30 @@ The solution is to avoid `np.array`, `float`, or JAX arrays in operations whose results are used as shapes, e.g., instead of `np.arange(n) * x.shape[0]` write `[i * x.shape[0] for i in range(n)]`. +JAX assumes that dimension variables range over strictly positive integers. +These assumptions are now checked against the shapes of the actual arguments +when the lowered code is invoked. +For example, given the `polymorphic_shapes="(b, b, 2*d)"` +specification, we will generate code to check the following constraints when +invoked with actual argument `arg`: + + * `arg.shape[0] >= 1` + * `arg.shape[1] == arg.shape[0]` + * `arg.shape[2] % 2 == 0` and `arg.shape[0] // 2 >= 1` + +When using native serialization these are checked by the `tf.XlaCallModule` +op (starting with serialization +[version 7](https://github.com/search?q=repo%3Agoogle%2Fjax+path%3Aconfig.py+jax_serialization_version&type=code)), +and you will get `tf.errors.InvalidArgument` errors. +You can disable this checking by including `DisabledSafetyCheck.shape_assertions()` +in the `disabled_checks` parameter to `jax2tf.convert`, or by setting +the environment variable +`TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=shape_assertions`. +When using graph serialization these are checked using `tf.debugging.assert`, +which will also result in `tf.errors.InvalidArgument`. +Note that due to limitations in TensorFlow, these errors are suppressed when using +`jit_compile=True` and when running on TPU. + ### Comparison of symbolic dimensions is partially supported Inside JAX there are a number of equality and inequality comparisons diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index a2ed5ecce..3d7fbf851 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -567,17 +567,6 @@ class GraphSerializationImpl(SerializationImpl): self.args_avals_flat, args_kwargs_tree=self.in_tree), self.args_flat_tf, self.args_avals_flat, self.name_stack) - # We invoke shape checking to give it a chance to raise shape errors that - # are evident statically. This should work in TF eager mode because all - # the shapes are known. - # TODO: handle non-static shape checking for graph serialization - acc_shape_check_messages: list[str] = [] - _, _ = _interpret_fun_jax( - partial(shape_poly.compute_shape_check_from_arg_shapes, - self.args_avals_flat, args_kwargs_tree=self.in_tree, - acc_shape_check_messages=acc_shape_check_messages), - self.args_flat_tf, self.args_avals_flat, self.name_stack) - _thread_local_state.shape_env = zip(dim_vars, dim_values) fun_flat_jax, out_tree_thunk = flatten_fun_jax(self.fun_jax, self.in_tree) @@ -3386,6 +3375,16 @@ def _dim_as_value_jax2tf(dim: shape_poly.DimSize): tf_impl[shape_poly.dim_as_value_p] = _dim_as_value_jax2tf +def _shape_assertion_jax2tf(assert_what, *error_message_inputs, + error_message: str): + + tf.debugging.assert_equal( + assert_what, True, + message=error_message.format(*error_message_inputs)) + return [] + +tf_impl[shape_poly.shape_assertion_p] = _shape_assertion_jax2tf + def _reduce_precision(x, *, exponent_bits, mantissa_bits): return tfxla.reduce_precision(x, exponent_bits=exponent_bits, mantissa_bits=mantissa_bits) diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index 01dc0c1b0..f470f8dc6 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -82,6 +82,15 @@ class DisabledSafetyCheck: """ return DisabledSafetyCheck(f"custom_call:{target_name}") + @classmethod + def shape_assertions(cls) -> "DisabledSafetyCheck": + """Allows invocations with shapes that do not meet the constraints. + + Has effect on serialization (to supress the generation of the assertions) + and also on deserialization (to suppress the checking of the assertions). + """ + return DisabledSafetyCheck("shape_assertions") + def is_custom_call(self) -> Optional[str]: """Returns the custom call target allowed by this directive.""" m = re.match(r'custom_call:(.+)$', self._impl) @@ -163,66 +172,6 @@ class Exported: def mlir_module(self) -> ir.Module: return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized) - def shape_check_module(self) -> Optional[tuple[bytes, Sequence[str]]]: - """Generates a serialized shape checking module and the error messages. - - Consider the exporting of a function with one array argument of type - "f32[w, 2 * h]", where "w" and "h" are two dimension variables. JAX tracing - assumes that `arg.shape[1]` is even, and that both `w` and `h` have - values >= 1. We must ensure that these assumptions hold when the function - is invoked. - - If we `call_exported` for this module we perform these checks - statically (in `call_exported_abstract_eval`). - - But if we wrap the exported module with XlaCallModule for execution outside - JAX, we must defer these checks to compile time, when the static shapes - are known. - We generate a separate MLIR shape checking module with a main - function that returns a triple of int32's, with a shape checking code and - two shape checking operands. The shape checking code is -1 if the checks - pass, or is an index into the returned shape_check_messages otherwise. - The returned shape checking operands are substituted for "%1" and "%2" in - the shape checking messages to obtain the error message. Here is a shape - checking module for our example: - - func public main(arg: f32[?, ?]) { - # code will be the index of the last failed shape check. - code = op1 = op2 = -1 # Assume no errors initially - # Check that w is >= 1 - arg_w = hlo.get_dimension_size(arg, 0) - code = hlo.select(arg_w >= 1, code, 0) # error message 0 - op1 = hlo.select(arg_w >= 1, op1, arg_w) - op2 = hlo.select(arg_w >= 1, op2, 1) - # Check that dim1 is even - dim1 = hlo.get_dimension_size(arg, 1) - code = hlo.select(dim1 % 2 == 0, code, 1) # error message 1 - op1 = hlo.select(dim1 % 2 == 0, op1, dim1 % 2) - op2 = hlo.select(dim1 % 2 == 0, op2, 0) - # Check that h >= 1 - arg_h = hlo.floordiv(dim1, 2) - code = hlo.select(arg_h >= 1, code, 2) # error message 2 - op1 = hlo.select(arg_h >= 1, op1, arg_h) - op2 = hlo.select(arg_h >= 1, op2, 1) - return (code, op1, op2) - - We also return the following shape checking messages to be used for - constructing error messages: - - shape_check_messages = [ - "Dimension variable 'w' must have integer value >= 1. Found %1", - "Dimension variable 'h' must have integer value >= 1. Found non-zero remainder %1", - "Dimension variable 'h' must have integer value >= 1. Found %1"] - """ - shape_check_res = _make_shape_check_module( - self.in_avals, args_kwargs_tree=self.in_tree) - if shape_check_res is None: - return None - module, shape_check_messages = shape_check_res - module_serialized, xla_call_module_version = _serialize_module(module) - assert xla_call_module_version == self.xla_call_module_version - return module_serialized, shape_check_messages - def __str__(self): # This is called to make a MLIR source location when we call an Exported, and we # do not want the entire serialized module to end up in locations. @@ -359,6 +308,8 @@ def export(fun_jax: Callable, exported = jax_export.export(f_jax)(*args, **kwargs) """ fun_name = getattr(fun_jax, "__name__", "unknown") + version = config.jax_serialization_version + def do_export(*args_specs, **kwargs_specs) -> Exported: if not hasattr(fun_jax, "lower"): # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also @@ -373,28 +324,39 @@ def export(fun_jax: Callable, allow_non_replicated_sharding = True lowering_platform_str = lowering_platform or default_lowering_platform() - lowered = wrapped_fun_jax.lower( - *args_specs, **kwargs_specs, - _experimental_lowering_platform=lowering_platform_str) - lowering = lowered._lowering # type: ignore - _check_lowering(lowering) - mlir_module = lowering.stablehlo() - args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals) - if "kept_var_idx" in lowering.compile_args: - module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"])) - else: - # For pmap - module_kept_var_idx = tuple(range(len(args_avals_flat))) - shape_poly_state = lowering.compile_args["shape_poly_state"] - if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat) - or lowering.compile_args.get("ordered_effects", [])): - # All arguments are kept if we have dimension variables. - assert len(module_kept_var_idx) == len(args_avals_flat) - mlir_module = _wrap_main_func( - mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree - ) - mlir_module_serialized, xla_call_module_version = _serialize_module(mlir_module) + # Do not include shape assertions if the version is < 7. + enable_shape_assertions = ( + DisabledSafetyCheck.shape_assertions() not in disabled_checks and + version >= 7) # type: ignore + try: + prev_enable_shape_assertions = shape_poly.thread_local_state.enable_shape_assertions + shape_poly.thread_local_state.enable_shape_assertions = enable_shape_assertions + lowered = wrapped_fun_jax.lower( + *args_specs, **kwargs_specs, + _experimental_lowering_platform=lowering_platform_str) + + lowering = lowered._lowering # type: ignore + _check_lowering(lowering) + mlir_module = lowering.stablehlo() + + args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals) + if "kept_var_idx" in lowering.compile_args: + module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"])) + else: + # For pmap + module_kept_var_idx = tuple(range(len(args_avals_flat))) + shape_poly_state = lowering.compile_args["shape_poly_state"] + if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat) + or lowering.compile_args.get("ordered_effects", [])): + # All arguments are kept if we have dimension variables. + assert len(module_kept_var_idx) == len(args_avals_flat) + mlir_module = _wrap_main_func( + mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree + ) + finally: + shape_poly.thread_local_state.enable_shape_assertions = prev_enable_shape_assertions + mlir_module_serialized = _serialize_module(mlir_module) # Figure out the result types and shapes if "global_out_avals" in lowering.compile_args: @@ -408,7 +370,7 @@ def export(fun_jax: Callable, # Log and then check the module. if logging.vlog_is_on(3): mlir_module_text = mlir.module_to_string(mlir_module) - logmsg = (f"version={xla_call_module_version} " + logmsg = (f"version={version} " f"lowering_platform={lowering_platform_str} " f"disabled_checks={disabled_checks}") logging.info("Lowered JAX module: %s\n", logmsg) @@ -432,14 +394,13 @@ def export(fun_jax: Callable, mlir_module_serialized=mlir_module_serialized, module_kept_var_idx=module_kept_var_idx, module_uses_dim_vars=shape_poly_state.uses_dim_vars, - xla_call_module_version=xla_call_module_version, + xla_call_module_version=version, # type: ignore _get_vjp=lambda exported: _export_native_vjp(fun_jax, exported)) return do_export -def _serialize_module(module: ir.Module) -> tuple[bytes, int]: - xla_call_module_version = config.jax_serialization_version +def _serialize_module(module: ir.Module) -> bytes: mlir_str = mlir.module_to_bytecode(module) if hlo.get_api_version() < 4: target_version = hlo.get_earliest_forward_compatible_version() @@ -463,7 +424,7 @@ def _serialize_module(module: ir.Module) -> tuple[bytes, int]: target_version = hlo.get_minimum_version() module_serialized = xla_client._xla.mlir.serialize_portable_artifact( mlir_str, target_version) - return module_serialized, xla_call_module_version + return module_serialized def _wrap_main_func( @@ -481,9 +442,9 @@ def _wrap_main_func( the `args_avals`, in sorted order. Consider the lowering of a function with one array argument of type "f32[w, - 2 * h]", - where "w" and "h" are two dimension variables. The `module` will also - contain two dimension arguments, corresponding to "h" and "w" respectively: + 2 * h]", where "w" and "h" are two dimension variables. The `module` will + also contain two dimension arguments, corresponding to "h" and "w" + respectively: func public main(arg_h: i32, arg_w: i32, arg: f32[?, ?]) { ... @@ -495,6 +456,7 @@ def _wrap_main_func( arg_w = hlo.get_dimension_size(arg, 0) dim1 = hlo.get_dimension_size(arg, 1) arg_h = hlo.floordiv(dim1, 2) + call _check_shape_assertions(arg) # See below res = call _wrapped_jax_export_main(arg_h, arg_w, arg) return res } @@ -503,6 +465,31 @@ def _wrap_main_func( function by providing dummy values. This ensures that the main function's calling convention is as expected. + Note that the lowering contains a call to `_check_shape_assertions. + JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h` + have values >= 1. We must check these constraints when we invoke the + module. We use a special custom call `@shape_assertion` that takes + a boolean first operand, a string `error_message` attribute that may contain + format specifiers `{0}`, `{1}`, ..., and a variadic number of integer + scalar operands corresponding to the format specifiers. + + func private _check_shape_assertions(arg: f32[?, ?]) { + # Check that w is >= 1 + arg_w = hlo.get_dimension_size(arg, 0) + custom_call @shape_assertion(arg_w >= 1, arg_w, + error_message="Dimension variable 'w' must have integer value >= 1. Found {0}") + # Check that dim1 is even + dim1 = hlo.get_dimension_size(arg, 1) + custom_call @shape_assertion(dim1 % 2 == 0, dim1, + error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}") + # Check that h >= 1 + arg_h = hlo.floordiv(dim1, 2) + custom_call @shape_assertion(arg_h >= 1, arg_h, + error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}") + + If we `call_exported` with this module we perform these checks + statically (in `call_exported_abstract_eval`). + Args: module: the HLO module as obtained from lowering. May have a number of dimension arguments, followed by the kept array arguments. @@ -601,55 +588,6 @@ def _wrap_main_func( return wrapped_module - -def _make_shape_check_module( - args_avals_flat: Sequence[core.AbstractValue], - *, - args_kwargs_tree: tree_util.PyTreeDef -) -> Optional[tuple[ir.Module, Sequence[str]]]: - """Codegens the shape checking function. - - The shape checking function takes the array inputs and returns a triple. - See `Exported.shape_check_module` docstring. - - Returns the shape checking module and the tuple of shape checking messages. - """ - context = mlir.make_ir_context() - with context, ir.Location.unknown(context): - array_input_types = tuple(mlir.aval_to_ir_type(a) for a in args_avals_flat) - module = ir.Module.create(loc=ir.Location.unknown(context)) - - symbol_table = ir.SymbolTable(module.operation) - shape_check_code_type = mlir.aval_to_ir_type(core.ShapedArray((), np.int32)) - shape_check_output_types = (shape_check_code_type,) * 3 - shape_check_ftype = ir.FunctionType.get(array_input_types, - shape_check_output_types) - shape_check_func = func_dialect.FuncOp( - "main", shape_check_ftype, - ip=ir.InsertionPoint.at_block_begin(module.body)) - shape_check_func.attributes["sym_visibility"] = ir.StringAttr.get("public") - symbol_table.insert(shape_check_func) - entry_block = shape_check_func.add_entry_block() - with ir.InsertionPoint(entry_block): - module_context = mlir.ModuleContext( - "cpu", "cpu", sharding_impls.ShardingContext([]), - source_info_util.new_name_stack(), - [], itertools.count(1), [], module=module, context=context) - ctx = mlir.LoweringRuleContext(module_context=module_context, - primitive=None, avals_in=args_avals_flat, - avals_out=None, tokens_in=mlir.TokenSet(), - tokens_out=None) - acc_shape_check_messages: list[str] = [] - values = mlir.lower_fun( - functools.partial(shape_poly.compute_shape_check_from_arg_shapes, - args_avals_flat, args_kwargs_tree=args_kwargs_tree, - acc_shape_check_messages=acc_shape_check_messages), - multiple_results=True)(ctx, *shape_check_func.arguments) - func_dialect.ReturnOp(util.flatten(values)) - - return (module, acc_shape_check_messages) if acc_shape_check_messages else None - - def _check_lowering(lowering) -> None: if not isinstance(lowering, pxla.MeshComputation): raise NotImplementedError(f"serialization is supported only for pjit. {lowering}") @@ -749,6 +687,7 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = { # See https://github.com/openxla/stablehlo/issues/8. "stablehlo.dynamic_reduce_window", "stablehlo.dynamic_rng_bit_generator", + "shape_assertion", # Used by shape_poly to evaluate assertions } @@ -763,6 +702,7 @@ def _check_module(mod: ir.Module, *, disabled_checks: the safety checks that are disabled. """ sharding_attr = ir.StringAttr.get("Sharding", mod.context) + shape_assertion_attr = ir.StringAttr.get("shape_assertion", mod.context) allowed_custom_call_targets: set[str] = copy.copy(_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) for dc in disabled_checks: target = dc.is_custom_call() @@ -799,6 +739,8 @@ def _check_module(mod: ir.Module, *, disallowed_custom_call_ops.append(f"{op} at {op.location}") if call_target_name_attr == sharding_attr: check_sharding(op, op.location) + elif call_target_name_attr == shape_assertion_attr: + assert (DisabledSafetyCheck.shape_assertions() not in disabled_checks) def walk_operations(op): check_op(op) diff --git a/jax/experimental/jax2tf/shape_poly.py b/jax/experimental/jax2tf/shape_poly.py index 38dbb6a0f..94f569d6c 100644 --- a/jax/experimental/jax2tf/shape_poly.py +++ b/jax/experimental/jax2tf/shape_poly.py @@ -38,6 +38,7 @@ import itertools import io import math import operator as op +import threading import tokenize from typing import Any, Callable, Iterable, Optional, Sequence, Union @@ -50,6 +51,7 @@ from jax.interpreters import xla from jax._src import core from jax._src import dtypes +from jax._src import effects from jax._src.lax import lax from jax._src.interpreters import mlir from jax._src.numpy import lax_numpy @@ -79,6 +81,15 @@ for more details. # https://github.com/python/mypy/issues/5887 super().__init__(error_msg) # type: ignore +class _ShapePolyThreadLocalState(threading.local): + + def __init__(self): + # TODO(necula): this does not play well with some lowering caches, because + # this state is not part of the cache key. + self.enable_shape_assertions = True + +thread_local_state = _ShapePolyThreadLocalState() + class _DimAtom: """Represents an atom in a symbolic dimension expression. @@ -792,6 +803,62 @@ def _einsum_contract_path(*operands, **kwargs): lax_numpy._poly_einsum_handlers[_DimExpr] = _einsum_contract_path +# To implement shape-constraint checking we use a shape assertion primitive. +# shape_assertion_p.bind(assert_what: bool, *error_message_inputs, +# error_message="...{0}...{1}") +# where "{0}" refers to error_message_inputs[0], etc. +shape_assertion_p = core.Primitive("shape_assertion") +shape_assertion_p.multiple_results = True +shape_assertion_p.def_effectful_abstract_eval( + lambda *_, **__: ((), {shape_assertion_effect})) # type: ignore + +def _shape_assertion_lowering_rule(ctx: mlir.LoweringRuleContext, + assert_what: mlir.ir.Value, + *error_message_inputs: mlir.ir.Value, + error_message: str): + op = mlir.custom_call( + "shape_assertion", + [], # No results + [assert_what, *error_message_inputs], + has_side_effect=True, + extra_attributes=dict(error_message=mlir.ir.StringAttr.get(error_message)) + ) + return op.results + +mlir.register_lowering(shape_assertion_p, _shape_assertion_lowering_rule) + +class ShapeAssertionEffect(effects.Effect): + __str__ = lambda _: "ShapeAssertionEffect" + +shape_assertion_effect = ShapeAssertionEffect() + +effects.lowerable_effects.add_type(ShapeAssertionEffect) +effects.control_flow_allowed_effects.add_type(ShapeAssertionEffect) +effects.remat_allowed_effects.add_type(ShapeAssertionEffect) +effects.custom_derivatives_allowed_effects.add_type(ShapeAssertionEffect) + +def shape_assertion(assert_what: jax.Array, + *error_message_inputs: jax.Array, + error_message: str) -> None: + """Adds a shape assertion in the code. + + Args: + assert_what: a boolean asserted to be true. Must be computed based only + on dimension expressions, so that it can be evaluated after shape + refinement. + error_message_inputs: integers expressions whose values can be referenced + in the `error_message`. Must be computed based only + on dimension expressions, so that they can be evaluated after shape + refinement. + error_message: an error message, possibly containing format specifiers + {0}, {1}, ..., referencing the values of the `error_message_inputs`. + The format specifiers are sometimes processed with Python's + `string::format` method, and sometimes with `llvm::formatv`. + """ + if thread_local_state.enable_shape_assertions: + shape_assertion_p.bind(assert_what, *error_message_inputs, + error_message=error_message) + # A JAX primitive with no array arguments but with a dimension parameter # that is a DimExpr. The value of the primitive is the value of the dimension, # using int64 in x64 mode or int32 otherwise (core.dim_value_dtype()) @@ -1055,22 +1122,6 @@ def _evaluate_multiply(v1, v2): pass return v1 * v2 -def _is_known_constant(v) -> Optional[int]: - try: - return int(v) - except Exception: - # TODO(necula): added this so that in jax2tf, in Eager mode, we can tell - # that a tensor is a constant. We should move this dependency into some - # jax2tf-specific area. - if hasattr(v, "val"): - try: - vint = int(v.val) - if isinstance(vint, int): # In TF, int(tf.Tensor) is tf.Tensor! - return vint - except Exception: - pass - return None - # dimension_size(operand, dimension=i) get the operand.shape[i] as a # value of type shape_poly.dim_as_value_dtype(). dimension_size_p = core.Primitive("dimension_size") @@ -1180,8 +1231,8 @@ class ShapeConstraint: left, right = [self._eval_operand(o, shapeenv) for o in [self.left, self.right]] # Try to evaluate the constraint statically. - left_int, right_int = _is_known_constant(left), _is_known_constant(right) - if left_int is not None and right_int is not None: + if core.is_constant_shape((left, right)): + left_int, right_int = op.index(left), op.index(right) if self.comp == ShapeConstraint.Comparator.EQ: if not (left_int == right_int): raise ValueError(self.make_err_msg(left_int, right_int)) @@ -1225,34 +1276,21 @@ class ShapeConstraints: for constraint in self.constraints: constraint.check_statically(shapeenv) - def compute(self, shapeenv: DimVarEnv) -> tuple[jax.Array, jax.Array, jax.Array, Sequence[str]]: - """Computes the error code for the set of constraints. + def shape_assertions(self, shapeenv: DimVarEnv) -> None: + """Computes the shape assertions the set of constraints. - The error code is -1 if all constraints are satisfied, or an index into - the returned error messages. - See Exported.shape_check_module docstring. + See _wrap_main_func docstring. """ # We want to report the errors in the same order as `check_statically`. - # So, we process them in order, in case some fail statically, but we - # accumulate the errors in reverse order in the computation, because the - # code-generation strategy will report the last error that failed. - acc: list[tuple[jax.Array, jax.Array, jax.Array, str]] = [] + # So, we process them in order, in case some fail statically, and we + # generate the shape assertions in the same order. for constraint in self.constraints: check_res = constraint.compute(shapeenv) if check_res is not None: - acc.append((*check_res, constraint.make_err_msg("%1", "%2"))) # type: ignore - - shape_check_messages: list[str] = [] - shape_check_code: jax.Array = np.int32(-1) # type: ignore - shape_check_op1 = shape_check_op2 = shape_check_code - for (is_ok, op1, op2, msg) in reversed(acc): - shape_check_code = lax.select(is_ok, shape_check_code, - np.int32(len(shape_check_messages))) - shape_check_op1 = lax.select(is_ok, shape_check_op1, op1) - shape_check_op2 = lax.select(is_ok, shape_check_op2, op2) - shape_check_messages.append(msg) - return (shape_check_code, shape_check_op1, shape_check_op2, shape_check_messages) - + is_ok, op1, op2 = check_res + shape_assertion( + is_ok, op1, op2, + error_message=constraint.make_err_msg("{0}", "{1}")) @dataclasses.dataclass class _DimEquation: @@ -1352,7 +1390,8 @@ def compute_dim_vars_from_arg_shapes( Like `solve_dim_vars` except that here we express the solution as JAX arrays that reference the `actual_args`. This function can be used to - generate the code for computing the dimension variables. + generate the code for computing the dimension variables. It also generates + the shape assertions. Returns: the values of the dimension variables, in the order determined by `all_dim_vars(args_avals)`. @@ -1366,39 +1405,9 @@ def compute_dim_vars_from_arg_shapes( dimension=dim_idx) for (vname, arg_idx, dim_idx) in synth_dim_vars} dim_values = [solution[var].evaluate(synthetic_env) for var in dim_vars] + shape_constraints.shape_assertions(synthetic_env) return tuple(dim_values) - -def compute_shape_check_from_arg_shapes( - args_avals: Sequence[core.AbstractValue], - *actual_args: jax.Array, - args_kwargs_tree: tree_util.PyTreeDef, - acc_shape_check_messages: list[str]) -> Sequence[jax.Array]: - """Computes the shape check code from the actual arguments. - - `acc_shape_check_messages` is an initially empty list where we append the - shape checking messages. Each of these correspond to a constraint. - See the `Exported.shape_check_module` docstring. - - Returns: a triple of JAX arrays: the shape check code and the values of the - two operands for the failed constraint. - - Raises ValueError if a constraint is invalidated statically. - """ - assert not acc_shape_check_messages - solution, shape_constraints, synth_dim_vars = solve_dim_vars( - tuple(args_avals), args_kwargs_tree=args_kwargs_tree) - - # Replace the synthetic vars with the dynamic shape of the actual arg - synthetic_env = {vname: dimension_size_p.bind(actual_args[arg_idx], - dimension=dim_idx) - for (vname, arg_idx, dim_idx) in synth_dim_vars} - shape_check_code, shape_check_op1, shape_check_op2, shape_check_messages = ( - shape_constraints.compute(synthetic_env)) - acc_shape_check_messages.extend(shape_check_messages) - return (shape_check_code, shape_check_op1, shape_check_op2) - - def _solve_dim_equations( eqns: list[_DimEquation] ) -> tuple[DimVarEnv, ShapeConstraints]: diff --git a/jax/experimental/jax2tf/tests/back_compat_test.py b/jax/experimental/jax2tf/tests/back_compat_test.py index 6d875be02..1a56b5925 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_test.py @@ -125,6 +125,8 @@ class CompatTest(bctu.CompatTestBase): covered_targets = covered_targets.union({ "tpu_custom_call", # tested separately + # TODO(necula): add tests for shape_assertion + "shape_assertion", }) not_covered = targets_to_cover.difference(covered_targets) self.assertEmpty(not_covered) @@ -608,7 +610,9 @@ class CompatTest(bctu.CompatTestBase): data = self.load_testdata(tpu_stablehlo_dynamic_reduce_window.data_unary_2023_06_17) self.run_one_test( func, data, - polymorphic_shapes=("b, ...",)) + polymorphic_shapes=("b, ...",), + # TODO(necula): now also includes shape_assertion + compare_with_current=False) def test_tpu_stablehlo_dynamic_reduce_window_variadic(self): # stablehlo.dynamic_reduce_window is used temporarily on TPU for a @@ -629,7 +633,9 @@ class CompatTest(bctu.CompatTestBase): data = self.load_testdata(tpu_stablehlo_dynamic_reduce_window.data_variadic_2023_06_17) self.run_one_test( func, data, - polymorphic_shapes=("b, ...", "b, ...")) + polymorphic_shapes=("b, ...", "b, ..."), + # TODO(necula): now also includes shape_assertion + compare_with_current=False) def test_stablehlo_dynamic_rbg_bit_generator(self): # stablehlo.dynamic_rbg_bit_generator is used temporarily for a @@ -660,7 +666,9 @@ class CompatTest(bctu.CompatTestBase): try: jax.config.update("jax_default_prng_impl", "unsafe_rbg") - self.run_one_test(func, data, polymorphic_shapes=(None, "b0, b1")) + self.run_one_test(func, data, polymorphic_shapes=(None, "b0, b1"), + # TODO(necula): now also includes shape_assertion + compare_with_current=False) finally: jax.config.update("jax_default_prng_impl", prev_default_prng_impl) diff --git a/jax/experimental/jax2tf/tests/jax_export_test.py b/jax/experimental/jax2tf/tests/jax_export_test.py index 578efed1b..bdd877fbb 100644 --- a/jax/experimental/jax2tf/tests/jax_export_test.py +++ b/jax/experimental/jax2tf/tests/jax_export_test.py @@ -13,7 +13,6 @@ # limitations under the License. import contextlib import re -from typing import Optional, Sequence import unittest from absl.testing import absltest @@ -28,136 +27,14 @@ from jax._src import core from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.interpreters import mlir -from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.lib.mlir.dialects import func as func_dialect -from jax._src.lib import xla_extension import numpy as np config.parse_flags_with_absl() -if xc.mlir_api_version >= 50: - def _temp_call_exported(exp: jax_export.Exported, *args: jax.Array, - skip_shape_check: bool = False): - assert all(core.is_constant_shape(a.shape) for a in args) - return jax_export.call_exported(exp)(*args) -else: - def _temp_call_exported(exp: jax_export.Exported, *args: jax.Array, - skip_shape_check: bool = False): - """Temporary runner for an Exported. - - Normally we would use jax_export.call_exported, but if the exported has - shape polymorphism and we are using jaxlib before 0.4.12 we use - Once we upgrade the jaxlib we can replace all uses of this function with - jax_export.call_exported. - """ - assert all(core.is_constant_shape(a.shape) for a in args) - if not exp.module_uses_dim_vars: - return jax_export.call_exported(exp)(*args) - else: - # We only get here in external tests, because internal ones use newest jaxlib - from jax.experimental.jax2tf import jax2tf # TODO: temporary - # call_exported does the shape checking, we must do it manually for - # XlaCallModule. - if not skip_shape_check: - shape_check = exp.shape_check_module() - if shape_check is not None: - err_msg = _run_shape_check_module(exp, shape_check[0], shape_check[1], args) - if err_msg is not None: - raise ValueError(err_msg) - - numpy_results = map(lambda res_tf: res_tf.numpy(), - jax2tf._run_exported_as_tf(args, exp)) - return exp.out_tree.unflatten(numpy_results) - - -def _run_shape_check_module(primal_exported: jax_export.Exported, - shape_check_module_serialized: bytes, - shape_check_messages: Sequence[str], - args: Sequence[jax.Array] - ) -> Optional[str]: - """Helper to run a shape checking module. - - We only need to do this in tests, because otherwise this will be done - implicitly when we call XlaCallModule. - - Returns: the error message, or None if no error. - """ - args = tuple(args) - # We cannot just make an Exported and run it, because call_exported will - # do the shape checks statically. So, we wrap the shape check module - # with static shape arguments - - static_in_avals = tuple(core.get_aval(a) for a in args) - context = mlir.make_ir_context() - with context, ir.Location.unknown(context): - wrapped_module = ir.Module.parse( - xla_extension.mlir.deserialize_portable_artifact( - shape_check_module_serialized)) - symbol_table = ir.SymbolTable(wrapped_module.operation) - orig_main = symbol_table["main"] - orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private") - symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main") - orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value - # Use static shapes - new_main_input_types = [mlir.aval_to_ir_type(a) for a in static_in_avals] - orig_output_types = orig_main.type.results - new_main_ftype = ir.FunctionType.get( - new_main_input_types, orig_output_types - ) - new_main_op = func_dialect.FuncOp( - "main", - new_main_ftype, - ip=ir.InsertionPoint.at_block_begin(wrapped_module.body), - ) - new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public") - symbol_table.insert(new_main_op) - entry_block = new_main_op.add_entry_block() - with ir.InsertionPoint(entry_block): - orig_main_args: list[ir.Value] = [] - for new_arg, orig_arg_type in zip( - new_main_op.arguments, orig_main.type.inputs - ): - orig_main_args.append(hlo.ConvertOp(orig_arg_type, new_arg).result) - call = func_dialect.CallOp( - orig_output_types, - ir.FlatSymbolRefAttr.get(orig_main_name), - orig_main_args, - ) - func_dialect.ReturnOp(call.results) - symbol_table.set_symbol_name(new_main_op, "main") - - wrapped_module_serialized, version = jax_export._serialize_module(wrapped_module) - # Make an Exported and then run it - out_avals = (core.ShapedArray((), dtype=np.int32),) * 3 - exp = jax_export.Exported( - fun_name=f"shape_check_{primal_exported.fun_name}", - in_tree=tree_util.tree_flatten((args, {}))[1], - in_avals=static_in_avals, - out_tree=tree_util.tree_flatten(out_avals)[1], - out_avals=out_avals, - in_shardings=None, - out_shardings=None, - lowering_platform=primal_exported.lowering_platform, - disabled_checks=(), - mlir_module_serialized=wrapped_module_serialized, - xla_call_module_version=version, - module_kept_var_idx=tuple(sorted(range(len(static_in_avals)))), - module_uses_dim_vars=True, - _get_vjp=lambda _: None) # type: ignore - - code, op1, op2 = _temp_call_exported(exp, *args, - skip_shape_check=True) - if code == -1: - return None - else: - return shape_check_messages[code].replace("%1", str(int(op1))).replace( - "%2", str(int(op2))) - - class JaxExportTest(jtu.JaxTestCase): def test_basic_export_only(self): @@ -396,7 +273,7 @@ class JaxExportTest(jtu.JaxTestCase): dict(poly_spec="3,4,12", arg_shape=(3, 4, 12)), dict(poly_spec="3,4,12", arg_shape=(3, 4, 13), # The shape check module does not test constant dimensions - expect_error_run=re.escape( + expect_error=re.escape( r"Shape mismatch for args[0].shape[2] (expected same constant)")), dict(poly_spec="3,4,6*a", arg_shape=(3, 4, 12)), dict(poly_spec="3,a,a+8", arg_shape=(3, 4, 12)), @@ -418,49 +295,31 @@ class JaxExportTest(jtu.JaxTestCase): def test_poly_shape_checks( self, poly_spec="3,a,a+8", arg_shape=(3, 4, 12), arg_dtype=np.float32, - expect_error=None, # If given, applies for expect_error_run and expect_error_shape_check - expect_error_run=None, # Error from running the exported module - expect_error_shape_check=None): # Error from running the shape check module - if expect_error is not None: - self.assertIsNone(expect_error_run, None) - self.assertIsNone(expect_error_shape_check, None) - expect_error_run = expect_error_shape_check = expect_error + expect_error=None): # If given, error from running the exported module + def f(x): # x: f32[poly_spec] return jnp.reshape(x, (-1, x.shape[1])) - exp_f = jax_export.export(f)( + if xc.mlir_api_version <= 51: + disabled_checks = (jax_export.DisabledSafetyCheck.shape_assertions(),) + else: + disabled_checks = () + exp_f = jax_export.export(f, disabled_checks=disabled_checks)( jax_export.poly_spec((3, 4, 12), np.float32, poly_spec)) self.assertEqual(exp_f.module_uses_dim_vars, poly_spec != "3,4,12") arg = np.arange(np.prod(arg_shape), dtype=arg_dtype).reshape(arg_shape) # arg : f32[3,4,12] with contextlib.ExitStack() as stack: - if expect_error_run is not None: - stack.push(self.assertRaisesRegex(Exception, expect_error_run)) + if expect_error is not None: + stack.push(self.assertRaisesRegex(Exception, expect_error)) - res = _temp_call_exported(exp_f, arg) + assert core.is_constant_shape(arg.shape) + res = jax_export.call_exported(exp_f)(arg) - if not expect_error_run: + if not expect_error: self.assertAllClose(res, f(arg)) - # Test the shape_check_module - shape_check = exp_f.shape_check_module() - # We have shape_check only if the exported has polymorphic inputs shapes. - if all(core.is_constant_shape(a.shape) for a in exp_f.in_avals): - self.assertIsNone(shape_check) - self.assertIsNone(expect_error_shape_check) - else: - self.assertIsNotNone(shape_check) - shape_check_module, shape_check_messages = shape_check - err_msg = _run_shape_check_module(exp_f, - shape_check_module, shape_check_messages, (arg,)) - - if expect_error_shape_check is None: - self.assertIsNone(err_msg) - else: - self.assertRegex(err_msg, expect_error_shape_check) - - # An inner function is exported with polymorphic shapes inner_poly_spec, and # is called from an outer function, which is exported with outer_poly_spec. @jtu.parameterized_filterable( @@ -536,12 +395,16 @@ class JaxExportTest(jtu.JaxTestCase): expect_error_outer_exp=None, expect_error_run=None): # Polymorphic export called with static or polymorphic shapes + if xc.mlir_api_version <= 51: + disabled_checks = (jax_export.DisabledSafetyCheck.shape_assertions(),) + else: + disabled_checks = () def inner(x): # x: inner_poly_spec return jnp.reshape(x, (-1, x.shape[1])) arg = np.arange(np.prod(arg_shape), dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12] - inner_exp = jax_export.export(inner)( + inner_exp = jax_export.export(inner, disabled_checks=disabled_checks)( jax_export.poly_spec((3, 4, 12), np.float32, inner_poly_spec)) self.assertEqual(inner_exp.module_uses_dim_vars, @@ -556,7 +419,7 @@ class JaxExportTest(jtu.JaxTestCase): stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp)) # Call it after exporting again, with polymorphic shapes - outer_exp = jax_export.export(outer)( + outer_exp = jax_export.export(outer, disabled_checks=disabled_checks)( jax_export.poly_spec(arg.shape, arg.dtype, outer_poly_spec)) if expect_error_outer_exp is not None: @@ -564,17 +427,12 @@ class JaxExportTest(jtu.JaxTestCase): self.assertEqual(outer_exp.module_uses_dim_vars, (inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12")) - shape_check = outer_exp.shape_check_module() - if all(core.is_constant_shape(a.shape) for a in outer_exp.in_avals): - self.assertIsNone(shape_check) - else: - self.assertIsNotNone(shape_check) with contextlib.ExitStack() as stack: if expect_error_run is not None: stack.push(self.assertRaisesRegex(Exception, expect_error_run)) - res = _temp_call_exported(outer_exp, arg) + res = jax_export.call_exported(outer_exp)(arg) if expect_error_run is not None: return diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 40257f76f..f1eeafd38 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -87,7 +87,9 @@ class DimExprTest(tf_test_util.JaxToTfTestCase): self.assertEqual((a, 3), shape_poly._parse_spec("(a, ...) ", tshape)) a, b = shape_poly._parse_spec("a, b", (2, 3)) - @parameterized.named_parameters( + + @jtu.parameterized_filterable( + kwargs=[ dict(testcase_name=f"_{dim_spec}", dim_spec=dim_spec, dim_poly=dim_poly) for dim_spec, dim_poly in [ @@ -100,34 +102,36 @@ class DimExprTest(tf_test_util.JaxToTfTestCase): ("a + -1", a - 1), ("3 * a * mod(a + 2, b + 2)", 3 * a * ((a + 2) % (b + 2))), ("3 * floordiv(a + 2, b + 2) * 2", 3 * ((a + 2) // (b + 2)) * 2), - ]) + ]]) def test_parse_dim(self, dim_spec="-2 * a^2 * b + b^2", dim_poly=-2 * a * a * b + b * b): self.assertEqual((dim_poly,), shape_poly._parse_spec(dim_spec, (None,))) self.assertEqual((dim_poly,), shape_poly._parse_spec(str(dim_poly), (None,))) - @parameterized.named_parameters( + @jtu.parameterized_filterable( + kwargs=[ dict(testcase_name=f"_{shape_spec=}", shape_spec=shape_spec) for shape_spec in [ "2.5", "a + a a", "a ^ a", "a, a", "_", "...", "a ;", ")(", "2a", "a@", "'a'", "('a', ...)", "mod(a)", "floordiv(a, b, c)", "..., 3" - ]) + ]]) def test_parse_error(self, shape_spec="a + a a"): with self.assertRaisesRegex(ValueError, "syntax error in polymorphic shape"): shape_poly._parse_spec(shape_spec, (None,)) - @parameterized.named_parameters( + @jtu.parameterized_filterable( + kwargs=[ dict(testcase_name=f"_{shape_spec=}", shape_spec=shape_spec, arg_shape=arg_shape) for shape_spec, arg_shape in [ ("3", (4,)), ("b, 3", (None, 4)), - ]) + ]]) def test_parse_mismatch_error(self, shape_spec="3", arg_shape=(4,)): with self.assertRaisesRegex(ValueError, @@ -347,7 +351,8 @@ class DimExprTest(tf_test_util.JaxToTfTestCase): self.assertEqual(a * 2 // a, 2) self.assertIsInstance(a * 2 // a, int) - @parameterized.named_parameters( + @jtu.parameterized_filterable( + kwargs=[ dict(testcase_name=f"_D={dividend}_d={divisor}_q={quotient}_r={remainder}", dividend=dividend, divisor=divisor, quotient=quotient, remainder=remainder) @@ -364,7 +369,7 @@ class DimExprTest(tf_test_util.JaxToTfTestCase): (3 * a, 2, "floordiv(3*a, 2)", "mod(3*a, 2)"), (2 * a * b + b * b, a + b, "floordiv(2*a*b + b^2, a + b)", "mod(2*a*b + b^2, a + b)"), (3, a, "floordiv(3, a)", "mod(3, a)"), - ]) + ]]) def test_poly_divmod(self, *, dividend, quotient, divisor, remainder): if isinstance(quotient, str): d1, d2 = divmod(dividend, divisor) @@ -930,37 +935,26 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase): re.escape("pytree structure error: different types")): conv_and_run(arg_shape=(2,), polymorphic_shape=("a tuple",)) - # The following do not work yet with native serialization because - # XlaCallModule does not yet do shape checking. - if config.jax2tf_default_native_serialization: - return - - # TODO(necula): enable even for native serialization with self.assertRaisesRegex(ValueError, "Cannot solve for values of dimension variables {'b'}"): conv_and_run(arg_shape=(4, 36, 3), polymorphic_shape="b * b, b * d * d, d") - # TODO(necula): enable even for native serialization - with self.assertRaisesRegex(ValueError, + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, "Dimension variable 'b' must have integer value >= 1"): conv_and_run(arg_shape=(5, 36), polymorphic_shape="3 * b, ...") - # TODO(necula): enable even for native serialization - with self.assertRaisesRegex(ValueError, + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, "Dimension variable 'b' must have integer value >= 1"): conv_and_run(arg_shape=(10, 3), polymorphic_shape="3 * b + 10, ...") - # TODO(necula): enable even for native serialization - with self.assertRaisesRegex(ValueError, + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, "Dimension variable 'b' must have integer value >= 1"): conv_and_run(arg_shape=(7, 3), polymorphic_shape="3 * b + 10, ...") - # TODO(necula): enable even for native serialization with self.assertRaisesRegex( - ValueError, + tf.errors.InvalidArgumentError, "Found inconsistency 3 != 2 when solving.*"): conv_and_run(arg_shape=(2, 3), polymorphic_shape="(a, a)") - def test_pytree(self): """Arguments and polymorphic_shapes are pytrees.""" @@ -1250,7 +1244,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase): res_tf = jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x) self.assertAllClose(f(x), res_tf) - @jtu.sample_product(with_function=[False, True]) + @jtu.parameterized_filterable( + kwargs=[dict(with_function=v) for v in [True, False]] + ) def test_grad_int(self, with_function=False): # https://github.com/google/jax/issues/7093 # Also issue #6975. @@ -1418,34 +1414,56 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase): jax2tf.convert(lambda x: 0 if x.shape[0] + 1 >= x.shape[1] else 1, polymorphic_shapes=["(a, b)"])(np.ones((4, 4))) - # Unsoundness: not checking that the dimension variable is 0 + # Checking that the dimension variable is >= 1 def f1_jax(x): # f32[b] # We have to use "x" return jnp.concatenate([x, jnp.array([0. if x.shape[0] == 0 else 1.], dtype=np.float32)]) x0 = np.array([], np.float32) - # JAX with static shapes sees that the x.shape[0] == 0 self.assertEqual(jnp.array([0.], dtype=np.float32), f1_jax(x0)) + # In graph serialization eager mode we catch the error with self.assertRaisesRegex( - ValueError, + tf.errors.InvalidArgumentError, "Dimension variable 'b' must have integer value >= 1. Found 0"): jax2tf.convert(f1_jax, polymorphic_shapes=["b"], native_serialization=False)(x0) - # In native serialization, or if we trace to a TF graph, we miss this - res1_tf = jax2tf.convert(f1_jax, polymorphic_shapes=["b"], - native_serialization=True)(x0) - self.assertEqual(jnp.array([1.], dtype=np.float32), res1_tf) - + # In graph serialization graph mode we also catch it (except on TPU) f1_tf = tf.function( jax2tf.convert(f1_jax, polymorphic_shapes=["b"], - native_serialization=False) + native_serialization=False), + autograph=False, ).get_concrete_function(tf.TensorSpec([None], dtype=np.float32)) + # In graph serialization graph mode we also catch it (except on TPU, where + # the behavior is as for jit_compile=1) + if jtu.device_under_test() == "tpu": + self.assertEqual(jnp.array([1.], dtype=np.float32), f1_tf(x0)) + else: + with self.assertRaisesRegex( + tf.errors.InvalidArgumentError, + "Dimension variable .* must have integer value"): + _ = f1_tf(x0) + + # In graph serialization with jit_compile=True we do not catch the error + # and we return the wrong result + f1_tf = tf.function( + jax2tf.convert(f1_jax, polymorphic_shapes=["b"], + native_serialization=False), + autograph=False, + jit_compile=True + ) self.assertEqual(jnp.array([1.], dtype=np.float32), f1_tf(x0)) - # Unsoundness: not checking that the actual dimensions denoted by the same + # We also catch the error with native serialization + with self.assertRaisesRegex( + tf.errors.InvalidArgumentError, + "Dimension variable 'b' must have integer value >= 1. Found 0"): + _ = jax2tf.convert(f1_jax, polymorphic_shapes=["b"], + native_serialization=True)(x0) + + # Checking that the actual dimensions denoted by the same # dimension variables have equal sizes. def f2_jax(x): # f32[b, b] # We have to use "x" @@ -1455,25 +1473,46 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase): # JAX with static shapes sees that x.shape[0] != x.shape[1] self.assertEqual(jnp.sum(x45), f2_jax(x45)) - # jax2tf catches the broken assumption b >= 1 if the converted function is executed - # eagerly. + # In graph serialization eager mode, we catch the broken assumption b >= 1 with self.assertRaisesRegex( - ValueError, + tf.errors.InvalidArgumentError, r"Found inconsistency 5 != 4 when solving b == args\[0\].shape\[1\]"): jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"], native_serialization=False)(x45) - # In native serialization, or if we trace to a TF graph, we miss this - res2_tf = jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"], - native_serialization=True)(x45) - self.assertEqual(1. + jnp.sum(x45), res2_tf) + # In graph serialization graph mode we also catch it (except on TPU, where + # the behavior is as for jit_compile=1) f2_tf = tf.function( jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"], - native_serialization=False) + native_serialization=False), + autograph=False, ).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32)) + if jtu.device_under_test() == "tpu": + self.assertEqual(1. + jnp.sum(x45), f2_tf(x45)) + else: + with self.assertRaisesRegex( + tf.errors.InvalidArgumentError, + r"Found inconsistency"): + _ = f2_tf(x45) + + # In graph serialization with jit_compile=True we do not catch the error + # and we return the wrong result + f2_tf = tf.function( + jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"], + native_serialization=False), + autograph=False, + jit_compile=True + ) self.assertEqual(1. + jnp.sum(x45), f2_tf(x45)) + # We also catch the error with native serialization + with self.assertRaisesRegex( + tf.errors.InvalidArgumentError, + "Found inconsistency 5 != 4"): + _ = jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"], + native_serialization=True)(x45) + x = np.ones((5,), dtype=np.float32) with self.assertRaisesRegex( ValueError, diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index e1dbe6185..650f33743 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -27,7 +27,6 @@ from typing import Any, Sequence import unittest from absl.testing import absltest -from absl.testing import parameterized import jax from jax._src import test_util as jtu @@ -180,12 +179,13 @@ class ShardingTest(tf_test_util.JaxToTfTestCase): else: assert False - @parameterized.named_parameters( + @jtu.parameterized_filterable( + kwargs=[ dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}", in_shardings=in_shardings, out_shardings=out_shardings) for in_shardings in ("missing", None, "P") for out_shardings in ("missing", None, "P") - ) + ]) @jtu.with_mesh([("x", 2)]) def test_pjit_basic(self, in_shardings="P", out_shardings="P"): # Ensure that we can distinguish the inputs and outputs by shape @@ -316,14 +316,15 @@ class ShardingTest(tf_test_util.JaxToTfTestCase): res_tf = f_tf(x) self.assertAllClose(res_tf, res_jax) - @parameterized.named_parameters( + @jtu.parameterized_filterable( + kwargs=[ dict(testcase_name=f"_nested_pjit={nested_pjit}_constraint={constraint}_poly={poly}", nested_pjit=nested_pjit, constraint=constraint, poly=poly) # We add a constraint either with a nested pjit or with a sharding_constraint for nested_pjit in (True, False) for constraint in (None, "P") for poly in (None, "2*b1,_", "_,b2", "2*b1,b2") - ) + ]) @jtu.with_mesh([("x", 2)]) def test_pjit_sharding_constraint(self, nested_pjit=True, constraint="P", poly="2*b1,b2"): constraint_sharding = P("x", None) if constraint == "P" else None @@ -363,12 +364,13 @@ class ShardingTest(tf_test_util.JaxToTfTestCase): (r"custom_call_target.*Sharding", 2 + count_inner_sharding + count_inner_replicated) ]) - @parameterized.named_parameters( + @jtu.parameterized_filterable( + kwargs=[ dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}", in_shardings=in_shardings, out_shardings=out_shardings) for in_shardings in ("missing", None, "P") for out_shardings in ("missing", None, "P") - ) + ]) @jtu.with_mesh([("x", 2)]) def test_grad_pjit(self, in_shardings="P", out_shardings=None): def f_jax(x): # x: f32[10,20] -> f32[20,10] @@ -414,7 +416,8 @@ class ShardingTest(tf_test_util.JaxToTfTestCase): (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", count_out_replicated), ]) - @parameterized.named_parameters( + @jtu.parameterized_filterable( + kwargs=[ dict(testcase_name=f"_kind={kind}_in_shardings={in_shardings}_out_shardings={out_shardings}", kind=kind, in_shardings=in_shardings, out_shardings=out_shardings) for kind in ("pjit", "jit", "sharding_constraint") @@ -425,7 +428,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase): for out_shardings in ( ("unspecified",) if kind in ["sharding_constraint", "jit"] else ("unspecified", "none", "P")) - ) + ]) def test_pjit_error_inner_sharding(self, kind="pjit", in_shardings="P", out_shardings="none"): # Check that we raise an error if there is no top-level pjit but we convert @@ -462,11 +465,12 @@ class ShardingTest(tf_test_util.JaxToTfTestCase): with Mesh(self.devices, axis_names=("x",)): f_tf(x) - @parameterized.named_parameters( + @jtu.parameterized_filterable( + kwargs=[ dict(testcase_name=f"_func={func}", func=func) for func in ("pjit_sharded", "pjit_replicated", "nested_pjit_sharded", "nested_pjit_replicated") - ) + ]) def test_pjit_eager_error(self, func="pjit_sharded"): if config.jax2tf_default_native_serialization: raise unittest.SkipTest("There is no error in eager mode for native serialization") @@ -693,10 +697,11 @@ class ShardingTest(tf_test_util.JaxToTfTestCase): res_tf = f_tf(a) self.assertAllClose(res_tf, expected) - @parameterized.named_parameters( + @jtu.parameterized_filterable( + kwargs=[ dict(testcase_name=f"_poly={poly}", poly=poly) for poly in (None, "2*b1,_", "_,b2", "2*b1,b2") - ) + ]) def test_shmap_collective_permute(self, poly=None): if jtu.device_under_test() == "cpu": raise unittest.SkipTest("TODO(b/268295912): ShardingRemover crash")