Merge pull request #15287 from gnecula:tf_dim_vars

PiperOrigin-RevId: 520633830
This commit is contained in:
jax authors 2023-03-30 07:37:47 -07:00
commit 1fd6e01289
6 changed files with 156 additions and 166 deletions

View File

@ -573,32 +573,33 @@ def sharded_aval(aval: core.ShapedArray,
return aval.update(tuple(sharded_shape))
class DimExprEvaluator:
class DimExprValueMlir:
# TODO(necula): remove this, use regular JAX lowering
# A wrapper for an ir.Value that overloads + and * to be used for evaluating
# symbolic dimensions.
# symbolic dimensions resulting in ir.Value. See shape_poly.DimExprValue.
__array_priority__ = 1000 # Same as tracer, for __radd__ and others on ndarray
def __init__(self, value: ir.Value):
self.value = value
def __add__(self, other: Union[np.int32, np.int64, DimExprEvaluator]):
if not isinstance(other, DimExprEvaluator):
other = DimExprEvaluator(ir_constant(other))
return DimExprEvaluator(hlo.AddOp(self.value, other.value).result)
def __add__(self, other: Union[np.int32, np.int64, DimExprValueMlir]):
if not isinstance(other, DimExprValueMlir):
other = DimExprValueMlir(ir_constant(other))
return DimExprValueMlir(hlo.AddOp(self.value, other.value).result)
def __radd__(self, other: Union[np.int32, np.int64]):
return DimExprEvaluator(ir_constant(other)).__add__(self)
return DimExprValueMlir(ir_constant(other)).__add__(self)
def __mul__(self, other: Union[np.int32, np.int64, DimExprEvaluator]):
if not isinstance(other, DimExprEvaluator):
other = DimExprEvaluator(ir_constant(other))
return DimExprEvaluator(hlo.MulOp(self.value, other.value).result)
def __mul__(self, other: Union[np.int32, np.int64, DimExprValueMlir]):
if not isinstance(other, DimExprValueMlir):
other = DimExprValueMlir(ir_constant(other))
return DimExprValueMlir(hlo.MulOp(self.value, other.value).result)
def __rmul__(self, other: Union[np.int32, np.int64]):
return DimExprEvaluator(ir_constant(other)).__mul__(self)
return DimExprValueMlir(ir_constant(other)).__mul__(self)
def __divmod__(self, divisor: Union[np.int32, np.int64, DimExprEvaluator]):
if not isinstance(divisor, DimExprEvaluator):
divisor = DimExprEvaluator(ir_constant(divisor))
def __divmod__(self, divisor: Union[np.int32, np.int64, DimExprValueMlir]):
if not isinstance(divisor, DimExprValueMlir):
divisor = DimExprValueMlir(ir_constant(divisor))
# Quotient
raw_quotient = hlo.DivOp(self.value, divisor.value)
raw_remainder = hlo.RemOp(self.value, divisor.value)
@ -613,11 +614,11 @@ class DimExprEvaluator:
raw_quotient)
# Remainder
remainder = hlo.SubtractOp(self.value, hlo.MulOp(divisor.value, quotient))
return (DimExprEvaluator(quotient.result),
DimExprEvaluator(remainder.result))
return (DimExprValueMlir(quotient.result),
DimExprValueMlir(remainder.result))
def __rdivmod__(self, dividend: Union[np.int32, np.int64]):
return DimExprEvaluator(ir_constant(dividend)).__divmod__(self)
return DimExprValueMlir(ir_constant(dividend)).__divmod__(self)
def eval_dynamic_shape(ctx: LoweringRuleContext,
shape: core.Shape) -> Tuple[Union[int, Value], ...]:
@ -625,7 +626,7 @@ def eval_dynamic_shape(ctx: LoweringRuleContext,
if config.jax_dynamic_shapes:
return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore
else:
dim_var_env = {dv_name : DimExprEvaluator(dv_val[0])
dim_var_env = {dv_name: DimExprValueMlir(dv_val[0])
for dv_name, dv_val in zip(ctx.module_context.dim_vars, ctx.dim_var_values)}
def eval_dim(d: core.DimSize) -> Union[int, ir.Value]:
try:

View File

@ -399,8 +399,11 @@ def convert(fun_jax: Callable,
native_serialization_strict_checks)
return outs_tf, out_avals
else:
def get_dimension_size(args, arg_idx: int, dim_idx: int) -> shape_poly.DimExprValue:
return shape_poly.dimension_size_p.bind(args[arg_idx], dimension=dim_idx)
dim_vars, get_dim_values_jax = shape_poly.prepare_dim_var_env(
args_avals_flat)
args_avals_flat, get_dimension_size)
dim_values, _ = _interpret_fun_jax(get_dim_values_jax, args_flat_tf,
args_avals_flat, name_stack)
shape_env = zip(dim_vars, dim_values) # type: ignore
@ -984,6 +987,7 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfVal]:
# Returns a tuple of shape_poly.dim_as_value_dtype
# Used only for non-native lowering
assert all(map(lambda x: x is not None, shape)), (
f"Argument shape should be a valid JAX shape but got {shape}")
if dtype is not None:
@ -993,7 +997,7 @@ def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfV
dim_vars, dim_values = util.unzip2(_thread_local_state.shape_env)
eval_shape_jax = shape_poly.get_shape_evaluator(dim_vars, shape)
dim_aval = shape_poly.dim_as_value_abstract(1)
dim_aval = shape_poly.dim_as_value_abstract()
shape_values_tf, _ = _interpret_fun_jax(eval_shape_jax,
dim_values, [dim_aval] * len(dim_values), "") # type: ignore
# Keep only the non-constant dimensions

