mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #16211 from gnecula:poly_checks
PiperOrigin-RevId: 540030091
This commit is contained in:
commit
e9f2e40a3e
@ -22,7 +22,7 @@ import re
|
||||
import os
|
||||
import tempfile
|
||||
import textwrap
|
||||
from typing import Callable, List, Generator, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Generator, Optional, Sequence, Tuple, Union
|
||||
import unittest
|
||||
import warnings
|
||||
import zlib
|
||||
@ -1207,3 +1207,42 @@ def _parse_version(v: str) -> Tuple[int, ...]:
|
||||
|
||||
def numpy_version():
|
||||
return _parse_version(np.__version__)
|
||||
|
||||
def parameterized_filterable(*,
|
||||
kwargs: Sequence[Dict[str, Any]],
|
||||
testcase_name: Optional[Callable[[Dict[str, Any]], str]] = None,
|
||||
one_containing: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Decorator for named parameterized tests, with filtering.
|
||||
|
||||
Works like parameterized.named_parameters, except that it supports the
|
||||
`one_containing` option. This is useful to select only one of the tests,
|
||||
and to leave the test name unchanged (helps with specifying the desired test
|
||||
when debugging).
|
||||
|
||||
Args:
|
||||
kwargs: Each entry is a set of kwargs to be passed to the test function.
|
||||
testcase_name: Optionally, a function to construct the testcase_name from
|
||||
one kwargs dict. If not given then the kwarg must contain `testcase_name`.
|
||||
one_containing: If given, then leave the test name unchanged, and use
|
||||
only one `kwargs` whose `testcase_name` includes `one_containing`.
|
||||
"""
|
||||
# Ensure that all kwargs contain a testcase_name
|
||||
kwargs_with_testcase_name: Sequence[Dict[str, Any]]
|
||||
if testcase_name is not None:
|
||||
kwargs_with_testcase_name = [dict(testcase_name=testcase_name(kw), **kw)
|
||||
for kw in kwargs]
|
||||
else:
|
||||
for kw in kwargs:
|
||||
assert "testcase_name" in kw
|
||||
kwargs_with_testcase_name = kwargs
|
||||
if one_containing is not None:
|
||||
filtered = tuple(kw for kw in kwargs_with_testcase_name
|
||||
if one_containing in kw["testcase_name"])
|
||||
assert filtered, f"No testcase_name contains '{one_containing}'"
|
||||
kw = filtered[0]
|
||||
kw["testcase_name"] = ""
|
||||
return parameterized.named_parameters([kw])
|
||||
else:
|
||||
return parameterized.named_parameters(*kwargs_with_testcase_name)
|
||||
|
@ -575,6 +575,18 @@ class GraphSerializationImpl(SerializationImpl):
|
||||
partial(shape_poly.compute_dim_vars_from_arg_shapes,
|
||||
self.args_avals_flat, args_kwargs_tree=self.in_tree),
|
||||
self.args_flat_tf, self.args_avals_flat, self.name_stack)
|
||||
|
||||
# We invoke shape checking to give it a chance to raise shape errors that
|
||||
# are evident statically. This should work in TF eager mode because all
|
||||
# the shapes are known.
|
||||
# TODO: handle non-static shape checking for graph serialization
|
||||
acc_shape_check_messages: List[str] = []
|
||||
_, _ = _interpret_fun_jax(
|
||||
partial(shape_poly.compute_shape_check_from_arg_shapes,
|
||||
self.args_avals_flat, args_kwargs_tree=self.in_tree,
|
||||
acc_shape_check_messages=acc_shape_check_messages),
|
||||
self.args_flat_tf, self.args_avals_flat, self.name_stack)
|
||||
|
||||
_thread_local_state.shape_env = zip(dim_vars, dim_values)
|
||||
|
||||
fun_flat_jax, out_tree_thunk = flatten_fun_jax(self.fun_jax, self.in_tree)
|
||||
|
@ -23,6 +23,7 @@ import re
|
||||
from typing import Any, Callable, List, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import sharding
|
||||
@ -101,10 +102,10 @@ class Exported:
|
||||
expressions in the shapes, with dimension variables among those in
|
||||
`in_avals.
|
||||
in_shardings: the flattened input shardings. Only for the inputs that are
|
||||
specified in `module_kept_var_idx`.
|
||||
specified in `module_kept_var_idx`. If `None` then it is equivalent
|
||||
to unspecified shardings.
|
||||
out_shardings: the flattened output shardings, as long as `in_avals`.
|
||||
lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'
|
||||
|
||||
mlir_module_serialized: the serialized lowered VHLO module.
|
||||
xla_call_module_version: a version number for the serialized module.
|
||||
See more versioning details at https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+VERSION_MAXIMUM_SUPPORTED%22&type=code
|
||||
@ -113,7 +114,7 @@ class Exported:
|
||||
because they are not used. Same length as `in_shardings`.
|
||||
module_uses_dim_vars: whether the `mlir_module_serialized` uses shape
|
||||
polymorphic dimension variables. This may be from `in_avals` but also
|
||||
from inner calls of Exported modules.
|
||||
from inner calls of shape-polymorphic Exported modules.
|
||||
disabled_checks: a list of descriptors of safety checks that have been
|
||||
disabled at export time.
|
||||
_get_vjp: an optional function that takes the current exported function and
|
||||
@ -129,8 +130,8 @@ class Exported:
|
||||
out_tree: tree_util.PyTreeDef
|
||||
out_avals: Tuple[core.AbstractValue, ...]
|
||||
|
||||
in_shardings: Tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]
|
||||
out_shardings: Tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]
|
||||
in_shardings: Optional[Tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]]
|
||||
out_shardings: Optional[Tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]]
|
||||
lowering_platform: str
|
||||
disabled_checks: Sequence[DisabledSafetyCheck]
|
||||
|
||||
@ -145,6 +146,66 @@ class Exported:
|
||||
def mlir_module(self) -> ir.Module:
|
||||
return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized)
|
||||
|
||||
def shape_check_module(self) -> Optional[Tuple[bytes, Sequence[str]]]:
|
||||
"""Generates a serialized shape checking module and the error messages.
|
||||
|
||||
Consider the exporting of a function with one array argument of type
|
||||
"f32[w, 2 * h]", where "w" and "h" are two dimension variables. JAX tracing
|
||||
assumes that `arg.shape[1]` is even, and that both `w` and `h` have
|
||||
values >= 1. We must ensure that these assumptions hold when the function
|
||||
is invoked.
|
||||
|
||||
If we `call_exported` for this module we perform these checks
|
||||
statically (in `call_exported_abstract_eval`).
|
||||
|
||||
But if we wrap the exported module with XlaCallModule for execution outside
|
||||
JAX, we must defer these checks to compile time, when the static shapes
|
||||
are known.
|
||||
We generate a separate MLIR shape checking module with a main
|
||||
function that returns a triple of int32's, with a shape checking code and
|
||||
two shape checking operands. The shape checking code is -1 if the checks
|
||||
pass, or is an index into the returned shape_check_messages otherwise.
|
||||
The returned shape checking operands are substituted for "%1" and "%2" in
|
||||
the shape checking messages to obtain the error message. Here is a shape
|
||||
checking module for our example:
|
||||
|
||||
func public main(arg: f32[?, ?]) {
|
||||
# code will be the index of the last failed shape check.
|
||||
code = op1 = op2 = -1 # Assume no errors initially
|
||||
# Check that w is >= 1
|
||||
arg_w = hlo.get_dimension_size(arg, 0)
|
||||
code = hlo.select(arg_w >= 1, code, 0) # error message 0
|
||||
op1 = hlo.select(arg_w >= 1, op1, arg_w)
|
||||
op2 = hlo.select(arg_w >= 1, op2, 1)
|
||||
# Check that dim1 is even
|
||||
dim1 = hlo.get_dimension_size(arg, 1)
|
||||
code = hlo.select(dim1 % 2 == 0, code, 1) # error message 1
|
||||
op1 = hlo.select(dim1 % 2 == 0, op1, dim1 % 2)
|
||||
op2 = hlo.select(dim1 % 2 == 0, op2, 0)
|
||||
# Check that h >= 1
|
||||
arg_h = hlo.floordiv(dim1, 2)
|
||||
code = hlo.select(arg_h >= 1, code, 2) # error message 2
|
||||
op1 = hlo.select(arg_h >= 1, op1, arg_h)
|
||||
op2 = hlo.select(arg_h >= 1, op2, 1)
|
||||
return (code, op1, op2)
|
||||
|
||||
We also return the following shape checking messages to be used for
|
||||
constructing error messages:
|
||||
|
||||
shape_check_messages = [
|
||||
"Dimension variable 'w' must have integer value >= 1. Found %1",
|
||||
"Dimension variable 'h' must have integer value >= 1. Found non-zero remainder %1",
|
||||
"Dimension variable 'h' must have integer value >= 1. Found %1"]
|
||||
"""
|
||||
shape_check_res = _make_shape_check_module(
|
||||
self.in_avals, args_kwargs_tree=self.in_tree)
|
||||
if shape_check_res is None:
|
||||
return None
|
||||
module, shape_check_messages = shape_check_res
|
||||
module_serialized, xla_call_module_version = _serialize_module(module)
|
||||
assert xla_call_module_version == self.xla_call_module_version
|
||||
return module_serialized, shape_check_messages
|
||||
|
||||
def __str__(self):
|
||||
# This is called to make a MLIR source location when we call an Exported, and we
|
||||
# do not want the entire serialized module to end up in locations.
|
||||
@ -186,6 +247,10 @@ def poly_spec(
|
||||
convenience, zero or more trailing `_` can be abbreviated with `...`, and
|
||||
the surrounding parentheses may be missing.
|
||||
|
||||
Note that this function does not ensure that the provided `arg_shape`
|
||||
is compatible with `polymorphic_shape`. The `arg_shape` is used only
|
||||
to fill-in placeholders from `polymorphic_shape`.
|
||||
|
||||
See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
|
||||
for more details.
|
||||
|
||||
@ -215,6 +280,10 @@ def poly_specs(
|
||||
See [how optional parameters are matched to
|
||||
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
|
||||
|
||||
Note that this function does not ensure that the provided `args` shapes
|
||||
are compatible with `polymorphic_shapes`. The `args.shape` are used only
|
||||
to fill-in placeholders from `polymorphic_shapes`.
|
||||
|
||||
See docstring of `poly_spec` and
|
||||
[the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
|
||||
for more details.
|
||||
@ -290,9 +359,7 @@ def export(fun_jax: Callable,
|
||||
*args_specs, **kwargs_specs,
|
||||
_experimental_lowering_platform=lowering_platform_str)
|
||||
lowering = lowered._lowering # type: ignore
|
||||
|
||||
_check_lowering(lowering)
|
||||
|
||||
mlir_module = lowering.stablehlo()
|
||||
|
||||
args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals)
|
||||
@ -309,31 +376,7 @@ def export(fun_jax: Callable,
|
||||
mlir_module = _wrap_main_func(
|
||||
mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree
|
||||
)
|
||||
|
||||
xla_call_module_version = 6
|
||||
mlir_str = mlir.module_to_bytecode(mlir_module)
|
||||
if stablehlo.get_api_version() < 4:
|
||||
target_version = stablehlo.get_earliest_forward_compatible_version()
|
||||
else:
|
||||
# `target_version` is used to manage situations when a StableHLO producer
|
||||
# (in this case, jax2tf) and a StableHLO consumer were built using
|
||||
# different versions of StableHLO.
|
||||
#
|
||||
# Each StableHLO version `producer_version` has a compatibility window,
|
||||
# i.e. range of versions [`consumer_version_min`, `consumer_version_max`],
|
||||
# where StableHLO portable artifacts serialized by `producer_version`
|
||||
# can be deserialized by `consumer_version` within the window.
|
||||
# See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
|
||||
# for the exact extent of these compatibility guarantees.
|
||||
#
|
||||
# `stablehlo.get_minimum_version()` returns `consumer_version_min`
|
||||
# for the current version of StableHLO. We are using it here to maximize
|
||||
# forward compatibility, i.e. to maximize how far into the past we can go
|
||||
# and still have the payloads produced by `serialize_portable_artifact`
|
||||
# compatible with potential consumers from the past.
|
||||
target_version = stablehlo.get_minimum_version()
|
||||
mlir_module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
|
||||
mlir_str, target_version)
|
||||
mlir_module_serialized, xla_call_module_version = _serialize_module(mlir_module)
|
||||
|
||||
# Figure out the result types and shapes
|
||||
if "global_out_avals" in lowering.compile_args:
|
||||
@ -377,6 +420,34 @@ def export(fun_jax: Callable,
|
||||
return do_export
|
||||
|
||||
|
||||
def _serialize_module(module: ir.Module) -> Tuple[bytes, int]:
|
||||
xla_call_module_version = 6
|
||||
mlir_str = mlir.module_to_bytecode(module)
|
||||
if stablehlo.get_api_version() < 4:
|
||||
target_version = stablehlo.get_earliest_forward_compatible_version()
|
||||
else:
|
||||
# `target_version` is used to manage situations when a StableHLO producer
|
||||
# (in this case, jax2tf) and a StableHLO consumer were built using
|
||||
# different versions of StableHLO.
|
||||
#
|
||||
# Each StableHLO version `producer_version` has a compatibility window,
|
||||
# i.e. range of versions [`consumer_version_min`, `consumer_version_max`],
|
||||
# where StableHLO portable artifacts serialized by `producer_version`
|
||||
# can be deserialized by `consumer_version` within the window.
|
||||
# See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
|
||||
# for the exact extent of these compatibility guarantees.
|
||||
#
|
||||
# `stablehlo.get_minimum_version()` returns `consumer_version_min`
|
||||
# for the current version of StableHLO. We are using it here to maximize
|
||||
# forward compatibility, i.e. to maximize how far into the past we can go
|
||||
# and still have the payloads produced by `serialize_portable_artifact`
|
||||
# compatible with potential consumers from the past.
|
||||
target_version = stablehlo.get_minimum_version()
|
||||
module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
|
||||
mlir_str, target_version)
|
||||
return module_serialized, xla_call_module_version
|
||||
|
||||
|
||||
def _wrap_main_func(
|
||||
module: ir.Module,
|
||||
args_avals_flat: Sequence[core.ShapedArray],
|
||||
@ -392,7 +463,7 @@ def _wrap_main_func(
|
||||
the `args_avals`, in sorted order.
|
||||
|
||||
Consider the lowering of a function with one array argument of type "f32[w,
|
||||
h]",
|
||||
2 * h]",
|
||||
where "w" and "h" are two dimension variables. The `module` will also
|
||||
contain two dimension arguments, corresponding to "h" and "w" respectively:
|
||||
|
||||
@ -403,8 +474,9 @@ def _wrap_main_func(
|
||||
we rename "main" to "_wrapped_jax_export_main" and add a new "main":
|
||||
|
||||
func public main(arg: f32[?, ?]) {
|
||||
arg_h = hlo.get_dimension_size(arg, 1)
|
||||
arg_w = hlo.get_dimension_size(arg, 0)
|
||||
dim1 = hlo.get_dimension_size(arg, 1)
|
||||
arg_h = hlo.floordiv(dim1, 2)
|
||||
res = call _wrapped_jax_export_main(arg_h, arg_w, arg)
|
||||
return res
|
||||
}
|
||||
@ -424,11 +496,11 @@ def _wrap_main_func(
|
||||
Returns the wrapped module.
|
||||
"""
|
||||
dim_vars = shape_poly.all_dim_vars(args_avals_flat)
|
||||
# Make a new module, do not mutate the "module" because it may be cached
|
||||
context = mlir.make_ir_context()
|
||||
with context, ir.Location.unknown(context):
|
||||
new_module = ir.Module.parse(mlir.module_to_bytecode(module))
|
||||
symbol_table = ir.SymbolTable(new_module.operation)
|
||||
# Make a copy, do not mutate because it may be cached
|
||||
wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module))
|
||||
symbol_table = ir.SymbolTable(wrapped_module.operation)
|
||||
orig_main = symbol_table["main"]
|
||||
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
|
||||
symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main")
|
||||
@ -446,8 +518,9 @@ def _wrap_main_func(
|
||||
nr_array_args = len(orig_input_types) - len(dim_vars) - nr_token_args
|
||||
assert nr_array_args >= 0
|
||||
assert not any(is_token(attrs) for attrs in arg_attrs[-nr_array_args:])
|
||||
new_main_input_types = orig_input_types[-nr_array_args:]
|
||||
|
||||
# The order of args: dim args, token args, array args.
|
||||
new_main_input_types = orig_input_types[- nr_array_args:]
|
||||
dim_var_input_types = orig_input_types[:len(dim_vars)]
|
||||
orig_output_types = orig_main.type.results
|
||||
result_attrs = list(ir.ArrayAttr(orig_main.result_attrs))
|
||||
nr_token_results = sum(1 for attrs in result_attrs if is_token(attrs))
|
||||
@ -457,10 +530,9 @@ def _wrap_main_func(
|
||||
is_token(attrs) for attrs in result_attrs[-nr_array_results:]
|
||||
)
|
||||
new_main_output_types = orig_output_types[-nr_array_results:]
|
||||
|
||||
ftype = ir.FunctionType.get(new_main_input_types, new_main_output_types)
|
||||
new_main_ftype = ir.FunctionType.get(new_main_input_types, new_main_output_types)
|
||||
new_main_op = func_dialect.FuncOp(
|
||||
"main", ftype, ip=ir.InsertionPoint.at_block_begin(new_module.body))
|
||||
"main", new_main_ftype, ip=ir.InsertionPoint.at_block_begin(wrapped_module.body))
|
||||
new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public")
|
||||
try:
|
||||
new_main_op.arg_attrs = ir.ArrayAttr.get(arg_attrs[-nr_array_args:])
|
||||
@ -479,13 +551,21 @@ def _wrap_main_func(
|
||||
module_context = mlir.ModuleContext(
|
||||
"cpu", "cpu", sharding_impls.ShardingContext([]),
|
||||
source_info_util.new_name_stack(),
|
||||
[], itertools.count(1), [], module=new_module, context=context)
|
||||
[], itertools.count(1), [], module=wrapped_module, context=context)
|
||||
ctx = mlir.LoweringRuleContext(module_context=module_context,
|
||||
primitive=None, avals_in=args_avals_flat, avals_out=None,
|
||||
tokens_in=mlir.TokenSet(), tokens_out=None)
|
||||
dim_args = _compute_dim_args(ctx, args_avals_flat, tuple(new_main_op.arguments),
|
||||
orig_input_types[:len(dim_vars)],
|
||||
args_kwargs_tree=args_kwargs_tree)
|
||||
dim_values = mlir.lower_fun(
|
||||
functools.partial(shape_poly.compute_dim_vars_from_arg_shapes,
|
||||
args_avals_flat, args_kwargs_tree=args_kwargs_tree),
|
||||
multiple_results=True)(ctx, *new_main_op.arguments)
|
||||
|
||||
dim_args = []
|
||||
for dim_arg, dim_arg_type in zip(util.flatten(dim_values), dim_var_input_types):
|
||||
if dim_arg.type != dim_arg_type:
|
||||
dim_args.append(hlo.ConvertOp(dim_arg_type, dim_arg).result)
|
||||
else:
|
||||
dim_args.append(dim_arg)
|
||||
# The first arguments are the dimension variable
|
||||
orig_main_args.extend(dim_args)
|
||||
# Then the token arguments
|
||||
@ -497,40 +577,56 @@ def _wrap_main_func(
|
||||
orig_main_args)
|
||||
func_dialect.ReturnOp(call.results[-nr_array_results:])
|
||||
symbol_table.set_symbol_name(new_main_op, "main")
|
||||
return new_module
|
||||
|
||||
return wrapped_module
|
||||
|
||||
|
||||
def _compute_dim_args(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
args_avals_flat: Sequence[core.ShapedArray],
|
||||
array_args: Sequence[ir.Value],
|
||||
dim_arg_types: Sequence[ir.Type], *,
|
||||
args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[ir.Value]:
|
||||
"""Compute the values of the dimension arguments.
|
||||
def _make_shape_check_module(
|
||||
args_avals_flat: Sequence[core.AbstractValue],
|
||||
*,
|
||||
args_kwargs_tree: tree_util.PyTreeDef
|
||||
) -> Optional[Tuple[ir.Module, Sequence[str]]]:
|
||||
"""Codegens the shape checking function.
|
||||
|
||||
Args:
|
||||
args_avals_flat: the abstract values of the array arguments.
|
||||
array_args: the values of the array arguments.
|
||||
dim_arg_types: the desired types for the dimension arguments.
|
||||
args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for
|
||||
error messages.
|
||||
The shape checking function takes the array inputs and returns a triple.
|
||||
See `Exported.shape_check_module` docstring.
|
||||
|
||||
Returns:
|
||||
the values of the dimension variables, in the sorted order of the
|
||||
dimension variables.
|
||||
Returns the shape checking module and the tuple of shape checking messages.
|
||||
"""
|
||||
dim_values = mlir.lower_fun(
|
||||
functools.partial(shape_poly.compute_dim_vars_from_arg_shapes,
|
||||
args_avals_flat, args_kwargs_tree=args_kwargs_tree),
|
||||
multiple_results=True)(ctx, *array_args)
|
||||
context = mlir.make_ir_context()
|
||||
with context, ir.Location.unknown(context):
|
||||
array_input_types = tuple(mlir.aval_to_ir_type(a) for a in args_avals_flat)
|
||||
module = ir.Module.create(loc=ir.Location.unknown(context))
|
||||
|
||||
res = []
|
||||
for dim_arg, dim_arg_type in zip(util.flatten(dim_values), dim_arg_types):
|
||||
if dim_arg.type != dim_arg_type:
|
||||
res.append(hlo.ConvertOp(dim_arg_type, dim_arg).result)
|
||||
else:
|
||||
res.append(dim_arg)
|
||||
return tuple(res)
|
||||
symbol_table = ir.SymbolTable(module.operation)
|
||||
shape_check_code_type = mlir.aval_to_ir_type(core.ShapedArray((), np.int32))
|
||||
shape_check_output_types = (shape_check_code_type,) * 3
|
||||
shape_check_ftype = ir.FunctionType.get(array_input_types,
|
||||
shape_check_output_types)
|
||||
shape_check_func = func_dialect.FuncOp(
|
||||
"main", shape_check_ftype,
|
||||
ip=ir.InsertionPoint.at_block_begin(module.body))
|
||||
shape_check_func.attributes["sym_visibility"] = ir.StringAttr.get("public")
|
||||
symbol_table.insert(shape_check_func)
|
||||
entry_block = shape_check_func.add_entry_block()
|
||||
with ir.InsertionPoint(entry_block):
|
||||
module_context = mlir.ModuleContext(
|
||||
"cpu", "cpu", sharding_impls.ShardingContext([]),
|
||||
source_info_util.new_name_stack(),
|
||||
[], itertools.count(1), [], module=module, context=context)
|
||||
ctx = mlir.LoweringRuleContext(module_context=module_context,
|
||||
primitive=None, avals_in=args_avals_flat,
|
||||
avals_out=None, tokens_in=mlir.TokenSet(),
|
||||
tokens_out=None)
|
||||
acc_shape_check_messages: List[str] = []
|
||||
values = mlir.lower_fun(
|
||||
functools.partial(shape_poly.compute_shape_check_from_arg_shapes,
|
||||
args_avals_flat, args_kwargs_tree=args_kwargs_tree,
|
||||
acc_shape_check_messages=acc_shape_check_messages),
|
||||
multiple_results=True)(ctx, *shape_check_func.arguments)
|
||||
func_dialect.ReturnOp(util.flatten(values))
|
||||
|
||||
return (module, acc_shape_check_messages) if acc_shape_check_messages else None
|
||||
|
||||
|
||||
def _check_lowering(lowering) -> None:
|
||||
@ -707,9 +803,9 @@ def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported:
|
||||
# Expand in_shardings to all in_avals even not kept ones.
|
||||
all_in_shardings = [sharding_impls.UNSPECIFIED] * len(primal.in_avals)
|
||||
for idx, in_s in zip(sorted(primal.module_kept_var_idx),
|
||||
primal.in_shardings):
|
||||
primal.in_shardings): # type: ignore
|
||||
all_in_shardings[idx] = in_s # type: ignore
|
||||
all_shardings = all_in_shardings + list(primal.out_shardings)
|
||||
all_shardings = all_in_shardings + list(primal.out_shardings) # type: ignore
|
||||
# Cannot mix unspecified and specified shardings. Make the unspecified
|
||||
# ones replicated.
|
||||
specified_shardings = [
|
||||
@ -817,16 +913,28 @@ def _call_exported_abstract_eval(*in_avals: core.AbstractValue,
|
||||
if (not core.is_constant_dim(actual_aval.shape[dim_idx]) or
|
||||
aval_d != actual_aval.shape[dim_idx]):
|
||||
raise ValueError(
|
||||
f"Shape mismatch for {pp_arg_dim(dim_idx)} (expected constant): "
|
||||
f"Shape mismatch for {pp_arg_dim(dim_idx)} "
|
||||
"(expected same constant): "
|
||||
f"expected {exp_aval.shape} and called with {actual_aval.shape}")
|
||||
|
||||
# Must express the exported_dim_vars in terms of the shapes in in_avals.
|
||||
solution, shape_constraints, known_dim_vars = shape_poly.solve_dim_vars(
|
||||
solution, shape_constraints, synth_dim_vars = shape_poly.solve_dim_vars(
|
||||
exported.in_avals, args_kwargs_tree=exported.in_tree)
|
||||
known_env = {vname: in_avals[arg_idx].shape[dim_idx]
|
||||
for (vname, arg_idx, dim_idx) in known_dim_vars}
|
||||
shape_constraints.check(known_env)
|
||||
exported_dim_values = [solution[var].evaluate(known_env)
|
||||
synthetic_env = {vname: in_avals[arg_idx].shape[dim_idx]
|
||||
for (vname, arg_idx, dim_idx) in synth_dim_vars}
|
||||
# We discharge all the constraints statically. This results in much simpler
|
||||
# composability (because we do not have to worry about the constraints of the
|
||||
# Exported called recursively; we only need to worry about entry-point
|
||||
# constraints). This also makes sense from a composibility point of view,
|
||||
# because we get the same errors if we invoke the exported module, or if we
|
||||
# trace the exported function. Consider for example, an exported module with
|
||||
# signature `f32[a, a] -> f32[a]`. If we invoke the module with an argument
|
||||
# `f32[c, d]` it is better to fail because `c == d` is inconclusive, than
|
||||
# succeed and add a compile-time check that `c == d`. In the latter case,
|
||||
# it would be ambiguous whether we should continue tracing with a result
|
||||
# a type `f32[c]` or `f32[d]`.
|
||||
shape_constraints.check_statically(synthetic_env)
|
||||
exported_dim_values = [solution[var].evaluate(synthetic_env)
|
||||
for var in exported_dim_vars]
|
||||
return tuple(
|
||||
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars,
|
||||
@ -881,8 +989,9 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
kept_args)
|
||||
# The ctx.avals_out already contain the abstract values refined by
|
||||
# _call_exported_abstract_eval.
|
||||
return tuple(convert_shape(out, out_aval, refined_out_aval)
|
||||
for out, out_aval, refined_out_aval in zip(call.results, exported.out_avals, ctx.avals_out))
|
||||
return tuple(
|
||||
convert_shape(out, out_aval, refined_out_aval)
|
||||
for out, out_aval, refined_out_aval in zip(call.results, exported.out_avals, ctx.avals_out))
|
||||
|
||||
|
||||
for _p in ("cpu", "tpu", "cuda", "rocm"):
|
||||
|
@ -51,6 +51,7 @@ from jax.interpreters import xla
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src import tree_util
|
||||
@ -1164,20 +1165,61 @@ class ShapeConstraint:
|
||||
comp: Comparator
|
||||
left: DimSize
|
||||
right: DimSize
|
||||
# make_err_msg is invoked with (left_int, right_int) if the constraint fails.
|
||||
make_err_msg: Callable[[int, int], str]
|
||||
# make_err_msg lazily produces an error message that may refer to the
|
||||
# two operands being compared.
|
||||
make_err_msg: Callable[[Any, Any], str]
|
||||
|
||||
def check(self, shapeenv: DimVarEnv) -> None:
|
||||
"""Evaluates a constraint statically and raises an error if fails."""
|
||||
def eval_operand(o: DimSize) -> Union[int, jax.Array]:
|
||||
if core.is_constant_dim(o): return op.index(o)
|
||||
return o.evaluate(shapeenv) # type: ignore
|
||||
def _eval_operand(self, o: DimSize, shapeenv: DimVarEnv,
|
||||
) -> Union[int, jax.Array]:
|
||||
if core.is_constant_dim(o):
|
||||
res = op.index(o)
|
||||
else:
|
||||
res = o.evaluate(shapeenv) # type: ignore
|
||||
# Ensure we have np.int32
|
||||
res_dtype = dtypes.dtype(res)
|
||||
if res_dtype == np.int32 or is_poly_dim(res):
|
||||
return res
|
||||
if type(res) in dtypes.python_scalar_dtypes:
|
||||
return dtypes.coerce_to_array(res, np.int32) # type: ignore
|
||||
else:
|
||||
return res.astype(np.int32)
|
||||
|
||||
def _err_msg(self, left, right):
|
||||
return self.make_err_msg(str(left), str(right))
|
||||
|
||||
def check_statically(self, shapeenv: DimVarEnv) -> None:
|
||||
"""Evaluates a constraint statically.
|
||||
|
||||
The `shapeenv` maps variables to _DimExpr. If the static checking
|
||||
of the constraint fails, raise ValueError.
|
||||
"""
|
||||
left, right = [self._eval_operand(o, shapeenv) for o in [self.left,
|
||||
self.right]]
|
||||
try:
|
||||
left1, right1 = eval_operand(self.left), eval_operand(self.right)
|
||||
except KeyError:
|
||||
return None
|
||||
if self.comp == ShapeConstraint.Comparator.EQ:
|
||||
ok = (left == right)
|
||||
elif self.comp == ShapeConstraint.Comparator.GEQ:
|
||||
ok = (left >= right)
|
||||
else: assert False
|
||||
except InconclusiveDimensionOperation as e:
|
||||
raise ValueError(self.make_err_msg(left, right)) from e
|
||||
if not ok:
|
||||
raise ValueError(self.make_err_msg(left, right))
|
||||
|
||||
left_int, right_int = _is_known_constant(left1), _is_known_constant(right1)
|
||||
def compute(self,
|
||||
shapeenv: DimVarEnv
|
||||
) -> Optional[Tuple[jax.Array, jax.Array, jax.Array]]:
|
||||
"""Computes if the constraint is satisfied.
|
||||
|
||||
If the constraint can be resolved statically returns None
|
||||
or raises ValueError otherwise. If the constraint cannot be
|
||||
resolved statically, returns a triple with a boolean encoding if the
|
||||
constraint is satisfied and the int32 operands of binary constraints.
|
||||
"""
|
||||
left, right = [self._eval_operand(o, shapeenv) for o in [self.left,
|
||||
self.right]]
|
||||
# Try to evaluate the constraint statically.
|
||||
left_int, right_int = _is_known_constant(left), _is_known_constant(right)
|
||||
if left_int is not None and right_int is not None:
|
||||
if self.comp == ShapeConstraint.Comparator.EQ:
|
||||
if not (left_int == right_int):
|
||||
@ -1186,31 +1228,69 @@ class ShapeConstraint:
|
||||
if not (left_int >= right_int):
|
||||
raise ValueError(self.make_err_msg(left_int, right_int))
|
||||
else: assert False
|
||||
else:
|
||||
return None # TODO: evaluate constraint dynamically
|
||||
return None
|
||||
|
||||
if self.comp == ShapeConstraint.Comparator.EQ:
|
||||
is_ok = lax.eq(left, right)
|
||||
elif self.comp == ShapeConstraint.Comparator.GEQ:
|
||||
is_ok = lax.ge(left, right)
|
||||
else: assert False
|
||||
return is_ok, left, right # type: ignore
|
||||
|
||||
def __str__(self):
|
||||
return (f"{self.left} {'==' if self.comp == ShapeConstraint.Comparator.EQ else '>='} {self.right}"
|
||||
f" ({self.make_err_msg(self.left, self.right)})")
|
||||
f" ({self.make_err_msg('%1', '%2')})")
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
class ShapeConstraints:
|
||||
def __init__(self):
|
||||
self.constraints: Set[ShapeConstraint] = set() # map DimConstraint to an integer >= 0
|
||||
|
||||
self.constraints: List[ShapeConstraint] = []
|
||||
|
||||
def add_constraint(self,
|
||||
comp: ShapeConstraint.Comparator,
|
||||
left: DimSize, right: DimSize,
|
||||
make_err_msg: Callable[[int, int], str]):
|
||||
make_err_msg: Callable[[Any, Any], str]):
|
||||
# Try to evaluate it statically
|
||||
c = ShapeConstraint(comp, left, right, make_err_msg)
|
||||
self.constraints.add(c)
|
||||
self.constraints.append(c)
|
||||
|
||||
def check(self, shapeenv: DimVarEnv) -> None:
|
||||
def check_statically(self, shapeenv: DimVarEnv) -> None:
|
||||
"""Evaluates all the constraints statically.
|
||||
|
||||
The `shapeenv` maps variables to _DimExpr. If the static checking
|
||||
of any constraint fails, raise ValueError.
|
||||
"""
|
||||
for constraint in self.constraints:
|
||||
constraint.check(shapeenv)
|
||||
constraint.check_statically(shapeenv)
|
||||
|
||||
def compute(self, shapeenv: DimVarEnv) -> Tuple[jax.Array, jax.Array, jax.Array, Sequence[str]]:
|
||||
"""Computes the error code for the set of constraints.
|
||||
|
||||
The error code is -1 if all constraints are satisfied, or an index into
|
||||
the returned error messages.
|
||||
See Exported.shape_check_module docstring.
|
||||
"""
|
||||
# We want to report the errors in the same order as `check_statically`.
|
||||
# So, we process them in order, in case some fail statically, but we
|
||||
# accumulate the errors in reverse order in the computation, because the
|
||||
# code-generation strategy will report the last error that failed.
|
||||
acc: List[Tuple[jax.Array, jax.Array, jax.Array, str]] = []
|
||||
for constraint in self.constraints:
|
||||
check_res = constraint.compute(shapeenv)
|
||||
if check_res is not None:
|
||||
acc.append((*check_res, constraint.make_err_msg("%1", "%2"))) # type: ignore
|
||||
|
||||
shape_check_messages: List[str] = []
|
||||
shape_check_code: jax.Array = np.int32(-1) # type: ignore
|
||||
shape_check_op1 = shape_check_op2 = shape_check_code
|
||||
for (is_ok, op1, op2, msg) in reversed(acc):
|
||||
shape_check_code = lax.select(is_ok, shape_check_code,
|
||||
np.int32(len(shape_check_messages)))
|
||||
shape_check_op1 = lax.select(is_ok, shape_check_op1, op1)
|
||||
shape_check_op2 = lax.select(is_ok, shape_check_op2, op2)
|
||||
shape_check_messages.append(msg)
|
||||
return (shape_check_code, shape_check_op1, shape_check_op2, shape_check_messages)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -1256,50 +1336,52 @@ def solve_dim_vars(
|
||||
|
||||
args_avals = [ShapedArray((3, a, a + b), f32)]
|
||||
|
||||
we introduce fresh "known" dimension variables to represent the actual dimension
|
||||
size of actual arguments for each non-constant dimension. Each known variable
|
||||
has a name, an arg_idx, and a dim_idx, e.g.:
|
||||
we introduce fresh "synthetic" dimension variables to represent the actual
|
||||
dimension size of actual arguments for each non-constant dimension.
|
||||
Each synthetic variable has a name, an arg_idx, and a dim_idx, e.g.:
|
||||
|
||||
known_vars = [("args[0].shape[1]", 0, 1), ("args[0].shape[2]", 0, 2)]
|
||||
synthetic_vars = [("args[0].shape[1]", 0, 1), ("args[0].shape[2]", 0, 2)]
|
||||
|
||||
and then we express the solution for the unknown dimension variables {a, b}
|
||||
as symbolic expressions in terms of the known variables:
|
||||
as symbolic expressions in terms of the synthetic variables:
|
||||
|
||||
dict(a=args[0].shape[1], b=args[0].shape[2] - args[0].shape[1])
|
||||
|
||||
Not all equations are solvable. For now, we solve first the linear uni-variate
|
||||
equations, then the solved variables are used to simplify the remaining
|
||||
equations to linear uni-variate equations, and the process continues
|
||||
until all dimension variables are solved.
|
||||
Not all equations are solvable. For now, we solve first the linear
|
||||
uni-variate equations, then the solved variables are used to simplify the
|
||||
remaining equations to linear uni-variate equations, and the process
|
||||
continues until all dimension variables are solved.
|
||||
|
||||
Args:
|
||||
args_avals: the abstract values of the `args`, with shapes that may
|
||||
include unknown dimension variables.
|
||||
args_kwargs_tree: a PyTreeDef that describes the tuple `(args, kwargs)` from
|
||||
which the flat sequence `args_avals` is extracted. Used for describing
|
||||
args and kwargs in known variable names and in error messages.
|
||||
args_kwargs_tree: a PyTreeDef that describes the tuple `(args, kwargs)`
|
||||
from which the flat sequence `args_avals` is extracted. Used for
|
||||
describing args and kwargs in synthetic variable names and in
|
||||
error messages.
|
||||
|
||||
Returns: a 3-tuple with: (a) the solution for the unknown dimension variables
|
||||
(b) a list of constraints that must be satisfied for the solution to be a
|
||||
valid one, and (c) and the list of known variables that may appear in
|
||||
valid one, and (c) and the list of synthetic variables that may appear in
|
||||
the solution and the constraints.
|
||||
|
||||
Raises ValueError if it cannot solve some dimension variable.
|
||||
"""
|
||||
dim_equations: List[_DimEquation] = []
|
||||
known_dimension_vars: List[Tuple[str, int, int]] = []
|
||||
synth_dimension_vars: List[Tuple[str, int, int]] = []
|
||||
for arg_idx, aval in enumerate(args_avals):
|
||||
for dim_idx, aval_d in enumerate(aval.shape):
|
||||
if is_poly_dim(aval_d):
|
||||
known_dim_var = pretty_print_dimension_descriptor(args_kwargs_tree,
|
||||
synth_dim_var = pretty_print_dimension_descriptor(args_kwargs_tree,
|
||||
arg_idx, dim_idx)
|
||||
known_dimension_vars.append((known_dim_var, arg_idx, dim_idx))
|
||||
synth_dimension_vars.append((synth_dim_var, arg_idx, dim_idx))
|
||||
dim_equations.append(
|
||||
_DimEquation(dim_expr=_ensure_poly(aval_d, "solve_dim_vars"),
|
||||
dim_value=_DimExpr.from_var(known_dim_var)))
|
||||
dim_value=_DimExpr.from_var(synth_dim_var)))
|
||||
|
||||
solution, shape_constraints = _solve_dim_equations(dim_equations)
|
||||
return solution, shape_constraints, known_dimension_vars
|
||||
return solution, shape_constraints, synth_dimension_vars
|
||||
|
||||
|
||||
def compute_dim_vars_from_arg_shapes(
|
||||
args_avals: Sequence[core.AbstractValue],
|
||||
@ -1315,17 +1397,47 @@ def compute_dim_vars_from_arg_shapes(
|
||||
`all_dim_vars(args_avals)`.
|
||||
"""
|
||||
dim_vars = all_dim_vars(args_avals)
|
||||
solution, shape_constraints, known_dim_vars = solve_dim_vars(
|
||||
solution, shape_constraints, synth_dim_vars = solve_dim_vars(
|
||||
tuple(args_avals), args_kwargs_tree=args_kwargs_tree)
|
||||
|
||||
# Replace the synthetic vars with the dynamic shape of the actual arg
|
||||
known_env = {vname: dimension_size_p.bind(actual_args[arg_idx], dimension=dim_idx)
|
||||
for (vname, arg_idx, dim_idx) in known_dim_vars}
|
||||
dim_values = [solution[var].evaluate(known_env) for var in dim_vars]
|
||||
shape_constraints.check(known_env)
|
||||
synthetic_env = {vname: dimension_size_p.bind(actual_args[arg_idx],
|
||||
dimension=dim_idx)
|
||||
for (vname, arg_idx, dim_idx) in synth_dim_vars}
|
||||
dim_values = [solution[var].evaluate(synthetic_env) for var in dim_vars]
|
||||
return tuple(dim_values)
|
||||
|
||||
|
||||
def compute_shape_check_from_arg_shapes(
|
||||
args_avals: Sequence[core.AbstractValue],
|
||||
*actual_args: jax.Array,
|
||||
args_kwargs_tree: tree_util.PyTreeDef,
|
||||
acc_shape_check_messages: List[str]) -> Sequence[jax.Array]:
|
||||
"""Computes the shape check code from the actual arguments.
|
||||
|
||||
`acc_shape_check_messages` is an initially empty list where we append the
|
||||
shape checking messages. Each of these correspond to a constraint.
|
||||
See the `Exported.shape_check_module` docstring.
|
||||
|
||||
Returns: a triple of JAX arrays: the shape check code and the values of the
|
||||
two operands for the failed constraint.
|
||||
|
||||
Raises ValueError if a constraint is invalidated statically.
|
||||
"""
|
||||
assert not acc_shape_check_messages
|
||||
solution, shape_constraints, synth_dim_vars = solve_dim_vars(
|
||||
tuple(args_avals), args_kwargs_tree=args_kwargs_tree)
|
||||
|
||||
# Replace the synthetic vars with the dynamic shape of the actual arg
|
||||
synthetic_env = {vname: dimension_size_p.bind(actual_args[arg_idx],
|
||||
dimension=dim_idx)
|
||||
for (vname, arg_idx, dim_idx) in synth_dim_vars}
|
||||
shape_check_code, shape_check_op1, shape_check_op2, shape_check_messages = (
|
||||
shape_constraints.compute(synthetic_env))
|
||||
acc_shape_check_messages.extend(shape_check_messages)
|
||||
return (shape_check_code, shape_check_op1, shape_check_op2)
|
||||
|
||||
|
||||
def _solve_dim_equations(
|
||||
eqns: List[_DimEquation]
|
||||
) -> Tuple[DimVarEnv, ShapeConstraints]:
|
||||
@ -1375,16 +1487,16 @@ def _solve_dim_equations(
|
||||
var_value, var_remainder = divmod(dim_value, core.dim_constant(factor_var)) # type: ignore
|
||||
shape_constraints.add_constraint(
|
||||
ShapeConstraint.Comparator.EQ, var_remainder, 0,
|
||||
make_err_msg=lambda rem_int, _: (
|
||||
make_err_msg=lambda left, _: (
|
||||
f"Dimension variable '{var}' must have integer value >= 1. "
|
||||
f"Non-zero remainder {rem_int} for factor {factor_var} when solving "
|
||||
f"Non-zero remainder {left} for factor {factor_var} when solving "
|
||||
f"{eqn}.{_shapeenv_to_str()}"))
|
||||
|
||||
shape_constraints.add_constraint(
|
||||
ShapeConstraint.Comparator.GEQ, var_value, 1,
|
||||
make_err_msg=lambda var_int, _: (
|
||||
make_err_msg=lambda left, _: (
|
||||
f"Dimension variable '{var}' must have integer value >= 1. "
|
||||
f"Found {var_int} when "
|
||||
f"Found {left} when "
|
||||
f"solving {eqn}.{_shapeenv_to_str()}"))
|
||||
|
||||
if not isinstance(var_value, _DimExpr):
|
||||
@ -1396,8 +1508,8 @@ def _solve_dim_equations(
|
||||
shape_constraints.add_constraint(
|
||||
ShapeConstraint.Comparator.EQ, eqn.dim_value,
|
||||
eqn.dim_expr.evaluate(shapeenv),
|
||||
make_err_msg=lambda val1, val2: (
|
||||
f"Found inconsistency {val1} != {val2} when solving {eqn}.{_shapeenv_to_str()}"))
|
||||
make_err_msg=lambda left, right: (
|
||||
f"Found inconsistency {left} != {right} when solving {eqn}.{_shapeenv_to_str()}"))
|
||||
return True
|
||||
|
||||
while True:
|
||||
|
@ -15,31 +15,150 @@ import contextlib
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from typing import List
|
||||
from typing import List, Optional, Sequence
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax.config import config
|
||||
from jax.experimental.jax2tf import jax_export
|
||||
try:
|
||||
from jax.experimental.jax2tf import jax2tf # TODO: temporary
|
||||
except ImportError:
|
||||
jax2tf = None # type: ignore
|
||||
|
||||
from jax.lib import xla_client as xc
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import xla_extension
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
if xc.mlir_api_version >= 50:
|
||||
def _temp_call_exported(exp: jax_export.Exported, *args: jax.Array,
|
||||
skip_shape_check: bool = False):
|
||||
assert all(core.is_constant_shape(a.shape) for a in args)
|
||||
return jax_export.call_exported(exp)(*args)
|
||||
else:
|
||||
def _temp_call_exported(exp: jax_export.Exported, *args: jax.Array,
|
||||
skip_shape_check: bool = False):
|
||||
"""Temporary runner for an Exported.
|
||||
|
||||
Normally we would use jax_export.call_exported, but if the exported has
|
||||
shape polymorphism and we are using jaxlib before 0.4.12 we use
|
||||
Once we upgrade the jaxlib we can replace all uses of this function with
|
||||
jax_export.call_exported.
|
||||
"""
|
||||
assert all(core.is_constant_shape(a.shape) for a in args)
|
||||
if not exp.module_uses_dim_vars:
|
||||
return jax_export.call_exported(exp)(*args)
|
||||
else:
|
||||
# We only get here in external tests, because internal ones use newest jaxlib
|
||||
from jax.experimental.jax2tf import jax2tf # TODO: temporary
|
||||
# call_exported does the shape checking, we must do it manually for
|
||||
# XlaCallModule.
|
||||
if not skip_shape_check:
|
||||
shape_check = exp.shape_check_module()
|
||||
if shape_check is not None:
|
||||
err_msg = _run_shape_check_module(exp, shape_check[0], shape_check[1], args)
|
||||
if err_msg is not None:
|
||||
raise ValueError(err_msg)
|
||||
|
||||
numpy_results = map(lambda res_tf: res_tf.numpy(),
|
||||
jax2tf._run_exported_as_tf(args, exp))
|
||||
return exp.out_tree.unflatten(numpy_results)
|
||||
|
||||
def _run_shape_check_module(primal_exported: jax_export.Exported,
|
||||
shape_check_module_serialized: bytes,
|
||||
shape_check_messages: Sequence[str],
|
||||
args: Sequence[jax.Array]
|
||||
) -> Optional[str]:
|
||||
"""Helper to run a shape checking module.
|
||||
|
||||
We only need to do this in tests, because otherwise this will be done
|
||||
implicitly when we call XlaCallModule.
|
||||
|
||||
Returns: the error message, or None if no error.
|
||||
"""
|
||||
args = tuple(args)
|
||||
# We cannot just make an Exported and run it, because call_exported will
|
||||
# do the shape checks statically. So, we wrap the shape check module
|
||||
# with static shape arguments
|
||||
|
||||
static_in_avals = tuple(core.get_aval(a) for a in args)
|
||||
context = mlir.make_ir_context()
|
||||
with context, ir.Location.unknown(context):
|
||||
wrapped_module = ir.Module.parse(
|
||||
xla_extension.mlir.deserialize_portable_artifact(
|
||||
shape_check_module_serialized))
|
||||
symbol_table = ir.SymbolTable(wrapped_module.operation)
|
||||
orig_main = symbol_table["main"]
|
||||
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
|
||||
symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main")
|
||||
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value
|
||||
# Use static shapes
|
||||
new_main_input_types = [mlir.aval_to_ir_type(a) for a in static_in_avals]
|
||||
orig_output_types = orig_main.type.results
|
||||
new_main_ftype = ir.FunctionType.get(
|
||||
new_main_input_types, orig_output_types
|
||||
)
|
||||
new_main_op = func_dialect.FuncOp(
|
||||
"main",
|
||||
new_main_ftype,
|
||||
ip=ir.InsertionPoint.at_block_begin(wrapped_module.body),
|
||||
)
|
||||
new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public")
|
||||
symbol_table.insert(new_main_op)
|
||||
entry_block = new_main_op.add_entry_block()
|
||||
with ir.InsertionPoint(entry_block):
|
||||
orig_main_args: List[ir.Value] = []
|
||||
for new_arg, orig_arg_type in zip(
|
||||
new_main_op.arguments, orig_main.type.inputs
|
||||
):
|
||||
orig_main_args.append(hlo.ConvertOp(orig_arg_type, new_arg).result)
|
||||
call = func_dialect.CallOp(
|
||||
orig_output_types,
|
||||
ir.FlatSymbolRefAttr.get(orig_main_name),
|
||||
orig_main_args,
|
||||
)
|
||||
func_dialect.ReturnOp(call.results)
|
||||
symbol_table.set_symbol_name(new_main_op, "main")
|
||||
|
||||
wrapped_module_serialized, version = jax_export._serialize_module(wrapped_module)
|
||||
# Make an Exported and then run it
|
||||
out_avals = (core.ShapedArray((), dtype=np.int32),) * 3
|
||||
exp = jax_export.Exported(
|
||||
fun_name=f"shape_check_{primal_exported.fun_name}",
|
||||
in_tree=tree_util.tree_flatten((args, {}))[1],
|
||||
in_avals=static_in_avals,
|
||||
out_tree=tree_util.tree_flatten(out_avals)[1],
|
||||
out_avals=out_avals,
|
||||
in_shardings=None,
|
||||
out_shardings=None,
|
||||
lowering_platform=primal_exported.lowering_platform,
|
||||
disabled_checks=(),
|
||||
mlir_module_serialized=wrapped_module_serialized,
|
||||
xla_call_module_version=version,
|
||||
module_kept_var_idx=tuple(sorted(range(len(static_in_avals)))),
|
||||
module_uses_dim_vars=True,
|
||||
_get_vjp=lambda _: None) # type: ignore
|
||||
|
||||
code, op1, op2 = _temp_call_exported(exp, *args,
|
||||
skip_shape_check=True)
|
||||
if code == -1:
|
||||
return None
|
||||
else:
|
||||
return shape_check_messages[code].replace("%1", str(int(op1))).replace(
|
||||
"%2", str(int(op2)))
|
||||
|
||||
|
||||
class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
def test_basic_export_only(self):
|
||||
@ -171,9 +290,10 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
r"Dtype mismatch for args\[0\]"):
|
||||
jax_export.call_exported(exp_f)(f32_4.astype(np.float16), b=f32_4)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=p, platform=p)
|
||||
for p in ("cpu", "cuda", "rocm", "tpu"))
|
||||
@jtu.parameterized_filterable(
|
||||
testcase_name=lambda kw: kw["platform"],
|
||||
kwargs=[dict(platform=p)
|
||||
for p in ("cpu", "cuda", "rocm", "tpu")])
|
||||
def test_error_wrong_platform(self, platform):
|
||||
a = np.arange(4, dtype=np.float32)
|
||||
|
||||
@ -239,6 +359,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct)
|
||||
self.assertAllClose(jax_vjp, exp_vjp)
|
||||
|
||||
|
||||
def test_roundtrip(self):
|
||||
def f1(x):
|
||||
return jnp.sin(x)
|
||||
@ -253,150 +374,199 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))),
|
||||
jax_export.call_exported(exp_f2)(a))
|
||||
|
||||
# A function is exported with f32[poly_spec] and is called with different arg
|
||||
# shapes. We use jax_export.call_exported and we also run the shape check
|
||||
# module.
|
||||
@jtu.parameterized_filterable(
|
||||
testcase_name=lambda kw:f"poly_spec={kw['poly_spec']}_arg_shape={kw['arg_shape']}", # type: ignore
|
||||
#one_containing="",
|
||||
kwargs=[
|
||||
dict(poly_spec="3,4,12", arg_shape=(3, 4, 12)),
|
||||
dict(poly_spec="3,4,12", arg_shape=(3, 4, 13),
|
||||
# The shape check module does not test constant dimensions
|
||||
expect_error_run=re.escape(
|
||||
r"Shape mismatch for args[0].shape[2] (expected same constant)")),
|
||||
dict(poly_spec="3,4,6*a", arg_shape=(3, 4, 12)),
|
||||
dict(poly_spec="3,a,a+8", arg_shape=(3, 4, 12)),
|
||||
dict(poly_spec="3,4,a+1", arg_shape=(3, 4, 1),
|
||||
expect_error=re.escape(
|
||||
r"Dimension variable 'a' must have integer "
|
||||
r"value >= 1. Found 0 when solving "
|
||||
r"a + 1 == args[0].shape[2].")),
|
||||
dict(poly_spec="3,4,6*a", arg_shape=(3, 4, 13),
|
||||
expect_error=re.escape(
|
||||
r"Dimension variable 'a' must have integer value >= 1. "
|
||||
r"Non-zero remainder 1 for factor 6 when solving "
|
||||
r"6*a == args[0].shape[2]")),
|
||||
dict(poly_spec="3,a,a+8", arg_shape=(3, 4, 13),
|
||||
expect_error=re.escape(
|
||||
r"Found inconsistency 13 != 12 when solving "
|
||||
r"a + 8 == args[0].shape[2]")),
|
||||
])
|
||||
def test_poly_shape_checks(
|
||||
self, poly_spec="3,a,a+8",
|
||||
arg_shape=(3, 4, 12), arg_dtype=np.float32,
|
||||
expect_error=None, # If given, applies for expect_error_run and expect_error_shape_check
|
||||
expect_error_run=None, # Error from running the exported module
|
||||
expect_error_shape_check=None): # Error from running the shape check module
|
||||
if expect_error is not None:
|
||||
self.assertIsNone(expect_error_run, None)
|
||||
self.assertIsNone(expect_error_shape_check, None)
|
||||
expect_error_run = expect_error_shape_check = expect_error
|
||||
def f(x): # x: f32[poly_spec]
|
||||
return jnp.reshape(x, (-1, x.shape[1]))
|
||||
|
||||
exp_f = jax_export.export(f)(
|
||||
jax_export.poly_spec((3, 4, 12), np.float32, poly_spec))
|
||||
self.assertEqual(exp_f.module_uses_dim_vars, poly_spec != "3,4,12")
|
||||
arg = np.arange(np.prod(arg_shape),
|
||||
dtype=arg_dtype).reshape(arg_shape) # arg : f32[3,4,12]
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
if expect_error_run is not None:
|
||||
stack.push(self.assertRaisesRegex(Exception, expect_error_run))
|
||||
|
||||
res = _temp_call_exported(exp_f, arg)
|
||||
|
||||
if not expect_error_run:
|
||||
self.assertAllClose(res, f(arg))
|
||||
|
||||
# Test the shape_check_module
|
||||
shape_check = exp_f.shape_check_module()
|
||||
# We have shape_check only if the exported has polymorphic inputs shapes.
|
||||
if all(core.is_constant_shape(a.shape) for a in exp_f.in_avals):
|
||||
self.assertIsNone(shape_check)
|
||||
self.assertIsNone(expect_error_shape_check)
|
||||
else:
|
||||
self.assertIsNotNone(shape_check)
|
||||
shape_check_module, shape_check_messages = shape_check
|
||||
err_msg = _run_shape_check_module(exp_f,
|
||||
shape_check_module, shape_check_messages, (arg,))
|
||||
|
||||
if expect_error_shape_check is None:
|
||||
self.assertIsNone(err_msg)
|
||||
else:
|
||||
self.assertRegex(err_msg, expect_error_shape_check)
|
||||
|
||||
|
||||
# An inner function is exported with polymorphic shapes inner_poly_spec, and
|
||||
# is called from an outer function, which is exported with outer_poly_spec.
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"inner={d['inner_poly_spec']}_outer={d['outer_poly_spec']}", # type: ignore
|
||||
**d) # type: ignore
|
||||
for d in (
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,12"),
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,c"),
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c,c",
|
||||
expect_error=re.escape(
|
||||
r"Dimension variable 'b' must have integer value >= 1. "
|
||||
r"Found 0 when solving a + b == args[0].shape[2]")),
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="c,4,12",
|
||||
expect_error=r"Shape mismatch for args\[0\].shape\[0\] \(expected constant\)"),
|
||||
# dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c+4,12"), # TODO: there should be an error
|
||||
dict(inner_poly_spec="3,4,3*a", outer_poly_spec="3,4,12"),
|
||||
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,12",
|
||||
expect_error=re.escape(
|
||||
r"Dimension variable 'a' must have integer value >= 1. "
|
||||
r"Non-zero remainder 2 for factor 5 when solving 5*a == args[0].shape[2]")),
|
||||
# dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,c"), # TODO: there should be an error 5*a != c == 12
|
||||
# dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a"), # TODO: there should be an error 12 != 4
|
||||
dict(inner_poly_spec="3,a", inner_x_shape=(3, 4), outer_poly_spec="3,a,a",
|
||||
expect_error=r"Rank mismatch for args\[0\]"),
|
||||
dict(inner_poly_spec="3,a,a+b", inner_x_dtype=np.int32, outer_poly_spec="3,c,d",
|
||||
expect_error=r"Dtype mismatch for args\[0\]"),
|
||||
))
|
||||
def test_poly(self, inner_poly_spec="3,a,a", inner_x_shape=(3, 4, 6),
|
||||
inner_x_dtype=np.float32,
|
||||
outer_poly_spec="3,a,a", outer_x_shape=(3, 4, 12),
|
||||
expect_error=None):
|
||||
@jtu.parameterized_filterable(
|
||||
testcase_name=lambda kw:f"inner={kw['inner_poly_spec']}_outer={kw['outer_poly_spec']}", # type: ignore
|
||||
#one_containing="",
|
||||
# By default arg_shape = (3, 4, 12) for both the outer function and the inner
|
||||
# The inner function is exported for f32.
|
||||
kwargs=[
|
||||
# Both inner and outer are static shapes
|
||||
dict(inner_poly_spec="3,4,12", outer_poly_spec="3,4,12"),
|
||||
# Inner has poly shapes but outer has static shapes
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,12"),
|
||||
dict(inner_poly_spec="3,4,3*a", outer_poly_spec="3,4,12"),
|
||||
dict(inner_poly_spec="3,a,a", outer_poly_spec="3,4,12",
|
||||
expect_error_outer_exp=re.escape(
|
||||
r"Found inconsistency 12 != 4 when solving a == args[0].shape[2]")),
|
||||
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,12",
|
||||
expect_error_outer_exp=re.escape(
|
||||
r"Dimension variable 'a' must have integer value >= 1. "
|
||||
r"Non-zero remainder 2 for factor 5 when solving 5*a == args[0].shape[2]")),
|
||||
dict(inner_poly_spec="3,4,12+a", outer_poly_spec="3,4,12",
|
||||
expect_error_outer_exp=re.escape(
|
||||
r"Dimension variable 'a' must have integer value >= 1. "
|
||||
r"Found 0 when solving a + 12 == args[0].shape[2]")),
|
||||
# Both inner and outer have poly shapes.
|
||||
dict(inner_poly_spec="3,a,b", outer_poly_spec="3,4,c"),
|
||||
dict(inner_poly_spec="3,4,3*a", outer_poly_spec="3,4,6*c"),
|
||||
dict(inner_poly_spec="3,a,a+8", outer_poly_spec="3,c+2,c+10"),
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,c",
|
||||
expect_error_outer_exp=re.escape(
|
||||
r"Dimension variable 'b' must have integer value >= 1. "
|
||||
r"Found c + -4 when solving a + b == args[0].shape[2]")),
|
||||
dict(inner_poly_spec="3,a,a", outer_poly_spec="3,4,c",
|
||||
expect_error_outer_exp=re.escape(
|
||||
r"Found inconsistency c != 4 when solving a == args[0].shape[2]"
|
||||
)),
|
||||
dict(inner_poly_spec="3,a,a", arg_shape=(3, 4),
|
||||
outer_poly_spec="3,c",
|
||||
expect_error_outer_exp=r"Rank mismatch for args\[0\]"),
|
||||
dict(inner_poly_spec="3,a,a+b", arg_dtype=np.int32,
|
||||
outer_poly_spec="3,c,d",
|
||||
expect_error_outer_exp=r"Dtype mismatch for args\[0\]"),
|
||||
|
||||
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,c",
|
||||
expect_error_outer_exp=re.escape(
|
||||
r"Dimension variable 'a' must have integer value >= 1. "
|
||||
r"Non-zero remainder mod(c, 5) for factor 5 when solving 5*a == args[0].shape[2]"
|
||||
)),
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c,c",
|
||||
expect_error_outer_exp=re.escape(
|
||||
r"Dimension variable 'b' must have integer value >= 1. "
|
||||
r"Found 0 when solving a + b == args[0].shape[2]")),
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="c,4,12",
|
||||
expect_error_outer_exp=re.escape(
|
||||
r"Shape mismatch for args[0].shape[0] (expected same constant)")),
|
||||
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,25*c",
|
||||
expect_error_run=re.escape(
|
||||
r"Dimension variable 'c' must have integer value >= 1. "
|
||||
r"Non-zero remainder 12 for factor 25 when solving 25*c == args[0].shape[2]")),
|
||||
dict(inner_poly_spec="3,a,b", outer_poly_spec="3,c+4,12",
|
||||
expect_error_run=re.escape(
|
||||
r"Dimension variable 'c' must have integer value >= 1. "
|
||||
r"Found 0 when solving c + 4 == args[0].shape[1]")),
|
||||
dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a",
|
||||
expect_error_run=re.escape(
|
||||
r"Found inconsistency 12 != 4 when solving "
|
||||
r"a == args[0].shape[2]")),
|
||||
])
|
||||
def test_poly_shape_checks_nested(
|
||||
self, inner_poly_spec="3,4,5*a",
|
||||
arg_shape=(3, 4, 12), arg_dtype=np.float32,
|
||||
outer_poly_spec="3,4,25*c",
|
||||
expect_error_outer_exp=None,
|
||||
expect_error_run=None):
|
||||
# Polymorphic export called with static or polymorphic shapes
|
||||
def inner(x): # x: inner_poly_spec
|
||||
return jnp.reshape(x, (-1, x.shape[1]))
|
||||
|
||||
inner_x = np.arange(np.prod(inner_x_shape),
|
||||
dtype=inner_x_dtype).reshape(inner_x_shape) # inner_x : f32[3,4,6]
|
||||
arg = np.arange(np.prod(arg_shape),
|
||||
dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12]
|
||||
inner_exp = jax_export.export(inner)(
|
||||
jax_export.poly_spec(inner_x.shape, inner_x.dtype, inner_poly_spec))
|
||||
jax_export.poly_spec((3, 4, 12), np.float32, inner_poly_spec))
|
||||
|
||||
self.assertEqual(inner_exp.module_uses_dim_vars,
|
||||
(inner_poly_spec != "3,4,12"))
|
||||
outer_x = np.arange(np.prod(outer_x_shape),
|
||||
dtype=np.float32).reshape(outer_x_shape) # outer_x : f32[3,4,12]
|
||||
def outer(x): # x: outer_poly_spec
|
||||
# Use an addition to test that the shapes are refined properly for the
|
||||
# result of the call_exported.
|
||||
return jax_export.call_exported(inner_exp)(x) + inner(x)
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
if expect_error is not None:
|
||||
stack.push(self.assertRaisesRegex(ValueError, expect_error))
|
||||
if expect_error_outer_exp is not None:
|
||||
stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp))
|
||||
|
||||
# Call it after exporting again, with polymorphic shapes
|
||||
outer_exp = jax_export.export(outer)(
|
||||
jax_export.poly_spec(outer_x.shape, outer_x.dtype, outer_poly_spec))
|
||||
self.assertEqual(outer_exp.module_uses_dim_vars,
|
||||
(inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12"))
|
||||
# TODO(necula): need conditionals until jaxlib 0.4.12 is the minimum version
|
||||
if not outer_exp.module_uses_dim_vars or xc.mlir_api_version >= 50:
|
||||
res = jax_export.call_exported(outer_exp)(outer_x)
|
||||
self.assertAllClose(2. * inner(outer_x), res)
|
||||
else:
|
||||
# TODO: for now, we use XlaCallModule to run modules with polymorphic shapes
|
||||
# until we create the python bindings to invoke shape refinement.
|
||||
if jax2tf is not None:
|
||||
res = jax2tf._run_exported_as_tf([outer_x], outer_exp)[0].numpy()
|
||||
self.assertAllClose(2. * inner(outer_x), res)
|
||||
jax_export.poly_spec(arg.shape, arg.dtype, outer_poly_spec))
|
||||
|
||||
def test_call_poly(self):
|
||||
a_shape = (3, 4)
|
||||
a = np.arange(math.prod(a_shape), dtype=np.float32).reshape(a_shape)
|
||||
if expect_error_outer_exp is not None:
|
||||
return
|
||||
|
||||
def f_inner(x): # x: f32[w, h]
|
||||
return jnp.reshape(x, (-1,))
|
||||
self.assertEqual(outer_exp.module_uses_dim_vars,
|
||||
(inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12"))
|
||||
shape_check = outer_exp.shape_check_module()
|
||||
if all(core.is_constant_shape(a.shape) for a in outer_exp.in_avals):
|
||||
self.assertIsNone(shape_check)
|
||||
else:
|
||||
self.assertIsNotNone(shape_check)
|
||||
|
||||
exp_inner = jax_export.export(f_inner)(
|
||||
jax_export.poly_spec(a.shape, a.dtype, "w, h")
|
||||
)
|
||||
with contextlib.ExitStack() as stack:
|
||||
if expect_error_run is not None:
|
||||
stack.push(self.assertRaisesRegex(Exception, expect_error_run))
|
||||
|
||||
# There are dynamic shapes in the exported module
|
||||
self.assertIn("?x", exp_inner.mlir_module)
|
||||
self.assertIn("stablehlo.dynamic_reshape", exp_inner.mlir_module)
|
||||
res = _temp_call_exported(outer_exp, arg)
|
||||
|
||||
# Add a wrapper "main" func with static shapes
|
||||
# TODO(necula): We will add this functionality to jax_export.
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
|
||||
context = mlir.make_ir_context()
|
||||
with context, ir.Location.unknown(context):
|
||||
wrapped_module = ir.Module.parse(exp_inner.mlir_module)
|
||||
symbol_table = ir.SymbolTable(wrapped_module.operation)
|
||||
orig_main = symbol_table["main"]
|
||||
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
|
||||
symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main")
|
||||
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value
|
||||
# Use static shapes
|
||||
new_main_input_types = [
|
||||
mlir.aval_to_ir_type(core.ShapedArray((3, 4), np.float32))
|
||||
]
|
||||
orig_output_types = orig_main.type.results
|
||||
new_main_ftype = ir.FunctionType.get(
|
||||
new_main_input_types, orig_output_types
|
||||
)
|
||||
new_main_op = func_dialect.FuncOp(
|
||||
"main",
|
||||
new_main_ftype,
|
||||
ip=ir.InsertionPoint.at_block_begin(wrapped_module.body),
|
||||
)
|
||||
new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public")
|
||||
symbol_table.insert(new_main_op)
|
||||
entry_block = new_main_op.add_entry_block()
|
||||
with ir.InsertionPoint(entry_block):
|
||||
orig_main_args: List[ir.Value] = []
|
||||
for new_arg, orig_arg_type in zip(
|
||||
new_main_op.arguments, orig_main.type.inputs
|
||||
):
|
||||
orig_main_args.append(hlo.ConvertOp(orig_arg_type, new_arg).result)
|
||||
call = func_dialect.CallOp(
|
||||
orig_output_types,
|
||||
ir.FlatSymbolRefAttr.get(orig_main_name),
|
||||
orig_main_args,
|
||||
)
|
||||
func_dialect.ReturnOp(call.results)
|
||||
symbol_table.set_symbol_name(new_main_op, "main")
|
||||
|
||||
# TODO(necula): need conditionals until jaxlib 0.4.12 is the minimum version
|
||||
if xc.mlir_api_version >= 50:
|
||||
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
|
||||
mlir.module_to_bytecode(wrapped_module)
|
||||
)
|
||||
context = mlir.make_ir_context()
|
||||
with context:
|
||||
refined_module = ir.Module.parse(refined_module_str)
|
||||
|
||||
logging.info("Postprocessed module %s", str(refined_module))
|
||||
self.assertNotIn("?x", str(refined_module))
|
||||
self.assertNotIn("stablehlo.dynamic_reshape", str(refined_module))
|
||||
self.assertIn("stablehlo.reshape", str(refined_module))
|
||||
if expect_error_run is not None:
|
||||
return
|
||||
self.assertAllClose(2. * inner(arg), res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -841,7 +841,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
f_tf: Callable[..., Any] = jax2tf.convert(f_jax, polymorphic_shapes=["b, ..."])
|
||||
self.assertAllClose(f_jax(x, y=y), f_tf(x, y=y))
|
||||
|
||||
def test_arg_avals(self):
|
||||
def test_arg_avals_non_native(self):
|
||||
"""Test conversion of actual arguments to abstract values."""
|
||||
|
||||
def check_avals(*, arg_shapes: Sequence[Sequence[Optional[int]]],
|
||||
@ -857,7 +857,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
dim_vars = shape_poly.all_dim_vars(avals)
|
||||
dim_values, _ = jax2tf.jax2tf._interpret_fun_jax(
|
||||
partial(shape_poly.compute_dim_vars_from_arg_shapes,
|
||||
avals, args_kwargs_tree=tree_util.tree_flatten((avals, {}))[1]),
|
||||
avals,
|
||||
args_kwargs_tree=tree_util.tree_flatten((avals, {}))[1]),
|
||||
args_tf, avals, "")
|
||||
if expected_avals is not None:
|
||||
self.assertEqual(expected_avals, avals)
|
||||
@ -958,55 +959,55 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
eager_mode=True,
|
||||
expected_shapeenv=dict(a=2, b=3, c=4))
|
||||
|
||||
def test_arg_avals_errors(self):
|
||||
"""Test error reporting for shape polymorpish."""
|
||||
def conv_and_run(*, arg_shape: core.Shape,
|
||||
polymorphic_shape: str):
|
||||
arg = np.arange(math.prod(arg_shape), dtype=np.float32).reshape(arg_shape)
|
||||
jax2tf.convert(lambda x: x, polymorphic_shapes=[polymorphic_shape])(arg)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
re.escape("polymorphic shape spec should be")):
|
||||
conv_and_run(arg_shape=(2,), polymorphic_shape=5.)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
re.escape("pytree structure error: different types")):
|
||||
conv_and_run(arg_shape=(2,), polymorphic_shape=["a list"])
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
re.escape("pytree structure error: different types")):
|
||||
conv_and_run(arg_shape=(2,), polymorphic_shape=("a tuple",))
|
||||
|
||||
# The following do not work yet with native serialization because
|
||||
# XlaCallModule does not yet do shape checking.
|
||||
if config.jax2tf_default_native_serialization:
|
||||
return
|
||||
|
||||
# TODO(necula): enable even for native serialization
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Cannot solve for values of dimension variables {'b'}"):
|
||||
check_avals(
|
||||
arg_shapes=[(4, 36, 3)],
|
||||
polymorphic_shapes=[PS("b * b", "b * d * d", "d")])
|
||||
conv_and_run(arg_shape=(4, 36, 3), polymorphic_shape="b * b, b * d * d, d")
|
||||
|
||||
# TODO(necula): enable even for native serialization
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Dimension variable 'b' must have integer value >= 1"):
|
||||
check_avals(
|
||||
arg_shapes=[(5, 36)],
|
||||
polymorphic_shapes=[PS("3 * b", ...)],
|
||||
eager_mode=True)
|
||||
conv_and_run(arg_shape=(5, 36), polymorphic_shape="3 * b, ...")
|
||||
|
||||
# TODO(necula): enable even for native serialization
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Dimension variable 'b' must have integer value >= 1"):
|
||||
check_avals(
|
||||
arg_shapes=[(10, 3)],
|
||||
polymorphic_shapes=[PS("3 * b + 10", ...)],
|
||||
eager_mode=True)
|
||||
conv_and_run(arg_shape=(10, 3), polymorphic_shape="3 * b + 10, ...")
|
||||
|
||||
# TODO(necula): enable even for native serialization
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Dimension variable 'b' must have integer value >= 1"):
|
||||
check_avals(
|
||||
arg_shapes=[(7, 3)],
|
||||
polymorphic_shapes=[PS("3 * b + 10", ...)],
|
||||
eager_mode=True)
|
||||
|
||||
for invalid_syntax in [5.0, ["a list"], ("a tuple",), re.compile(".")]:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
re.escape("Invalid polymorphic shape element")):
|
||||
check_avals(
|
||||
arg_shapes=[(2,)], polymorphic_shapes=[PS([invalid_syntax])])
|
||||
conv_and_run(arg_shape=(7, 3), polymorphic_shape="3 * b + 10, ...")
|
||||
|
||||
# TODO(necula): enable even for native serialization
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Found inconsistency 3 != 2 when solving.*"):
|
||||
check_avals(
|
||||
arg_shapes=[(2, 3)],
|
||||
polymorphic_shapes=["(a, a)"],
|
||||
eager_mode=True)
|
||||
|
||||
# Same error across multiple arguments
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Found inconsistency 5 != 2 when solving.*"):
|
||||
check_avals(
|
||||
arg_shapes=[(2, 3), (5,)],
|
||||
polymorphic_shapes=["a, ...", "a"],
|
||||
eager_mode=True)
|
||||
conv_and_run(arg_shape=(2, 3), polymorphic_shape="(a, a)")
|
||||
|
||||
def test_pytree(self):
|
||||
"""Arguments and polymorphic_shapes are pytrees."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user