mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[shape_poly] Reimplement the shape constraint checking using shape assertions.
Most of the functionality is for the JAX native serialization case. This relies on newly added functionality to xla_extension.refine_polymorphic_shapes that handles custom calls @static_assertion. As a beneficial side-effect now we get shape constraint checking for jax2tf graph serialization when the resulting function is executed in graph mode.
This commit is contained in:
parent
f97dca79a2
commit
e643f98558
@ -68,6 +68,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* JAX now supports a configuration flag --jax_serialization_version
|
||||
and a JAX_SERIALIZATION_VERSION environment variable to control the
|
||||
serialization version ({jax-issue}`#16746`).
|
||||
* jax2tf in presence of shape polymorphism now generates code that checks
|
||||
certain shape constraints, if the serialization version is at least 7.
|
||||
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism.
|
||||
|
||||
## jaxlib 0.4.14
|
||||
|
||||
|
@ -2044,7 +2044,7 @@ def custom_call(
|
||||
out_types: Sequence[ir.Type],
|
||||
operands: Sequence[ir.Value],
|
||||
*,
|
||||
backend_config: Union[str, dict] = "",
|
||||
backend_config: Union[str, dict[str, ir.Attribute]] = "",
|
||||
has_side_effect: bool = False,
|
||||
result_shapes: Optional[Sequence[ir.Value]] = None,
|
||||
called_computations: Sequence[str] = (),
|
||||
|
@ -1230,18 +1230,21 @@ def parameterized_filterable(*,
|
||||
Args:
|
||||
kwargs: Each entry is a set of kwargs to be passed to the test function.
|
||||
testcase_name: Optionally, a function to construct the testcase_name from
|
||||
one kwargs dict. If not given then the kwarg must contain `testcase_name`.
|
||||
one kwargs dict. If not given then kwarg may contain `testcase_name` and
|
||||
if not, the test case name is constructed as `str(kwarg)`.
|
||||
one_containing: If given, then leave the test name unchanged, and use
|
||||
only one `kwargs` whose `testcase_name` includes `one_containing`.
|
||||
"""
|
||||
# Ensure that all kwargs contain a testcase_name
|
||||
kwargs_with_testcase_name: Sequence[dict[str, Any]]
|
||||
if testcase_name is not None:
|
||||
kwargs_with_testcase_name = [dict(testcase_name=testcase_name(kw), **kw)
|
||||
kwargs_with_testcase_name = [dict(testcase_name=str(testcase_name(kw)), **kw)
|
||||
for kw in kwargs]
|
||||
else:
|
||||
for kw in kwargs:
|
||||
assert "testcase_name" in kw
|
||||
if "testcase_name" not in kw:
|
||||
kw["testcase_name"] = "_".join(f"{k}={str(kw[k])}"
|
||||
for k in sorted(kw.keys()))
|
||||
kwargs_with_testcase_name = kwargs
|
||||
if one_containing is not None:
|
||||
filtered = tuple(kw for kw in kwargs_with_testcase_name
|
||||
|
@ -547,11 +547,9 @@ cannot be used anymore as dimension parameters and will raise a JAX error.
|
||||
|
||||
### Errors in presence of shape polymorphism
|
||||
|
||||
If you write your program assuming that all shapes are tuples of integers,
|
||||
and then try to trace it with shape polymorphism you can run into a number
|
||||
of errors.
|
||||
|
||||
The program:
|
||||
Most JAX code assumes that the shapes of JAX arrays are tuples of integers,
|
||||
but with shape polymorphism some dimensions may be symbolic expressions.
|
||||
This can lead to a number of errors. For example, the program:
|
||||
|
||||
```python
|
||||
four_ones = np.ones((4,))
|
||||
@ -597,6 +595,30 @@ The solution is to avoid `np.array`, `float`, or JAX arrays in operations whose
|
||||
results are used as shapes, e.g., instead of `np.arange(n) * x.shape[0]` write
|
||||
`[i * x.shape[0] for i in range(n)]`.
|
||||
|
||||
JAX assumes that dimension variables range over strictly positive integers.
|
||||
These assumptions are now checked against the shapes of the actual arguments
|
||||
when the lowered code is invoked.
|
||||
For example, given the `polymorphic_shapes="(b, b, 2*d)"`
|
||||
specification, we will generate code to check the following constraints when
|
||||
invoked with actual argument `arg`:
|
||||
|
||||
* `arg.shape[0] >= 1`
|
||||
* `arg.shape[1] == arg.shape[0]`
|
||||
* `arg.shape[2] % 2 == 0` and `arg.shape[0] // 2 >= 1`
|
||||
|
||||
When using native serialization these are checked by the `tf.XlaCallModule`
|
||||
op (starting with serialization
|
||||
[version 7](https://github.com/search?q=repo%3Agoogle%2Fjax+path%3Aconfig.py+jax_serialization_version&type=code)),
|
||||
and you will get `tf.errors.InvalidArgument` errors.
|
||||
You can disable this checking by including `DisabledSafetyCheck.shape_assertions()`
|
||||
in the `disabled_checks` parameter to `jax2tf.convert`, or by setting
|
||||
the environment variable
|
||||
`TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=shape_assertions`.
|
||||
When using graph serialization these are checked using `tf.debugging.assert`,
|
||||
which will also result in `tf.errors.InvalidArgument`.
|
||||
Note that due to limitations in TensorFlow, these errors are suppressed when using
|
||||
`jit_compile=True` and when running on TPU.
|
||||
|
||||
### Comparison of symbolic dimensions is partially supported
|
||||
|
||||
Inside JAX there are a number of equality and inequality comparisons
|
||||
|
@ -567,17 +567,6 @@ class GraphSerializationImpl(SerializationImpl):
|
||||
self.args_avals_flat, args_kwargs_tree=self.in_tree),
|
||||
self.args_flat_tf, self.args_avals_flat, self.name_stack)
|
||||
|
||||
# We invoke shape checking to give it a chance to raise shape errors that
|
||||
# are evident statically. This should work in TF eager mode because all
|
||||
# the shapes are known.
|
||||
# TODO: handle non-static shape checking for graph serialization
|
||||
acc_shape_check_messages: list[str] = []
|
||||
_, _ = _interpret_fun_jax(
|
||||
partial(shape_poly.compute_shape_check_from_arg_shapes,
|
||||
self.args_avals_flat, args_kwargs_tree=self.in_tree,
|
||||
acc_shape_check_messages=acc_shape_check_messages),
|
||||
self.args_flat_tf, self.args_avals_flat, self.name_stack)
|
||||
|
||||
_thread_local_state.shape_env = zip(dim_vars, dim_values)
|
||||
|
||||
fun_flat_jax, out_tree_thunk = flatten_fun_jax(self.fun_jax, self.in_tree)
|
||||
@ -3386,6 +3375,16 @@ def _dim_as_value_jax2tf(dim: shape_poly.DimSize):
|
||||
|
||||
tf_impl[shape_poly.dim_as_value_p] = _dim_as_value_jax2tf
|
||||
|
||||
def _shape_assertion_jax2tf(assert_what, *error_message_inputs,
|
||||
error_message: str):
|
||||
|
||||
tf.debugging.assert_equal(
|
||||
assert_what, True,
|
||||
message=error_message.format(*error_message_inputs))
|
||||
return []
|
||||
|
||||
tf_impl[shape_poly.shape_assertion_p] = _shape_assertion_jax2tf
|
||||
|
||||
def _reduce_precision(x, *, exponent_bits, mantissa_bits):
|
||||
return tfxla.reduce_precision(x, exponent_bits=exponent_bits,
|
||||
mantissa_bits=mantissa_bits)
|
||||
|
@ -82,6 +82,15 @@ class DisabledSafetyCheck:
|
||||
"""
|
||||
return DisabledSafetyCheck(f"custom_call:{target_name}")
|
||||
|
||||
@classmethod
|
||||
def shape_assertions(cls) -> "DisabledSafetyCheck":
|
||||
"""Allows invocations with shapes that do not meet the constraints.
|
||||
|
||||
Has effect on serialization (to supress the generation of the assertions)
|
||||
and also on deserialization (to suppress the checking of the assertions).
|
||||
"""
|
||||
return DisabledSafetyCheck("shape_assertions")
|
||||
|
||||
def is_custom_call(self) -> Optional[str]:
|
||||
"""Returns the custom call target allowed by this directive."""
|
||||
m = re.match(r'custom_call:(.+)$', self._impl)
|
||||
@ -163,66 +172,6 @@ class Exported:
|
||||
def mlir_module(self) -> ir.Module:
|
||||
return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized)
|
||||
|
||||
def shape_check_module(self) -> Optional[tuple[bytes, Sequence[str]]]:
|
||||
"""Generates a serialized shape checking module and the error messages.
|
||||
|
||||
Consider the exporting of a function with one array argument of type
|
||||
"f32[w, 2 * h]", where "w" and "h" are two dimension variables. JAX tracing
|
||||
assumes that `arg.shape[1]` is even, and that both `w` and `h` have
|
||||
values >= 1. We must ensure that these assumptions hold when the function
|
||||
is invoked.
|
||||
|
||||
If we `call_exported` for this module we perform these checks
|
||||
statically (in `call_exported_abstract_eval`).
|
||||
|
||||
But if we wrap the exported module with XlaCallModule for execution outside
|
||||
JAX, we must defer these checks to compile time, when the static shapes
|
||||
are known.
|
||||
We generate a separate MLIR shape checking module with a main
|
||||
function that returns a triple of int32's, with a shape checking code and
|
||||
two shape checking operands. The shape checking code is -1 if the checks
|
||||
pass, or is an index into the returned shape_check_messages otherwise.
|
||||
The returned shape checking operands are substituted for "%1" and "%2" in
|
||||
the shape checking messages to obtain the error message. Here is a shape
|
||||
checking module for our example:
|
||||
|
||||
func public main(arg: f32[?, ?]) {
|
||||
# code will be the index of the last failed shape check.
|
||||
code = op1 = op2 = -1 # Assume no errors initially
|
||||
# Check that w is >= 1
|
||||
arg_w = hlo.get_dimension_size(arg, 0)
|
||||
code = hlo.select(arg_w >= 1, code, 0) # error message 0
|
||||
op1 = hlo.select(arg_w >= 1, op1, arg_w)
|
||||
op2 = hlo.select(arg_w >= 1, op2, 1)
|
||||
# Check that dim1 is even
|
||||
dim1 = hlo.get_dimension_size(arg, 1)
|
||||
code = hlo.select(dim1 % 2 == 0, code, 1) # error message 1
|
||||
op1 = hlo.select(dim1 % 2 == 0, op1, dim1 % 2)
|
||||
op2 = hlo.select(dim1 % 2 == 0, op2, 0)
|
||||
# Check that h >= 1
|
||||
arg_h = hlo.floordiv(dim1, 2)
|
||||
code = hlo.select(arg_h >= 1, code, 2) # error message 2
|
||||
op1 = hlo.select(arg_h >= 1, op1, arg_h)
|
||||
op2 = hlo.select(arg_h >= 1, op2, 1)
|
||||
return (code, op1, op2)
|
||||
|
||||
We also return the following shape checking messages to be used for
|
||||
constructing error messages:
|
||||
|
||||
shape_check_messages = [
|
||||
"Dimension variable 'w' must have integer value >= 1. Found %1",
|
||||
"Dimension variable 'h' must have integer value >= 1. Found non-zero remainder %1",
|
||||
"Dimension variable 'h' must have integer value >= 1. Found %1"]
|
||||
"""
|
||||
shape_check_res = _make_shape_check_module(
|
||||
self.in_avals, args_kwargs_tree=self.in_tree)
|
||||
if shape_check_res is None:
|
||||
return None
|
||||
module, shape_check_messages = shape_check_res
|
||||
module_serialized, xla_call_module_version = _serialize_module(module)
|
||||
assert xla_call_module_version == self.xla_call_module_version
|
||||
return module_serialized, shape_check_messages
|
||||
|
||||
def __str__(self):
|
||||
# This is called to make a MLIR source location when we call an Exported, and we
|
||||
# do not want the entire serialized module to end up in locations.
|
||||
@ -359,6 +308,8 @@ def export(fun_jax: Callable,
|
||||
exported = jax_export.export(f_jax)(*args, **kwargs)
|
||||
"""
|
||||
fun_name = getattr(fun_jax, "__name__", "unknown")
|
||||
version = config.jax_serialization_version
|
||||
|
||||
def do_export(*args_specs, **kwargs_specs) -> Exported:
|
||||
if not hasattr(fun_jax, "lower"):
|
||||
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
|
||||
@ -373,28 +324,39 @@ def export(fun_jax: Callable,
|
||||
allow_non_replicated_sharding = True
|
||||
|
||||
lowering_platform_str = lowering_platform or default_lowering_platform()
|
||||
lowered = wrapped_fun_jax.lower(
|
||||
*args_specs, **kwargs_specs,
|
||||
_experimental_lowering_platform=lowering_platform_str)
|
||||
lowering = lowered._lowering # type: ignore
|
||||
_check_lowering(lowering)
|
||||
mlir_module = lowering.stablehlo()
|
||||
|
||||
args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals)
|
||||
if "kept_var_idx" in lowering.compile_args:
|
||||
module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"]))
|
||||
else:
|
||||
# For pmap
|
||||
module_kept_var_idx = tuple(range(len(args_avals_flat)))
|
||||
shape_poly_state = lowering.compile_args["shape_poly_state"]
|
||||
if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat)
|
||||
or lowering.compile_args.get("ordered_effects", [])):
|
||||
# All arguments are kept if we have dimension variables.
|
||||
assert len(module_kept_var_idx) == len(args_avals_flat)
|
||||
mlir_module = _wrap_main_func(
|
||||
mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree
|
||||
)
|
||||
mlir_module_serialized, xla_call_module_version = _serialize_module(mlir_module)
|
||||
# Do not include shape assertions if the version is < 7.
|
||||
enable_shape_assertions = (
|
||||
DisabledSafetyCheck.shape_assertions() not in disabled_checks and
|
||||
version >= 7) # type: ignore
|
||||
try:
|
||||
prev_enable_shape_assertions = shape_poly.thread_local_state.enable_shape_assertions
|
||||
shape_poly.thread_local_state.enable_shape_assertions = enable_shape_assertions
|
||||
lowered = wrapped_fun_jax.lower(
|
||||
*args_specs, **kwargs_specs,
|
||||
_experimental_lowering_platform=lowering_platform_str)
|
||||
|
||||
lowering = lowered._lowering # type: ignore
|
||||
_check_lowering(lowering)
|
||||
mlir_module = lowering.stablehlo()
|
||||
|
||||
args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals)
|
||||
if "kept_var_idx" in lowering.compile_args:
|
||||
module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"]))
|
||||
else:
|
||||
# For pmap
|
||||
module_kept_var_idx = tuple(range(len(args_avals_flat)))
|
||||
shape_poly_state = lowering.compile_args["shape_poly_state"]
|
||||
if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat)
|
||||
or lowering.compile_args.get("ordered_effects", [])):
|
||||
# All arguments are kept if we have dimension variables.
|
||||
assert len(module_kept_var_idx) == len(args_avals_flat)
|
||||
mlir_module = _wrap_main_func(
|
||||
mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree
|
||||
)
|
||||
finally:
|
||||
shape_poly.thread_local_state.enable_shape_assertions = prev_enable_shape_assertions
|
||||
mlir_module_serialized = _serialize_module(mlir_module)
|
||||
|
||||
# Figure out the result types and shapes
|
||||
if "global_out_avals" in lowering.compile_args:
|
||||
@ -408,7 +370,7 @@ def export(fun_jax: Callable,
|
||||
# Log and then check the module.
|
||||
if logging.vlog_is_on(3):
|
||||
mlir_module_text = mlir.module_to_string(mlir_module)
|
||||
logmsg = (f"version={xla_call_module_version} "
|
||||
logmsg = (f"version={version} "
|
||||
f"lowering_platform={lowering_platform_str} "
|
||||
f"disabled_checks={disabled_checks}")
|
||||
logging.info("Lowered JAX module: %s\n", logmsg)
|
||||
@ -432,14 +394,13 @@ def export(fun_jax: Callable,
|
||||
mlir_module_serialized=mlir_module_serialized,
|
||||
module_kept_var_idx=module_kept_var_idx,
|
||||
module_uses_dim_vars=shape_poly_state.uses_dim_vars,
|
||||
xla_call_module_version=xla_call_module_version,
|
||||
xla_call_module_version=version, # type: ignore
|
||||
_get_vjp=lambda exported: _export_native_vjp(fun_jax, exported))
|
||||
|
||||
return do_export
|
||||
|
||||
|
||||
def _serialize_module(module: ir.Module) -> tuple[bytes, int]:
|
||||
xla_call_module_version = config.jax_serialization_version
|
||||
def _serialize_module(module: ir.Module) -> bytes:
|
||||
mlir_str = mlir.module_to_bytecode(module)
|
||||
if hlo.get_api_version() < 4:
|
||||
target_version = hlo.get_earliest_forward_compatible_version()
|
||||
@ -463,7 +424,7 @@ def _serialize_module(module: ir.Module) -> tuple[bytes, int]:
|
||||
target_version = hlo.get_minimum_version()
|
||||
module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
|
||||
mlir_str, target_version)
|
||||
return module_serialized, xla_call_module_version
|
||||
return module_serialized
|
||||
|
||||
|
||||
def _wrap_main_func(
|
||||
@ -481,9 +442,9 @@ def _wrap_main_func(
|
||||
the `args_avals`, in sorted order.
|
||||
|
||||
Consider the lowering of a function with one array argument of type "f32[w,
|
||||
2 * h]",
|
||||
where "w" and "h" are two dimension variables. The `module` will also
|
||||
contain two dimension arguments, corresponding to "h" and "w" respectively:
|
||||
2 * h]", where "w" and "h" are two dimension variables. The `module` will
|
||||
also contain two dimension arguments, corresponding to "h" and "w"
|
||||
respectively:
|
||||
|
||||
func public main(arg_h: i32, arg_w: i32, arg: f32[?, ?]) {
|
||||
...
|
||||
@ -495,6 +456,7 @@ def _wrap_main_func(
|
||||
arg_w = hlo.get_dimension_size(arg, 0)
|
||||
dim1 = hlo.get_dimension_size(arg, 1)
|
||||
arg_h = hlo.floordiv(dim1, 2)
|
||||
call _check_shape_assertions(arg) # See below
|
||||
res = call _wrapped_jax_export_main(arg_h, arg_w, arg)
|
||||
return res
|
||||
}
|
||||
@ -503,6 +465,31 @@ def _wrap_main_func(
|
||||
function by providing dummy values. This ensures that the main function's
|
||||
calling convention is as expected.
|
||||
|
||||
Note that the lowering contains a call to `_check_shape_assertions.
|
||||
JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h`
|
||||
have values >= 1. We must check these constraints when we invoke the
|
||||
module. We use a special custom call `@shape_assertion` that takes
|
||||
a boolean first operand, a string `error_message` attribute that may contain
|
||||
format specifiers `{0}`, `{1}`, ..., and a variadic number of integer
|
||||
scalar operands corresponding to the format specifiers.
|
||||
|
||||
func private _check_shape_assertions(arg: f32[?, ?]) {
|
||||
# Check that w is >= 1
|
||||
arg_w = hlo.get_dimension_size(arg, 0)
|
||||
custom_call @shape_assertion(arg_w >= 1, arg_w,
|
||||
error_message="Dimension variable 'w' must have integer value >= 1. Found {0}")
|
||||
# Check that dim1 is even
|
||||
dim1 = hlo.get_dimension_size(arg, 1)
|
||||
custom_call @shape_assertion(dim1 % 2 == 0, dim1,
|
||||
error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}")
|
||||
# Check that h >= 1
|
||||
arg_h = hlo.floordiv(dim1, 2)
|
||||
custom_call @shape_assertion(arg_h >= 1, arg_h,
|
||||
error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}")
|
||||
|
||||
If we `call_exported` with this module we perform these checks
|
||||
statically (in `call_exported_abstract_eval`).
|
||||
|
||||
Args:
|
||||
module: the HLO module as obtained from lowering. May have a number of
|
||||
dimension arguments, followed by the kept array arguments.
|
||||
@ -601,55 +588,6 @@ def _wrap_main_func(
|
||||
|
||||
return wrapped_module
|
||||
|
||||
|
||||
def _make_shape_check_module(
|
||||
args_avals_flat: Sequence[core.AbstractValue],
|
||||
*,
|
||||
args_kwargs_tree: tree_util.PyTreeDef
|
||||
) -> Optional[tuple[ir.Module, Sequence[str]]]:
|
||||
"""Codegens the shape checking function.
|
||||
|
||||
The shape checking function takes the array inputs and returns a triple.
|
||||
See `Exported.shape_check_module` docstring.
|
||||
|
||||
Returns the shape checking module and the tuple of shape checking messages.
|
||||
"""
|
||||
context = mlir.make_ir_context()
|
||||
with context, ir.Location.unknown(context):
|
||||
array_input_types = tuple(mlir.aval_to_ir_type(a) for a in args_avals_flat)
|
||||
module = ir.Module.create(loc=ir.Location.unknown(context))
|
||||
|
||||
symbol_table = ir.SymbolTable(module.operation)
|
||||
shape_check_code_type = mlir.aval_to_ir_type(core.ShapedArray((), np.int32))
|
||||
shape_check_output_types = (shape_check_code_type,) * 3
|
||||
shape_check_ftype = ir.FunctionType.get(array_input_types,
|
||||
shape_check_output_types)
|
||||
shape_check_func = func_dialect.FuncOp(
|
||||
"main", shape_check_ftype,
|
||||
ip=ir.InsertionPoint.at_block_begin(module.body))
|
||||
shape_check_func.attributes["sym_visibility"] = ir.StringAttr.get("public")
|
||||
symbol_table.insert(shape_check_func)
|
||||
entry_block = shape_check_func.add_entry_block()
|
||||
with ir.InsertionPoint(entry_block):
|
||||
module_context = mlir.ModuleContext(
|
||||
"cpu", "cpu", sharding_impls.ShardingContext([]),
|
||||
source_info_util.new_name_stack(),
|
||||
[], itertools.count(1), [], module=module, context=context)
|
||||
ctx = mlir.LoweringRuleContext(module_context=module_context,
|
||||
primitive=None, avals_in=args_avals_flat,
|
||||
avals_out=None, tokens_in=mlir.TokenSet(),
|
||||
tokens_out=None)
|
||||
acc_shape_check_messages: list[str] = []
|
||||
values = mlir.lower_fun(
|
||||
functools.partial(shape_poly.compute_shape_check_from_arg_shapes,
|
||||
args_avals_flat, args_kwargs_tree=args_kwargs_tree,
|
||||
acc_shape_check_messages=acc_shape_check_messages),
|
||||
multiple_results=True)(ctx, *shape_check_func.arguments)
|
||||
func_dialect.ReturnOp(util.flatten(values))
|
||||
|
||||
return (module, acc_shape_check_messages) if acc_shape_check_messages else None
|
||||
|
||||
|
||||
def _check_lowering(lowering) -> None:
|
||||
if not isinstance(lowering, pxla.MeshComputation):
|
||||
raise NotImplementedError(f"serialization is supported only for pjit. {lowering}")
|
||||
@ -749,6 +687,7 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
||||
# See https://github.com/openxla/stablehlo/issues/8.
|
||||
"stablehlo.dynamic_reduce_window",
|
||||
"stablehlo.dynamic_rng_bit_generator",
|
||||
"shape_assertion", # Used by shape_poly to evaluate assertions
|
||||
}
|
||||
|
||||
|
||||
@ -763,6 +702,7 @@ def _check_module(mod: ir.Module, *,
|
||||
disabled_checks: the safety checks that are disabled.
|
||||
"""
|
||||
sharding_attr = ir.StringAttr.get("Sharding", mod.context)
|
||||
shape_assertion_attr = ir.StringAttr.get("shape_assertion", mod.context)
|
||||
allowed_custom_call_targets: set[str] = copy.copy(_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
|
||||
for dc in disabled_checks:
|
||||
target = dc.is_custom_call()
|
||||
@ -799,6 +739,8 @@ def _check_module(mod: ir.Module, *,
|
||||
disallowed_custom_call_ops.append(f"{op} at {op.location}")
|
||||
if call_target_name_attr == sharding_attr:
|
||||
check_sharding(op, op.location)
|
||||
elif call_target_name_attr == shape_assertion_attr:
|
||||
assert (DisabledSafetyCheck.shape_assertions() not in disabled_checks)
|
||||
|
||||
def walk_operations(op):
|
||||
check_op(op)
|
||||
|
@ -38,6 +38,7 @@ import itertools
|
||||
import io
|
||||
import math
|
||||
import operator as op
|
||||
import threading
|
||||
import tokenize
|
||||
from typing import Any, Callable, Iterable, Optional, Sequence, Union
|
||||
|
||||
@ -50,6 +51,7 @@ from jax.interpreters import xla
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src.lax import lax
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.numpy import lax_numpy
|
||||
@ -79,6 +81,15 @@ for more details.
|
||||
# https://github.com/python/mypy/issues/5887
|
||||
super().__init__(error_msg) # type: ignore
|
||||
|
||||
class _ShapePolyThreadLocalState(threading.local):
|
||||
|
||||
def __init__(self):
|
||||
# TODO(necula): this does not play well with some lowering caches, because
|
||||
# this state is not part of the cache key.
|
||||
self.enable_shape_assertions = True
|
||||
|
||||
thread_local_state = _ShapePolyThreadLocalState()
|
||||
|
||||
class _DimAtom:
|
||||
"""Represents an atom in a symbolic dimension expression.
|
||||
|
||||
@ -792,6 +803,62 @@ def _einsum_contract_path(*operands, **kwargs):
|
||||
|
||||
lax_numpy._poly_einsum_handlers[_DimExpr] = _einsum_contract_path
|
||||
|
||||
# To implement shape-constraint checking we use a shape assertion primitive.
|
||||
# shape_assertion_p.bind(assert_what: bool, *error_message_inputs,
|
||||
# error_message="...{0}...{1}")
|
||||
# where "{0}" refers to error_message_inputs[0], etc.
|
||||
shape_assertion_p = core.Primitive("shape_assertion")
|
||||
shape_assertion_p.multiple_results = True
|
||||
shape_assertion_p.def_effectful_abstract_eval(
|
||||
lambda *_, **__: ((), {shape_assertion_effect})) # type: ignore
|
||||
|
||||
def _shape_assertion_lowering_rule(ctx: mlir.LoweringRuleContext,
|
||||
assert_what: mlir.ir.Value,
|
||||
*error_message_inputs: mlir.ir.Value,
|
||||
error_message: str):
|
||||
op = mlir.custom_call(
|
||||
"shape_assertion",
|
||||
[], # No results
|
||||
[assert_what, *error_message_inputs],
|
||||
has_side_effect=True,
|
||||
extra_attributes=dict(error_message=mlir.ir.StringAttr.get(error_message))
|
||||
)
|
||||
return op.results
|
||||
|
||||
mlir.register_lowering(shape_assertion_p, _shape_assertion_lowering_rule)
|
||||
|
||||
class ShapeAssertionEffect(effects.Effect):
|
||||
__str__ = lambda _: "ShapeAssertionEffect"
|
||||
|
||||
shape_assertion_effect = ShapeAssertionEffect()
|
||||
|
||||
effects.lowerable_effects.add_type(ShapeAssertionEffect)
|
||||
effects.control_flow_allowed_effects.add_type(ShapeAssertionEffect)
|
||||
effects.remat_allowed_effects.add_type(ShapeAssertionEffect)
|
||||
effects.custom_derivatives_allowed_effects.add_type(ShapeAssertionEffect)
|
||||
|
||||
def shape_assertion(assert_what: jax.Array,
|
||||
*error_message_inputs: jax.Array,
|
||||
error_message: str) -> None:
|
||||
"""Adds a shape assertion in the code.
|
||||
|
||||
Args:
|
||||
assert_what: a boolean asserted to be true. Must be computed based only
|
||||
on dimension expressions, so that it can be evaluated after shape
|
||||
refinement.
|
||||
error_message_inputs: integers expressions whose values can be referenced
|
||||
in the `error_message`. Must be computed based only
|
||||
on dimension expressions, so that they can be evaluated after shape
|
||||
refinement.
|
||||
error_message: an error message, possibly containing format specifiers
|
||||
{0}, {1}, ..., referencing the values of the `error_message_inputs`.
|
||||
The format specifiers are sometimes processed with Python's
|
||||
`string::format` method, and sometimes with `llvm::formatv`.
|
||||
"""
|
||||
if thread_local_state.enable_shape_assertions:
|
||||
shape_assertion_p.bind(assert_what, *error_message_inputs,
|
||||
error_message=error_message)
|
||||
|
||||
# A JAX primitive with no array arguments but with a dimension parameter
|
||||
# that is a DimExpr. The value of the primitive is the value of the dimension,
|
||||
# using int64 in x64 mode or int32 otherwise (core.dim_value_dtype())
|
||||
@ -1055,22 +1122,6 @@ def _evaluate_multiply(v1, v2):
|
||||
pass
|
||||
return v1 * v2
|
||||
|
||||
def _is_known_constant(v) -> Optional[int]:
|
||||
try:
|
||||
return int(v)
|
||||
except Exception:
|
||||
# TODO(necula): added this so that in jax2tf, in Eager mode, we can tell
|
||||
# that a tensor is a constant. We should move this dependency into some
|
||||
# jax2tf-specific area.
|
||||
if hasattr(v, "val"):
|
||||
try:
|
||||
vint = int(v.val)
|
||||
if isinstance(vint, int): # In TF, int(tf.Tensor) is tf.Tensor!
|
||||
return vint
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
# dimension_size(operand, dimension=i) get the operand.shape[i] as a
|
||||
# value of type shape_poly.dim_as_value_dtype().
|
||||
dimension_size_p = core.Primitive("dimension_size")
|
||||
@ -1180,8 +1231,8 @@ class ShapeConstraint:
|
||||
left, right = [self._eval_operand(o, shapeenv) for o in [self.left,
|
||||
self.right]]
|
||||
# Try to evaluate the constraint statically.
|
||||
left_int, right_int = _is_known_constant(left), _is_known_constant(right)
|
||||
if left_int is not None and right_int is not None:
|
||||
if core.is_constant_shape((left, right)):
|
||||
left_int, right_int = op.index(left), op.index(right)
|
||||
if self.comp == ShapeConstraint.Comparator.EQ:
|
||||
if not (left_int == right_int):
|
||||
raise ValueError(self.make_err_msg(left_int, right_int))
|
||||
@ -1225,34 +1276,21 @@ class ShapeConstraints:
|
||||
for constraint in self.constraints:
|
||||
constraint.check_statically(shapeenv)
|
||||
|
||||
def compute(self, shapeenv: DimVarEnv) -> tuple[jax.Array, jax.Array, jax.Array, Sequence[str]]:
|
||||
"""Computes the error code for the set of constraints.
|
||||
def shape_assertions(self, shapeenv: DimVarEnv) -> None:
|
||||
"""Computes the shape assertions the set of constraints.
|
||||
|
||||
The error code is -1 if all constraints are satisfied, or an index into
|
||||
the returned error messages.
|
||||
See Exported.shape_check_module docstring.
|
||||
See _wrap_main_func docstring.
|
||||
"""
|
||||
# We want to report the errors in the same order as `check_statically`.
|
||||
# So, we process them in order, in case some fail statically, but we
|
||||
# accumulate the errors in reverse order in the computation, because the
|
||||
# code-generation strategy will report the last error that failed.
|
||||
acc: list[tuple[jax.Array, jax.Array, jax.Array, str]] = []
|
||||
# So, we process them in order, in case some fail statically, and we
|
||||
# generate the shape assertions in the same order.
|
||||
for constraint in self.constraints:
|
||||
check_res = constraint.compute(shapeenv)
|
||||
if check_res is not None:
|
||||
acc.append((*check_res, constraint.make_err_msg("%1", "%2"))) # type: ignore
|
||||
|
||||
shape_check_messages: list[str] = []
|
||||
shape_check_code: jax.Array = np.int32(-1) # type: ignore
|
||||
shape_check_op1 = shape_check_op2 = shape_check_code
|
||||
for (is_ok, op1, op2, msg) in reversed(acc):
|
||||
shape_check_code = lax.select(is_ok, shape_check_code,
|
||||
np.int32(len(shape_check_messages)))
|
||||
shape_check_op1 = lax.select(is_ok, shape_check_op1, op1)
|
||||
shape_check_op2 = lax.select(is_ok, shape_check_op2, op2)
|
||||
shape_check_messages.append(msg)
|
||||
return (shape_check_code, shape_check_op1, shape_check_op2, shape_check_messages)
|
||||
|
||||
is_ok, op1, op2 = check_res
|
||||
shape_assertion(
|
||||
is_ok, op1, op2,
|
||||
error_message=constraint.make_err_msg("{0}", "{1}"))
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DimEquation:
|
||||
@ -1352,7 +1390,8 @@ def compute_dim_vars_from_arg_shapes(
|
||||
|
||||
Like `solve_dim_vars` except that here we express the solution as
|
||||
JAX arrays that reference the `actual_args`. This function can be used to
|
||||
generate the code for computing the dimension variables.
|
||||
generate the code for computing the dimension variables. It also generates
|
||||
the shape assertions.
|
||||
|
||||
Returns: the values of the dimension variables, in the order determined by
|
||||
`all_dim_vars(args_avals)`.
|
||||
@ -1366,39 +1405,9 @@ def compute_dim_vars_from_arg_shapes(
|
||||
dimension=dim_idx)
|
||||
for (vname, arg_idx, dim_idx) in synth_dim_vars}
|
||||
dim_values = [solution[var].evaluate(synthetic_env) for var in dim_vars]
|
||||
shape_constraints.shape_assertions(synthetic_env)
|
||||
return tuple(dim_values)
|
||||
|
||||
|
||||
def compute_shape_check_from_arg_shapes(
|
||||
args_avals: Sequence[core.AbstractValue],
|
||||
*actual_args: jax.Array,
|
||||
args_kwargs_tree: tree_util.PyTreeDef,
|
||||
acc_shape_check_messages: list[str]) -> Sequence[jax.Array]:
|
||||
"""Computes the shape check code from the actual arguments.
|
||||
|
||||
`acc_shape_check_messages` is an initially empty list where we append the
|
||||
shape checking messages. Each of these correspond to a constraint.
|
||||
See the `Exported.shape_check_module` docstring.
|
||||
|
||||
Returns: a triple of JAX arrays: the shape check code and the values of the
|
||||
two operands for the failed constraint.
|
||||
|
||||
Raises ValueError if a constraint is invalidated statically.
|
||||
"""
|
||||
assert not acc_shape_check_messages
|
||||
solution, shape_constraints, synth_dim_vars = solve_dim_vars(
|
||||
tuple(args_avals), args_kwargs_tree=args_kwargs_tree)
|
||||
|
||||
# Replace the synthetic vars with the dynamic shape of the actual arg
|
||||
synthetic_env = {vname: dimension_size_p.bind(actual_args[arg_idx],
|
||||
dimension=dim_idx)
|
||||
for (vname, arg_idx, dim_idx) in synth_dim_vars}
|
||||
shape_check_code, shape_check_op1, shape_check_op2, shape_check_messages = (
|
||||
shape_constraints.compute(synthetic_env))
|
||||
acc_shape_check_messages.extend(shape_check_messages)
|
||||
return (shape_check_code, shape_check_op1, shape_check_op2)
|
||||
|
||||
|
||||
def _solve_dim_equations(
|
||||
eqns: list[_DimEquation]
|
||||
) -> tuple[DimVarEnv, ShapeConstraints]:
|
||||
|
@ -125,6 +125,8 @@ class CompatTest(bctu.CompatTestBase):
|
||||
|
||||
covered_targets = covered_targets.union({
|
||||
"tpu_custom_call", # tested separately
|
||||
# TODO(necula): add tests for shape_assertion
|
||||
"shape_assertion",
|
||||
})
|
||||
not_covered = targets_to_cover.difference(covered_targets)
|
||||
self.assertEmpty(not_covered)
|
||||
@ -608,7 +610,9 @@ class CompatTest(bctu.CompatTestBase):
|
||||
data = self.load_testdata(tpu_stablehlo_dynamic_reduce_window.data_unary_2023_06_17)
|
||||
self.run_one_test(
|
||||
func, data,
|
||||
polymorphic_shapes=("b, ...",))
|
||||
polymorphic_shapes=("b, ...",),
|
||||
# TODO(necula): now also includes shape_assertion
|
||||
compare_with_current=False)
|
||||
|
||||
def test_tpu_stablehlo_dynamic_reduce_window_variadic(self):
|
||||
# stablehlo.dynamic_reduce_window is used temporarily on TPU for a
|
||||
@ -629,7 +633,9 @@ class CompatTest(bctu.CompatTestBase):
|
||||
data = self.load_testdata(tpu_stablehlo_dynamic_reduce_window.data_variadic_2023_06_17)
|
||||
self.run_one_test(
|
||||
func, data,
|
||||
polymorphic_shapes=("b, ...", "b, ..."))
|
||||
polymorphic_shapes=("b, ...", "b, ..."),
|
||||
# TODO(necula): now also includes shape_assertion
|
||||
compare_with_current=False)
|
||||
|
||||
def test_stablehlo_dynamic_rbg_bit_generator(self):
|
||||
# stablehlo.dynamic_rbg_bit_generator is used temporarily for a
|
||||
@ -660,7 +666,9 @@ class CompatTest(bctu.CompatTestBase):
|
||||
try:
|
||||
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
|
||||
|
||||
self.run_one_test(func, data, polymorphic_shapes=(None, "b0, b1"))
|
||||
self.run_one_test(func, data, polymorphic_shapes=(None, "b0, b1"),
|
||||
# TODO(necula): now also includes shape_assertion
|
||||
compare_with_current=False)
|
||||
finally:
|
||||
jax.config.update("jax_default_prng_impl", prev_default_prng_impl)
|
||||
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import re
|
||||
from typing import Optional, Sequence
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -28,136 +27,14 @@ from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import xla_extension
|
||||
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
if xc.mlir_api_version >= 50:
|
||||
def _temp_call_exported(exp: jax_export.Exported, *args: jax.Array,
|
||||
skip_shape_check: bool = False):
|
||||
assert all(core.is_constant_shape(a.shape) for a in args)
|
||||
return jax_export.call_exported(exp)(*args)
|
||||
else:
|
||||
def _temp_call_exported(exp: jax_export.Exported, *args: jax.Array,
|
||||
skip_shape_check: bool = False):
|
||||
"""Temporary runner for an Exported.
|
||||
|
||||
Normally we would use jax_export.call_exported, but if the exported has
|
||||
shape polymorphism and we are using jaxlib before 0.4.12 we use
|
||||
Once we upgrade the jaxlib we can replace all uses of this function with
|
||||
jax_export.call_exported.
|
||||
"""
|
||||
assert all(core.is_constant_shape(a.shape) for a in args)
|
||||
if not exp.module_uses_dim_vars:
|
||||
return jax_export.call_exported(exp)(*args)
|
||||
else:
|
||||
# We only get here in external tests, because internal ones use newest jaxlib
|
||||
from jax.experimental.jax2tf import jax2tf # TODO: temporary
|
||||
# call_exported does the shape checking, we must do it manually for
|
||||
# XlaCallModule.
|
||||
if not skip_shape_check:
|
||||
shape_check = exp.shape_check_module()
|
||||
if shape_check is not None:
|
||||
err_msg = _run_shape_check_module(exp, shape_check[0], shape_check[1], args)
|
||||
if err_msg is not None:
|
||||
raise ValueError(err_msg)
|
||||
|
||||
numpy_results = map(lambda res_tf: res_tf.numpy(),
|
||||
jax2tf._run_exported_as_tf(args, exp))
|
||||
return exp.out_tree.unflatten(numpy_results)
|
||||
|
||||
|
||||
def _run_shape_check_module(primal_exported: jax_export.Exported,
|
||||
shape_check_module_serialized: bytes,
|
||||
shape_check_messages: Sequence[str],
|
||||
args: Sequence[jax.Array]
|
||||
) -> Optional[str]:
|
||||
"""Helper to run a shape checking module.
|
||||
|
||||
We only need to do this in tests, because otherwise this will be done
|
||||
implicitly when we call XlaCallModule.
|
||||
|
||||
Returns: the error message, or None if no error.
|
||||
"""
|
||||
args = tuple(args)
|
||||
# We cannot just make an Exported and run it, because call_exported will
|
||||
# do the shape checks statically. So, we wrap the shape check module
|
||||
# with static shape arguments
|
||||
|
||||
static_in_avals = tuple(core.get_aval(a) for a in args)
|
||||
context = mlir.make_ir_context()
|
||||
with context, ir.Location.unknown(context):
|
||||
wrapped_module = ir.Module.parse(
|
||||
xla_extension.mlir.deserialize_portable_artifact(
|
||||
shape_check_module_serialized))
|
||||
symbol_table = ir.SymbolTable(wrapped_module.operation)
|
||||
orig_main = symbol_table["main"]
|
||||
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
|
||||
symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main")
|
||||
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value
|
||||
# Use static shapes
|
||||
new_main_input_types = [mlir.aval_to_ir_type(a) for a in static_in_avals]
|
||||
orig_output_types = orig_main.type.results
|
||||
new_main_ftype = ir.FunctionType.get(
|
||||
new_main_input_types, orig_output_types
|
||||
)
|
||||
new_main_op = func_dialect.FuncOp(
|
||||
"main",
|
||||
new_main_ftype,
|
||||
ip=ir.InsertionPoint.at_block_begin(wrapped_module.body),
|
||||
)
|
||||
new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public")
|
||||
symbol_table.insert(new_main_op)
|
||||
entry_block = new_main_op.add_entry_block()
|
||||
with ir.InsertionPoint(entry_block):
|
||||
orig_main_args: list[ir.Value] = []
|
||||
for new_arg, orig_arg_type in zip(
|
||||
new_main_op.arguments, orig_main.type.inputs
|
||||
):
|
||||
orig_main_args.append(hlo.ConvertOp(orig_arg_type, new_arg).result)
|
||||
call = func_dialect.CallOp(
|
||||
orig_output_types,
|
||||
ir.FlatSymbolRefAttr.get(orig_main_name),
|
||||
orig_main_args,
|
||||
)
|
||||
func_dialect.ReturnOp(call.results)
|
||||
symbol_table.set_symbol_name(new_main_op, "main")
|
||||
|
||||
wrapped_module_serialized, version = jax_export._serialize_module(wrapped_module)
|
||||
# Make an Exported and then run it
|
||||
out_avals = (core.ShapedArray((), dtype=np.int32),) * 3
|
||||
exp = jax_export.Exported(
|
||||
fun_name=f"shape_check_{primal_exported.fun_name}",
|
||||
in_tree=tree_util.tree_flatten((args, {}))[1],
|
||||
in_avals=static_in_avals,
|
||||
out_tree=tree_util.tree_flatten(out_avals)[1],
|
||||
out_avals=out_avals,
|
||||
in_shardings=None,
|
||||
out_shardings=None,
|
||||
lowering_platform=primal_exported.lowering_platform,
|
||||
disabled_checks=(),
|
||||
mlir_module_serialized=wrapped_module_serialized,
|
||||
xla_call_module_version=version,
|
||||
module_kept_var_idx=tuple(sorted(range(len(static_in_avals)))),
|
||||
module_uses_dim_vars=True,
|
||||
_get_vjp=lambda _: None) # type: ignore
|
||||
|
||||
code, op1, op2 = _temp_call_exported(exp, *args,
|
||||
skip_shape_check=True)
|
||||
if code == -1:
|
||||
return None
|
||||
else:
|
||||
return shape_check_messages[code].replace("%1", str(int(op1))).replace(
|
||||
"%2", str(int(op2)))
|
||||
|
||||
|
||||
class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
def test_basic_export_only(self):
|
||||
@ -396,7 +273,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
dict(poly_spec="3,4,12", arg_shape=(3, 4, 12)),
|
||||
dict(poly_spec="3,4,12", arg_shape=(3, 4, 13),
|
||||
# The shape check module does not test constant dimensions
|
||||
expect_error_run=re.escape(
|
||||
expect_error=re.escape(
|
||||
r"Shape mismatch for args[0].shape[2] (expected same constant)")),
|
||||
dict(poly_spec="3,4,6*a", arg_shape=(3, 4, 12)),
|
||||
dict(poly_spec="3,a,a+8", arg_shape=(3, 4, 12)),
|
||||
@ -418,49 +295,31 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def test_poly_shape_checks(
|
||||
self, poly_spec="3,a,a+8",
|
||||
arg_shape=(3, 4, 12), arg_dtype=np.float32,
|
||||
expect_error=None, # If given, applies for expect_error_run and expect_error_shape_check
|
||||
expect_error_run=None, # Error from running the exported module
|
||||
expect_error_shape_check=None): # Error from running the shape check module
|
||||
if expect_error is not None:
|
||||
self.assertIsNone(expect_error_run, None)
|
||||
self.assertIsNone(expect_error_shape_check, None)
|
||||
expect_error_run = expect_error_shape_check = expect_error
|
||||
expect_error=None): # If given, error from running the exported module
|
||||
|
||||
def f(x): # x: f32[poly_spec]
|
||||
return jnp.reshape(x, (-1, x.shape[1]))
|
||||
|
||||
exp_f = jax_export.export(f)(
|
||||
if xc.mlir_api_version <= 51:
|
||||
disabled_checks = (jax_export.DisabledSafetyCheck.shape_assertions(),)
|
||||
else:
|
||||
disabled_checks = ()
|
||||
exp_f = jax_export.export(f, disabled_checks=disabled_checks)(
|
||||
jax_export.poly_spec((3, 4, 12), np.float32, poly_spec))
|
||||
self.assertEqual(exp_f.module_uses_dim_vars, poly_spec != "3,4,12")
|
||||
arg = np.arange(np.prod(arg_shape),
|
||||
dtype=arg_dtype).reshape(arg_shape) # arg : f32[3,4,12]
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
if expect_error_run is not None:
|
||||
stack.push(self.assertRaisesRegex(Exception, expect_error_run))
|
||||
if expect_error is not None:
|
||||
stack.push(self.assertRaisesRegex(Exception, expect_error))
|
||||
|
||||
res = _temp_call_exported(exp_f, arg)
|
||||
assert core.is_constant_shape(arg.shape)
|
||||
res = jax_export.call_exported(exp_f)(arg)
|
||||
|
||||
if not expect_error_run:
|
||||
if not expect_error:
|
||||
self.assertAllClose(res, f(arg))
|
||||
|
||||
# Test the shape_check_module
|
||||
shape_check = exp_f.shape_check_module()
|
||||
# We have shape_check only if the exported has polymorphic inputs shapes.
|
||||
if all(core.is_constant_shape(a.shape) for a in exp_f.in_avals):
|
||||
self.assertIsNone(shape_check)
|
||||
self.assertIsNone(expect_error_shape_check)
|
||||
else:
|
||||
self.assertIsNotNone(shape_check)
|
||||
shape_check_module, shape_check_messages = shape_check
|
||||
err_msg = _run_shape_check_module(exp_f,
|
||||
shape_check_module, shape_check_messages, (arg,))
|
||||
|
||||
if expect_error_shape_check is None:
|
||||
self.assertIsNone(err_msg)
|
||||
else:
|
||||
self.assertRegex(err_msg, expect_error_shape_check)
|
||||
|
||||
|
||||
# An inner function is exported with polymorphic shapes inner_poly_spec, and
|
||||
# is called from an outer function, which is exported with outer_poly_spec.
|
||||
@jtu.parameterized_filterable(
|
||||
@ -536,12 +395,16 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
expect_error_outer_exp=None,
|
||||
expect_error_run=None):
|
||||
# Polymorphic export called with static or polymorphic shapes
|
||||
if xc.mlir_api_version <= 51:
|
||||
disabled_checks = (jax_export.DisabledSafetyCheck.shape_assertions(),)
|
||||
else:
|
||||
disabled_checks = ()
|
||||
def inner(x): # x: inner_poly_spec
|
||||
return jnp.reshape(x, (-1, x.shape[1]))
|
||||
|
||||
arg = np.arange(np.prod(arg_shape),
|
||||
dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12]
|
||||
inner_exp = jax_export.export(inner)(
|
||||
inner_exp = jax_export.export(inner, disabled_checks=disabled_checks)(
|
||||
jax_export.poly_spec((3, 4, 12), np.float32, inner_poly_spec))
|
||||
|
||||
self.assertEqual(inner_exp.module_uses_dim_vars,
|
||||
@ -556,7 +419,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp))
|
||||
|
||||
# Call it after exporting again, with polymorphic shapes
|
||||
outer_exp = jax_export.export(outer)(
|
||||
outer_exp = jax_export.export(outer, disabled_checks=disabled_checks)(
|
||||
jax_export.poly_spec(arg.shape, arg.dtype, outer_poly_spec))
|
||||
|
||||
if expect_error_outer_exp is not None:
|
||||
@ -564,17 +427,12 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertEqual(outer_exp.module_uses_dim_vars,
|
||||
(inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12"))
|
||||
shape_check = outer_exp.shape_check_module()
|
||||
if all(core.is_constant_shape(a.shape) for a in outer_exp.in_avals):
|
||||
self.assertIsNone(shape_check)
|
||||
else:
|
||||
self.assertIsNotNone(shape_check)
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
if expect_error_run is not None:
|
||||
stack.push(self.assertRaisesRegex(Exception, expect_error_run))
|
||||
|
||||
res = _temp_call_exported(outer_exp, arg)
|
||||
res = jax_export.call_exported(outer_exp)(arg)
|
||||
|
||||
if expect_error_run is not None:
|
||||
return
|
||||
|
@ -87,7 +87,9 @@ class DimExprTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertEqual((a, 3), shape_poly._parse_spec("(a, ...) ", tshape))
|
||||
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
@parameterized.named_parameters(
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(testcase_name=f"_{dim_spec}",
|
||||
dim_spec=dim_spec, dim_poly=dim_poly)
|
||||
for dim_spec, dim_poly in [
|
||||
@ -100,34 +102,36 @@ class DimExprTest(tf_test_util.JaxToTfTestCase):
|
||||
("a + -1", a - 1),
|
||||
("3 * a * mod(a + 2, b + 2)", 3 * a * ((a + 2) % (b + 2))),
|
||||
("3 * floordiv(a + 2, b + 2) * 2", 3 * ((a + 2) // (b + 2)) * 2),
|
||||
])
|
||||
]])
|
||||
def test_parse_dim(self,
|
||||
dim_spec="-2 * a^2 * b + b^2",
|
||||
dim_poly=-2 * a * a * b + b * b):
|
||||
self.assertEqual((dim_poly,), shape_poly._parse_spec(dim_spec, (None,)))
|
||||
self.assertEqual((dim_poly,), shape_poly._parse_spec(str(dim_poly), (None,)))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(testcase_name=f"_{shape_spec=}",
|
||||
shape_spec=shape_spec)
|
||||
for shape_spec in [
|
||||
"2.5", "a + a a", "a ^ a", "a, a",
|
||||
"_", "...", "a ;", ")(", "2a", "a@", "'a'", "('a', ...)",
|
||||
"mod(a)", "floordiv(a, b, c)", "..., 3"
|
||||
])
|
||||
]])
|
||||
def test_parse_error(self,
|
||||
shape_spec="a + a a"):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"syntax error in polymorphic shape"):
|
||||
shape_poly._parse_spec(shape_spec, (None,))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(testcase_name=f"_{shape_spec=}",
|
||||
shape_spec=shape_spec, arg_shape=arg_shape)
|
||||
for shape_spec, arg_shape in [
|
||||
("3", (4,)),
|
||||
("b, 3", (None, 4)),
|
||||
])
|
||||
]])
|
||||
def test_parse_mismatch_error(self,
|
||||
shape_spec="3", arg_shape=(4,)):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
@ -347,7 +351,8 @@ class DimExprTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertEqual(a * 2 // a, 2)
|
||||
self.assertIsInstance(a * 2 // a, int)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(testcase_name=f"_D={dividend}_d={divisor}_q={quotient}_r={remainder}",
|
||||
dividend=dividend, divisor=divisor, quotient=quotient,
|
||||
remainder=remainder)
|
||||
@ -364,7 +369,7 @@ class DimExprTest(tf_test_util.JaxToTfTestCase):
|
||||
(3 * a, 2, "floordiv(3*a, 2)", "mod(3*a, 2)"),
|
||||
(2 * a * b + b * b, a + b, "floordiv(2*a*b + b^2, a + b)", "mod(2*a*b + b^2, a + b)"),
|
||||
(3, a, "floordiv(3, a)", "mod(3, a)"),
|
||||
])
|
||||
]])
|
||||
def test_poly_divmod(self, *, dividend, quotient, divisor, remainder):
|
||||
if isinstance(quotient, str):
|
||||
d1, d2 = divmod(dividend, divisor)
|
||||
@ -930,37 +935,26 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
re.escape("pytree structure error: different types")):
|
||||
conv_and_run(arg_shape=(2,), polymorphic_shape=("a tuple",))
|
||||
|
||||
# The following do not work yet with native serialization because
|
||||
# XlaCallModule does not yet do shape checking.
|
||||
if config.jax2tf_default_native_serialization:
|
||||
return
|
||||
|
||||
# TODO(necula): enable even for native serialization
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Cannot solve for values of dimension variables {'b'}"):
|
||||
conv_and_run(arg_shape=(4, 36, 3), polymorphic_shape="b * b, b * d * d, d")
|
||||
|
||||
# TODO(necula): enable even for native serialization
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
|
||||
"Dimension variable 'b' must have integer value >= 1"):
|
||||
conv_and_run(arg_shape=(5, 36), polymorphic_shape="3 * b, ...")
|
||||
|
||||
# TODO(necula): enable even for native serialization
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
|
||||
"Dimension variable 'b' must have integer value >= 1"):
|
||||
conv_and_run(arg_shape=(10, 3), polymorphic_shape="3 * b + 10, ...")
|
||||
|
||||
# TODO(necula): enable even for native serialization
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
|
||||
"Dimension variable 'b' must have integer value >= 1"):
|
||||
conv_and_run(arg_shape=(7, 3), polymorphic_shape="3 * b + 10, ...")
|
||||
|
||||
# TODO(necula): enable even for native serialization
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
tf.errors.InvalidArgumentError,
|
||||
"Found inconsistency 3 != 2 when solving.*"):
|
||||
conv_and_run(arg_shape=(2, 3), polymorphic_shape="(a, a)")
|
||||
|
||||
def test_pytree(self):
|
||||
"""Arguments and polymorphic_shapes are pytrees."""
|
||||
|
||||
@ -1250,7 +1244,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
res_tf = jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x)
|
||||
self.assertAllClose(f(x), res_tf)
|
||||
|
||||
@jtu.sample_product(with_function=[False, True])
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[dict(with_function=v) for v in [True, False]]
|
||||
)
|
||||
def test_grad_int(self, with_function=False):
|
||||
# https://github.com/google/jax/issues/7093
|
||||
# Also issue #6975.
|
||||
@ -1418,34 +1414,56 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
jax2tf.convert(lambda x: 0 if x.shape[0] + 1 >= x.shape[1] else 1,
|
||||
polymorphic_shapes=["(a, b)"])(np.ones((4, 4)))
|
||||
|
||||
# Unsoundness: not checking that the dimension variable is 0
|
||||
# Checking that the dimension variable is >= 1
|
||||
def f1_jax(x): # f32[b]
|
||||
# We have to use "x"
|
||||
return jnp.concatenate([x, jnp.array([0. if x.shape[0] == 0 else 1.],
|
||||
dtype=np.float32)])
|
||||
|
||||
x0 = np.array([], np.float32)
|
||||
# JAX with static shapes sees that the x.shape[0] == 0
|
||||
self.assertEqual(jnp.array([0.], dtype=np.float32), f1_jax(x0))
|
||||
|
||||
# In graph serialization eager mode we catch the error
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
tf.errors.InvalidArgumentError,
|
||||
"Dimension variable 'b' must have integer value >= 1. Found 0"):
|
||||
jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
|
||||
native_serialization=False)(x0)
|
||||
|
||||
# In native serialization, or if we trace to a TF graph, we miss this
|
||||
res1_tf = jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
|
||||
native_serialization=True)(x0)
|
||||
self.assertEqual(jnp.array([1.], dtype=np.float32), res1_tf)
|
||||
|
||||
# In graph serialization graph mode we also catch it (except on TPU)
|
||||
f1_tf = tf.function(
|
||||
jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
|
||||
native_serialization=False)
|
||||
native_serialization=False),
|
||||
autograph=False,
|
||||
).get_concrete_function(tf.TensorSpec([None], dtype=np.float32))
|
||||
# In graph serialization graph mode we also catch it (except on TPU, where
|
||||
# the behavior is as for jit_compile=1)
|
||||
if jtu.device_under_test() == "tpu":
|
||||
self.assertEqual(jnp.array([1.], dtype=np.float32), f1_tf(x0))
|
||||
else:
|
||||
with self.assertRaisesRegex(
|
||||
tf.errors.InvalidArgumentError,
|
||||
"Dimension variable .* must have integer value"):
|
||||
_ = f1_tf(x0)
|
||||
|
||||
# In graph serialization with jit_compile=True we do not catch the error
|
||||
# and we return the wrong result
|
||||
f1_tf = tf.function(
|
||||
jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
|
||||
native_serialization=False),
|
||||
autograph=False,
|
||||
jit_compile=True
|
||||
)
|
||||
self.assertEqual(jnp.array([1.], dtype=np.float32), f1_tf(x0))
|
||||
|
||||
# Unsoundness: not checking that the actual dimensions denoted by the same
|
||||
# We also catch the error with native serialization
|
||||
with self.assertRaisesRegex(
|
||||
tf.errors.InvalidArgumentError,
|
||||
"Dimension variable 'b' must have integer value >= 1. Found 0"):
|
||||
_ = jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
|
||||
native_serialization=True)(x0)
|
||||
|
||||
# Checking that the actual dimensions denoted by the same
|
||||
# dimension variables have equal sizes.
|
||||
def f2_jax(x): # f32[b, b]
|
||||
# We have to use "x"
|
||||
@ -1455,25 +1473,46 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
# JAX with static shapes sees that x.shape[0] != x.shape[1]
|
||||
self.assertEqual(jnp.sum(x45), f2_jax(x45))
|
||||
|
||||
# jax2tf catches the broken assumption b >= 1 if the converted function is executed
|
||||
# eagerly.
|
||||
# In graph serialization eager mode, we catch the broken assumption b >= 1
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
tf.errors.InvalidArgumentError,
|
||||
r"Found inconsistency 5 != 4 when solving b == args\[0\].shape\[1\]"):
|
||||
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
|
||||
native_serialization=False)(x45)
|
||||
|
||||
# In native serialization, or if we trace to a TF graph, we miss this
|
||||
res2_tf = jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
|
||||
native_serialization=True)(x45)
|
||||
self.assertEqual(1. + jnp.sum(x45), res2_tf)
|
||||
# In graph serialization graph mode we also catch it (except on TPU, where
|
||||
# the behavior is as for jit_compile=1)
|
||||
|
||||
f2_tf = tf.function(
|
||||
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
|
||||
native_serialization=False)
|
||||
native_serialization=False),
|
||||
autograph=False,
|
||||
).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32))
|
||||
if jtu.device_under_test() == "tpu":
|
||||
self.assertEqual(1. + jnp.sum(x45), f2_tf(x45))
|
||||
else:
|
||||
with self.assertRaisesRegex(
|
||||
tf.errors.InvalidArgumentError,
|
||||
r"Found inconsistency"):
|
||||
_ = f2_tf(x45)
|
||||
|
||||
# In graph serialization with jit_compile=True we do not catch the error
|
||||
# and we return the wrong result
|
||||
f2_tf = tf.function(
|
||||
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
|
||||
native_serialization=False),
|
||||
autograph=False,
|
||||
jit_compile=True
|
||||
)
|
||||
self.assertEqual(1. + jnp.sum(x45), f2_tf(x45))
|
||||
|
||||
# We also catch the error with native serialization
|
||||
with self.assertRaisesRegex(
|
||||
tf.errors.InvalidArgumentError,
|
||||
"Found inconsistency 5 != 4"):
|
||||
_ = jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
|
||||
native_serialization=True)(x45)
|
||||
|
||||
x = np.ones((5,), dtype=np.float32)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
|
@ -27,7 +27,6 @@ from typing import Any, Sequence
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
@ -180,12 +179,13 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
else:
|
||||
assert False
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}",
|
||||
in_shardings=in_shardings, out_shardings=out_shardings)
|
||||
for in_shardings in ("missing", None, "P")
|
||||
for out_shardings in ("missing", None, "P")
|
||||
)
|
||||
])
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_pjit_basic(self, in_shardings="P", out_shardings="P"):
|
||||
# Ensure that we can distinguish the inputs and outputs by shape
|
||||
@ -316,14 +316,15 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
res_tf = f_tf(x)
|
||||
self.assertAllClose(res_tf, res_jax)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(testcase_name=f"_nested_pjit={nested_pjit}_constraint={constraint}_poly={poly}",
|
||||
nested_pjit=nested_pjit, constraint=constraint, poly=poly)
|
||||
# We add a constraint either with a nested pjit or with a sharding_constraint
|
||||
for nested_pjit in (True, False)
|
||||
for constraint in (None, "P")
|
||||
for poly in (None, "2*b1,_", "_,b2", "2*b1,b2")
|
||||
)
|
||||
])
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_pjit_sharding_constraint(self, nested_pjit=True, constraint="P", poly="2*b1,b2"):
|
||||
constraint_sharding = P("x", None) if constraint == "P" else None
|
||||
@ -363,12 +364,13 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
(r"custom_call_target.*Sharding", 2 + count_inner_sharding + count_inner_replicated)
|
||||
])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}",
|
||||
in_shardings=in_shardings, out_shardings=out_shardings)
|
||||
for in_shardings in ("missing", None, "P")
|
||||
for out_shardings in ("missing", None, "P")
|
||||
)
|
||||
])
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_grad_pjit(self, in_shardings="P", out_shardings=None):
|
||||
def f_jax(x): # x: f32[10,20] -> f32[20,10]
|
||||
@ -414,7 +416,8 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", count_out_replicated),
|
||||
])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(testcase_name=f"_kind={kind}_in_shardings={in_shardings}_out_shardings={out_shardings}",
|
||||
kind=kind, in_shardings=in_shardings, out_shardings=out_shardings)
|
||||
for kind in ("pjit", "jit", "sharding_constraint")
|
||||
@ -425,7 +428,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
for out_shardings in (
|
||||
("unspecified",) if kind in ["sharding_constraint", "jit"] else
|
||||
("unspecified", "none", "P"))
|
||||
)
|
||||
])
|
||||
def test_pjit_error_inner_sharding(self, kind="pjit", in_shardings="P",
|
||||
out_shardings="none"):
|
||||
# Check that we raise an error if there is no top-level pjit but we convert
|
||||
@ -462,11 +465,12 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
with Mesh(self.devices, axis_names=("x",)):
|
||||
f_tf(x)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(testcase_name=f"_func={func}", func=func)
|
||||
for func in ("pjit_sharded", "pjit_replicated",
|
||||
"nested_pjit_sharded", "nested_pjit_replicated")
|
||||
)
|
||||
])
|
||||
def test_pjit_eager_error(self, func="pjit_sharded"):
|
||||
if config.jax2tf_default_native_serialization:
|
||||
raise unittest.SkipTest("There is no error in eager mode for native serialization")
|
||||
@ -693,10 +697,11 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
res_tf = f_tf(a)
|
||||
self.assertAllClose(res_tf, expected)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(testcase_name=f"_poly={poly}", poly=poly)
|
||||
for poly in (None, "2*b1,_", "_,b2", "2*b1,b2")
|
||||
)
|
||||
])
|
||||
def test_shmap_collective_permute(self, poly=None):
|
||||
if jtu.device_under_test() == "cpu":
|
||||
raise unittest.SkipTest("TODO(b/268295912): ShardingRemover crash")
|
||||
|
Loading…
x
Reference in New Issue
Block a user