View File

@ -35,7 +35,7 @@ from jax._src.lib.mlir.dialects import stablehlo
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import func as func_dialect
from jax.experimental.jax2tf import shape_poly
map = util.safe_map
zip = util.safe_zip
@ -127,7 +127,7 @@ def serialize_native(fun_jax: Callable,
if not all(core.is_constant_shape(a.shape) for a in args_avals):
# All arguments are kept if we have dimension variables.
assert len(module_kept_var_idx) == len(args_avals)
mlir_module = compute_dim_vars(mlir_module, args_avals)
mlir_module = add_dim_arg_computation(mlir_module, args_avals)
xla_call_module_version = 4
mlir_str = mlir.module_to_bytecode(mlir_module)
@ -168,9 +168,9 @@ def serialize_native(fun_jax: Callable,
xla_call_module_version=xla_call_module_version)
def compute_dim_vars(module: mlir.ir.Module,
args_avals: Sequence[core.ShapedArray]) -> mlir.ir.Module:
"""Wraps the lowered module with a new "main" that computes the dim vars.
def add_dim_arg_computation(module: mlir.ir.Module,
args_avals: Sequence[core.ShapedArray]) -> mlir.ir.Module:
"""Wraps the lowered module with a new "main" that computes the dim args.
JAX lowering in presence of shape polymorphism produces a `module` that
takes one or more dimension arguments, specified using 0-dimensional tensors
@ -203,7 +203,7 @@ def compute_dim_vars(module: mlir.ir.Module,
Returns the wrapped module.
"""
dim_args_builders = get_dim_arg_builders(args_avals)
dim_vars = shape_poly.all_dim_vars(args_avals)
# Make a new module, do not mutate the "module" because it may be cached
context = mlir.make_ir_context()
@ -216,7 +216,7 @@ def compute_dim_vars(module: mlir.ir.Module,
symbol_table.set_symbol_name(orig_main, orig_main_name)
orig_input_types = orig_main.type.inputs
nr_array_args = len(orig_input_types) - len(dim_args_builders)
nr_array_args = len(orig_input_types) - len(dim_vars)
assert nr_array_args >= 0
new_main_input_types = orig_input_types[- nr_array_args:]
@ -237,11 +237,11 @@ def compute_dim_vars(module: mlir.ir.Module,
symbol_table.insert(new_main_op)
entry_block = new_main_op.add_entry_block()
with ir.InsertionPoint(entry_block):
orig_main_args = []
orig_main_args: List[mlir.ir.Value] = []
dim_args = compute_dim_args(args_avals, tuple(new_main_op.arguments),
orig_input_types[:len(dim_vars)])
# The first arguments are the dimension variable
for dim_arg_idx, dim_arg_builder in enumerate(dim_args_builders):
orig_main_args.append(
dim_arg_builder(new_main_op.arguments, orig_input_types[dim_arg_idx]))
orig_main_args.extend(dim_args)
# Then the array arguments
orig_main_args.extend(new_main_op.arguments)
call = func_dialect.CallOp(orig_output_types,
@ -252,60 +252,40 @@ def compute_dim_vars(module: mlir.ir.Module,
return new_module
# A dimension argument builder computes a dimension argument given
# the array arguments and the desired type of the dimension argument.
DimArgBuilder = Callable[[Sequence[mlir.ir.Value], mlir.ir.Type], mlir.ir.Value]
def get_dim_arg_builders(
args_avals: Sequence[core.ShapedArray]) -> Sequence[DimArgBuilder]:
"""For each dimension variable, return a builder.
def compute_dim_args(
args_avals: Sequence[core.ShapedArray],
array_args: Sequence[mlir.ir.Value],
dim_arg_types: Sequence[mlir.ir.Type]) -> Sequence[mlir.ir.Value]:
"""Compute the values of the dimension arguments.
Args:
args_avals: the abstract values of the array arguments.
array_args: the values of the array arguments.
dim_arg_types: the desired types for the dimension arguments.
Returns:
a list of DimArgBuilder, for each dimension variable appearing in `args_avals`
in the sorted order of dimension variable name.
the values of the dimension variables, in the sorted order of the
dimension variables.
"""
def get_dim_arg(array_arg_idx: int, dim_idx: int,
array_args: Sequence[mlir.ir.Value],
dim_arg_type: mlir.ir.Type) -> mlir.ir.Value:
dim_arg = hlo.GetDimensionSizeOp(array_args[array_arg_idx], dim_idx)
if dim_arg.result.type != dim_arg_type:
return hlo.ConvertOp(dim_arg_type, dim_arg).result
def get_dimension_size(args, arg_idx: int, dim_idx: int) -> mlir.DimExprValueMlir:
dim_size = hlo.GetDimensionSizeOp(args[arg_idx], dim_idx).result
dim_type = mlir.aval_to_ir_type(shape_poly.dim_as_value_abstract())
if dim_size.type != dim_type:
dim_size = hlo.ConvertOp(dim_type, dim_size).result
return mlir.DimExprValueMlir(dim_size)
all_dim_vars, dim_arg_builders = shape_poly.prepare_dim_var_env(
args_avals, get_dimension_size)
all_dim_args: Sequence[mlir.DimExprValueMlir] = dim_arg_builders(*array_args)
res = []
for dim_arg, dim_arg_type in zip(all_dim_args, dim_arg_types):
dim_arg = dim_arg.value
if dim_arg.type != dim_arg_type:
res.append(hlo.ConvertOp(dim_arg_type, dim_arg).result)
else:
return dim_arg.result
dim_args_builder_dict: Dict[str, DimArgBuilder] = {} # a builder for each dim var by name
all_dim_vars: Set[str] = set()
for arg_idx, aval in enumerate(args_avals):
for axis_idx, d in enumerate(aval.shape):
if not core.is_constant_dim(d):
all_dim_vars = all_dim_vars.union(d.get_vars())
d_var = d.to_var()
# TODO(necula): compute dim vars from non-trivial expressions also
if d_var is None: continue
if not d_var in dim_args_builder_dict:
dim_args_builder_dict[d_var] = partial(get_dim_arg, arg_idx, axis_idx)
if all_dim_vars:
dim_vars_with_builders_set = set(dim_args_builder_dict.keys())
if dim_vars_with_builders_set != all_dim_vars:
missing = all_dim_vars.difference(dim_vars_with_builders_set)
args_list = [f" Arg[{arg_idx}]: {aval}"
for arg_idx, aval in enumerate(args_avals)]
raise ValueError(
"The following dimension variables cannot be computed from the static "
f"shapes of the array arguments: {missing}. The argument shapes are:\n" +
"\n".join(args_list) +
"\n"
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.")
# In sorted order by name
builders = [dim_args_builder_dict[d_var] for d_var in sorted(dim_args_builder_dict.keys())]
else:
builders = []
return builders
res.append(dim_arg)
return tuple(res)
def check_module(mod: mlir.ir.Module, *,

View File

@ -56,10 +56,11 @@ from jax._src.typing import DimSize, Shape
TfVal = Any
# A dimension environment maps dimension variables to expressions that
# compute the values of the dimension. An expression must support addition
# denote the values of the dimension (DimExprValue). A DimExprValue must support addition
# and multiplication with other expressions or with int32/int64, e.g.,
# by overloading __add__, __radd__, __mul__, __rmul__, __divmod__, __rdivmod__.
ShapeEnv = Dict[str, Any]
DimExprValue = Any
ShapeEnv = Dict[str, DimExprValue]
DType = Any
class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation):
@ -324,7 +325,7 @@ class _DimExpr():
"""Symbolic expression in terms of dimension variables.
A dimension expression is an addition of products (_DimMon)
f atoms (_DimAtom).
of atoms (_DimAtom).
We overload integer operations, but we do that soundly, raising
:class:`InconclusiveDimensionOperation` when the result is not
@ -647,7 +648,7 @@ class _DimExpr():
@staticmethod
def get_aval(dim: "_DimExpr"):
return dim_as_value_abstract(dim)
return dim_as_value_abstract()
def dimension_as_value(self):
"""Turns a dimension size into a Jax value that we can compute with."""
@ -786,10 +787,10 @@ def dim_as_value_dtype():
def dim_constant(ct: int):
return np.array(ct, dtype=dim_as_value_dtype())
def dim_as_value_abstract(dim: DimSize) -> core.AbstractValue:
def dim_as_value_abstract() -> core.AbstractValue:
return core.ShapedArray((), dim_as_value_dtype(), weak_type=True)
dim_as_value_p.def_abstract_eval(dim_as_value_abstract)
dim_as_value_p.def_abstract_eval(lambda dim: dim_as_value_abstract())
def dim_as_value_impl(dim: DimSize):
raise NotImplementedError(
@ -992,7 +993,7 @@ def _is_known_constant(v) -> Optional[int]:
# value of type shape_poly.dim_as_value_dtype().
dimension_size_p = core.Primitive("dimension_size")
def _dimension_size_abstract(aval: core.AbstractValue, **_) -> core.AbstractValue:
return dim_as_value_abstract(aval)
return dim_as_value_abstract()
dimension_size_p.def_abstract_eval(_dimension_size_abstract)
@ -1004,9 +1005,13 @@ _JaxValue = Any
@dataclasses.dataclass
class DimEquation:
# Represents poly == _expr
poly: _DimExpr
dim_expr: _JaxValue # Of type dim_as_value_dtype()
# Represents args[arg_idx].shape[dim_idx] == dim_expr
arg_idx: int
dim_idx: int
dim_expr: _DimExpr
def __str__(self):
return f"{self.dim_expr} == args[{self.arg_idx}].shape[{self.dim_idx}]"
def get_shape_evaluator(dim_vars: Sequence[str], shape: Sequence[DimSize]) ->\
@ -1041,37 +1046,50 @@ def arg_aval(
aval_shape = _parse_spec(polymorphic_shape, arg_shape)
return core.ShapedArray(aval_shape, arg_jax_dtype)
def prepare_dim_var_env(args_avals: Sequence[core.AbstractValue]) -> \
Tuple[Sequence[str],
Callable[..., Sequence[TfVal]]]:
"""Get the dimension variables and the function to compute them.
Returns a tuple of dimension variables that appear in `args_avals` along
with a function that given the actual arguments of the top-level function
returns a tuple of dimension variable values, in the same order as the
dimension variables returned in the first component.
The dimension variables are TfVal with type dim_as_value_dtype().
"""
def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Set[str]:
dim_vars: Set[str] = set()
for a in args_avals:
for d in a.shape:
if is_poly_dim(d):
dim_vars = dim_vars.union(d.get_vars())
return dim_vars
def prepare_dim_var_env(
args_avals: Sequence[core.AbstractValue],
get_dimension_size: Callable[[Sequence[Any], int, int], DimExprValue]) -> \
Tuple[Sequence[str],
Callable[..., Sequence[DimExprValue]]]:
"""Get the dimension variables and the function to compute them.
Args:
args_aval: the abstract values of the array arguments
get_dimension_size: a function that given the array arguments, the argument
index, and the dimension index, computes the value of args[arg_idx].shape[dim_idx].
Returns: a tuple of dimension variables that appear in `args_avals` along
with a function that given the actual arguments of the top-level function
returns a tuple of dimension variable values, in the same order as the
dimension variables returned in the first component.
"""
# TODO(necula): clean up the notion of shape environments, and solving for dimension
# variables.
dim_vars: Set[str] = all_dim_vars(args_avals)
dim_vars_sorted = sorted(tuple(dim_vars))
def get_dim_var_values(*args: Any) -> Sequence[Any]:
dim_equations: List[DimEquation] = []
for a in args:
for i, d in enumerate(a.shape):
for arg_idx, a in enumerate(args_avals):
for dim_idx, d in enumerate(a.shape):
if is_poly_dim(d):
dim_equations.append(DimEquation(
poly=d, dim_expr=dimension_size_p.bind(a, dimension=i)))
dim_equations.append(
DimEquation(arg_idx=arg_idx, dim_idx=dim_idx, dim_expr=d))
dim_env = _solve_dim_equations(dim_equations,
functools.partial(get_dimension_size, args))
return tuple(dim_env[dv] for dv in dim_vars_sorted)
return dim_vars_sorted, get_dim_var_values
dim_env = _solve_dim_equations(dim_equations)
assert all(dim_env[dv].dtype == dim_as_value_dtype() for dv in dim_vars)
return tuple(dim_env[dv] for dv in dim_vars)
return tuple(dim_vars), get_dim_var_values
def _solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
def _solve_dim_equations(eqns: List[DimEquation],
get_dimension_size: Callable[[int, int], DimExprValue]) -> ShapeEnv:
# Returns a shape environment if it can solve all dimension variables.
# Raises an exception if it cannot.
shapeenv: ShapeEnv = {}
@ -1084,19 +1102,19 @@ def _solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
return ""
def process_one_eqn(eqn: DimEquation) -> bool:
# Try to rewrite the equation as "var * factor_var = dim_expr" (a linear
# uni-variate equation. Return False if this rewrite fails.
# Try to rewrite the equation as "var * factor_var = dim_value" (a linear
# uni-variate equation). Return False if this rewrite fails.
# Otherwise, add the variable to shapeenv and return True.
# The invariant is: var * factor_var + rest_eqn_poly = dim_expr
var, factor_var = None, None
dim_expr = eqn.dim_expr
dim_value = get_dimension_size(eqn.arg_idx, eqn.dim_idx)
# The invariant is: var * factor_var + rest_eqn_dim_expr = dim_value
for mon, factor in eqn.poly.monomials():
for mon, factor in eqn.dim_expr.monomials():
# Perhaps we can already evaluate this monomial (all vars solved)
try:
mon_value = mon.evaluate(shapeenv)
dim_expr = dim_expr - _evaluate_multiply(mon_value, dim_constant(factor))
dim_value = dim_value + -1 * _evaluate_multiply(mon_value, dim_constant(factor))
continue
except KeyError:
# There are some indeterminate variables. We handle only the case of
@ -1110,10 +1128,9 @@ def _solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
if var is not None:
if factor_var == 1:
var_value, var_remainder = dim_expr, dim_constant(0)
var_value, var_remainder = dim_value, dim_constant(0)
else:
var_value = lax.div(dim_expr, dim_constant(factor_var)) # type: ignore
var_remainder = lax.rem(dim_expr, dim_constant(factor_var)) # type: ignore
var_value, var_remainder = divmod(dim_value, dim_constant(factor_var)) # type: ignore
# Check that the division is even. Works only in eager mode.
var_remainder_int = _is_known_constant(var_remainder)
@ -1121,25 +1138,25 @@ def _solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
# TODO(necula): check even in graph mode, by embedding the checks in
# the graph.
msg = (f"Dimension variable {var} must have integer value >= 1. " # type: ignore
f"Found value {int(_is_known_constant(dim_expr)) / factor_var} when solving " # type: ignore
f"{eqn.poly} == {eqn.dim_expr}.{_shapeenv_to_str()}")
f"Found value {int(_is_known_constant(dim_value)) / factor_var} when solving " # type: ignore
f"{eqn}.{_shapeenv_to_str()}")
raise ValueError(msg)
var_value_int = _is_known_constant(var_value)
if var_value_int is not None and var_value_int <= 0:
msg = (f"{var_value_int} Dimension variable {var} must have integer value >= 1. "
f"Found value {int(var_value_int)} when solving "
f"{eqn.poly} == {eqn.dim_expr}.{_shapeenv_to_str()}")
f"{eqn}.{_shapeenv_to_str()}")
raise ValueError(msg)
shapeenv[var] = var_value
return True
else:
# All variables are resolved for this equation
dim_expr_int = _is_known_constant(dim_expr)
if dim_expr_int is not None and dim_expr_int != 0:
dim_value_int = _is_known_constant(dim_value)
if dim_value_int is not None and dim_value_int != 0:
err_msg = (
"Found inconsistency when solving "
f"{eqn.poly} == {eqn.dim_expr}.{_shapeenv_to_str()}")
f"{eqn}.{_shapeenv_to_str()}")
raise ValueError(err_msg)
return True
@ -1155,10 +1172,10 @@ def _solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
unsolved_vars: Set[str] = set()
unsolved_polys: List[_DimExpr] = []
for eqn in eqns:
unsolved_vars = unsolved_vars.union(eqn.poly.get_vars())
unsolved_polys.append(eqn.poly)
unsolved_vars = unsolved_vars.union(eqn.dim_expr.get_vars())
unsolved_polys.append(eqn.dim_expr)
unsolved_vars = unsolved_vars.difference(shapeenv.keys())
eqns_str = "\n ".join([str(eqn.poly) for eqn in eqns])
eqns_str = "\n ".join([str(eqn.dim_expr) for eqn in eqns])
err_msg = (
f"Cannot solve for values of dimension variables {unsolved_vars} from "
f"the remaining dimension polynomials\n {eqns_str}.{_shapeenv_to_str()} "

View File

@ -739,7 +739,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
arg_dtypes = (_f32,) * len(arg_shapes)
def f_tf(*args_tf):
avals = tuple(map(shape_poly.arg_aval, arg_shapes, arg_dtypes, polymorphic_shapes))
dim_vars, get_dim_values_jax = shape_poly.prepare_dim_var_env(avals)
def get_dimension_size(args, arg_idx: int, dim_idx: int) -> shape_poly.DimExprValue:
return shape_poly.dimension_size_p.bind(args[arg_idx], dimension=dim_idx)
dim_vars, get_dim_values_jax = shape_poly.prepare_dim_var_env(avals, get_dimension_size)
dim_values, _ = jax2tf.jax2tf._interpret_fun_jax(get_dim_values_jax,
args_tf, avals, "")
if expected_avals is not None:
@ -1007,16 +1009,25 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
arg_descriptors=[RandArg((3, 4), _f32)],
poly_axes=[(0, 1)])
def test_non_trivial_polynomials_spec(self):
if config.jax_dynamic_shapes:
raise unittest.SkipTest("--jax_dynamic_shapes supports only trivial polynomials")
@parameterized.named_parameters(
dict(testcase_name=f"_{str(polymorphic_shapes)}",
polymorphic_shapes=polymorphic_shapes)
# The polymorphic_shapes should have three comma-separated DimExpr matching
# 16, 24, 32
for polymorphic_shapes in [
"b1+6,b1+14,b2", # b1=10, b2=32
"2*b1,4*b2,b1+b2+18", # b1=8,b2=6
"b1+2*b2,4*b2,b1*b1+16", # b1=4,b2=6
])
def test_non_trivial_polynomials_spec(self,
polymorphic_shapes="2*b1,4*b2,b1+b2+18"):
# We can handle non-trivial polynomials in the input shape,
# as long as all variables also occur in trivial polynoamials
# as long as all variables also occur in trivial expressions
check_shape_poly(self,
lambda x, y: x + y.reshape((-1,)),
arg_descriptors=[RandArg((9,), _f32), RandArg((3, 3), _f32)],
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None, None])],
polymorphic_shapes=["b * b", "b, b"])
lambda x: 2 * x.shape[0] + 3 * x.shape[1] + 4 * x.shape[2],
arg_descriptors=[RandArg((16, 24, 32), _f32)],
input_signature=[tf.TensorSpec([None, None, None])],
polymorphic_shapes=polymorphic_shapes)
def test_unused_args(self):
# Tests with functions that do not use their inputs.
@ -1166,9 +1177,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
self.assertEqual((None, 3, 4), tuple(tf_grad.output_shapes[1]["grad"]))
def test_grad_not_var_output(self):
# Output of the function has poly shapes, non-variable
if config.jax2tf_default_native_serialization:
raise unittest.SkipTest("Not supported with native serialization")
def f_jax(x): # :[b, 3]
return jnp.reshape(x, (-1,)) # : [3b]
x = np.arange(12, dtype=np.float32).reshape((4, 3))
@ -1182,7 +1190,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
grad_tf = tape.gradient(res_tf, xv)
self.assertAllClose(np.ones(x.shape, dtype=np.float32), grad_tf.numpy())
def test_cond(self):
# Test the primitive under conditional
def f(x, y):
@ -1280,9 +1287,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(f_jax(y), restored_f(y))
def test_saved_model_int_function(self):
# TODO(https://github.com/google/jax/issues/14437)
if config.jax2tf_default_native_serialization:
raise unittest.SkipTest("Gradient function does not use the dimension variables")
def f_jax(x): # x:s32[b, 3, 4]
return jnp.reshape(x, (-1,)) # : s32[b * 12]
@ -1301,10 +1305,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(f_jax(x), res_jax_rt)
def test_saved_model_constant_gradient(self):
# TODO(https://github.com/google/jax/issues/14437)
if config.jax2tf_default_native_serialization:
raise unittest.SkipTest("Gradient function does not use the dimension variables")
def f_jax(x): # A function whose gradient is a constant
return x
@ -1369,10 +1369,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
jax2tf.convert(lambda x: jnp.reshape(x, (-1, x.shape[0])),
polymorphic_shapes=["(b1, b2, ...)"])(np.ones((4, 5, 6)))
if not config.jax2tf_default_native_serialization:
# Does not support 2*b constraints
jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),
polymorphic_shapes=["(2*b, ...)"])(np.ones((4, 5, 7)))
jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),
polymorphic_shapes=["(2*b, ...)"])(np.ones((4, 5, 7)))
with self.assertRaisesRegex(
core.InconclusiveDimensionOperation,
@ -1398,7 +1396,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
native_serialization=False)(x0)
# TODO(https://github.com/google/jax/issues/14437)
# In native serialization, or if we trace to a TF graph, we miss this
res1_tf = jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
native_serialization=True)(x0)
@ -1427,7 +1424,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
native_serialization=False)(x45)
# TODO(https://github.com/google/jax/issues/14437)
# In native serialization, or if we trace to a TF graph, we miss this
res2_tf = jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
native_serialization=True)(x45)
@ -1443,13 +1439,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
with self.assertRaisesRegex(
ValueError,
"Cannot solve for values of dimension variables"):
jax2tf.convert(lambda x: jnp.sum(x), polymorphic_shapes=["a + b"],
native_serialization=False)(x)
with self.assertRaisesRegex(
ValueError,
"dimension variables cannot be computed from the static shapes of the array arguments"):
jax2tf.convert(lambda x: jnp.sum(x), polymorphic_shapes=["a + b"],
native_serialization=True)(x)
jax2tf.convert(lambda x: jnp.sum(x), polymorphic_shapes=["a + b"])(x)
def test_dynamic_shapes(self):
# Test dim_as_value with dynamic shapes.
@ -1470,11 +1461,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
polymorphic_shapes=["b", "b"])(x, np.array([0.1, 0.2, 0.3]))
self.assertAllClose((9., 1.8), (res_primal, res_tangent))
# TODO(https://github.com/google/jax/issues/14437)
if not config.jax2tf_default_native_serialization:
self.assertAllClose(
np.array([3., 3., 3.]),
jax2tf.convert(jax.grad(f), polymorphic_shapes=["b"])(x))
self.assertAllClose(
np.array([3., 3., 3.]),
jax2tf.convert(jax.grad(f), polymorphic_shapes=["b"])(x))
xv = np.arange(24.).reshape((2, 3, 4))
res_vmap = jax.vmap(f, in_axes=1)(xv)

View File

@ -16,7 +16,6 @@ from jax._src.interpreters.mlir import (
AxisContext as AxisContext,
ConstantHandler as ConstantHandler,
DEVICE_TO_DEVICE_TYPE as DEVICE_TO_DEVICE_TYPE,
DimExprEvaluator as DimExprEvaluator,
LoweringResult as LoweringResult,
LoweringRule as LoweringRule,
LoweringRuleContext as LoweringRuleContext,