mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #15287 from gnecula:tf_dim_vars
PiperOrigin-RevId: 520633830
This commit is contained in:
commit
1fd6e01289
@ -573,32 +573,33 @@ def sharded_aval(aval: core.ShapedArray,
|
||||
return aval.update(tuple(sharded_shape))
|
||||
|
||||
|
||||
class DimExprEvaluator:
|
||||
class DimExprValueMlir:
|
||||
# TODO(necula): remove this, use regular JAX lowering
|
||||
# A wrapper for an ir.Value that overloads + and * to be used for evaluating
|
||||
# symbolic dimensions.
|
||||
# symbolic dimensions resulting in ir.Value. See shape_poly.DimExprValue.
|
||||
__array_priority__ = 1000 # Same as tracer, for __radd__ and others on ndarray
|
||||
def __init__(self, value: ir.Value):
|
||||
self.value = value
|
||||
|
||||
def __add__(self, other: Union[np.int32, np.int64, DimExprEvaluator]):
|
||||
if not isinstance(other, DimExprEvaluator):
|
||||
other = DimExprEvaluator(ir_constant(other))
|
||||
return DimExprEvaluator(hlo.AddOp(self.value, other.value).result)
|
||||
def __add__(self, other: Union[np.int32, np.int64, DimExprValueMlir]):
|
||||
if not isinstance(other, DimExprValueMlir):
|
||||
other = DimExprValueMlir(ir_constant(other))
|
||||
return DimExprValueMlir(hlo.AddOp(self.value, other.value).result)
|
||||
|
||||
def __radd__(self, other: Union[np.int32, np.int64]):
|
||||
return DimExprEvaluator(ir_constant(other)).__add__(self)
|
||||
return DimExprValueMlir(ir_constant(other)).__add__(self)
|
||||
|
||||
def __mul__(self, other: Union[np.int32, np.int64, DimExprEvaluator]):
|
||||
if not isinstance(other, DimExprEvaluator):
|
||||
other = DimExprEvaluator(ir_constant(other))
|
||||
return DimExprEvaluator(hlo.MulOp(self.value, other.value).result)
|
||||
def __mul__(self, other: Union[np.int32, np.int64, DimExprValueMlir]):
|
||||
if not isinstance(other, DimExprValueMlir):
|
||||
other = DimExprValueMlir(ir_constant(other))
|
||||
return DimExprValueMlir(hlo.MulOp(self.value, other.value).result)
|
||||
|
||||
def __rmul__(self, other: Union[np.int32, np.int64]):
|
||||
return DimExprEvaluator(ir_constant(other)).__mul__(self)
|
||||
return DimExprValueMlir(ir_constant(other)).__mul__(self)
|
||||
|
||||
def __divmod__(self, divisor: Union[np.int32, np.int64, DimExprEvaluator]):
|
||||
if not isinstance(divisor, DimExprEvaluator):
|
||||
divisor = DimExprEvaluator(ir_constant(divisor))
|
||||
def __divmod__(self, divisor: Union[np.int32, np.int64, DimExprValueMlir]):
|
||||
if not isinstance(divisor, DimExprValueMlir):
|
||||
divisor = DimExprValueMlir(ir_constant(divisor))
|
||||
# Quotient
|
||||
raw_quotient = hlo.DivOp(self.value, divisor.value)
|
||||
raw_remainder = hlo.RemOp(self.value, divisor.value)
|
||||
@ -613,11 +614,11 @@ class DimExprEvaluator:
|
||||
raw_quotient)
|
||||
# Remainder
|
||||
remainder = hlo.SubtractOp(self.value, hlo.MulOp(divisor.value, quotient))
|
||||
return (DimExprEvaluator(quotient.result),
|
||||
DimExprEvaluator(remainder.result))
|
||||
return (DimExprValueMlir(quotient.result),
|
||||
DimExprValueMlir(remainder.result))
|
||||
|
||||
def __rdivmod__(self, dividend: Union[np.int32, np.int64]):
|
||||
return DimExprEvaluator(ir_constant(dividend)).__divmod__(self)
|
||||
return DimExprValueMlir(ir_constant(dividend)).__divmod__(self)
|
||||
|
||||
def eval_dynamic_shape(ctx: LoweringRuleContext,
|
||||
shape: core.Shape) -> Tuple[Union[int, Value], ...]:
|
||||
@ -625,7 +626,7 @@ def eval_dynamic_shape(ctx: LoweringRuleContext,
|
||||
if config.jax_dynamic_shapes:
|
||||
return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore
|
||||
else:
|
||||
dim_var_env = {dv_name : DimExprEvaluator(dv_val[0])
|
||||
dim_var_env = {dv_name: DimExprValueMlir(dv_val[0])
|
||||
for dv_name, dv_val in zip(ctx.module_context.dim_vars, ctx.dim_var_values)}
|
||||
def eval_dim(d: core.DimSize) -> Union[int, ir.Value]:
|
||||
try:
|
||||
|
@ -399,8 +399,11 @@ def convert(fun_jax: Callable,
|
||||
native_serialization_strict_checks)
|
||||
return outs_tf, out_avals
|
||||
else:
|
||||
def get_dimension_size(args, arg_idx: int, dim_idx: int) -> shape_poly.DimExprValue:
|
||||
return shape_poly.dimension_size_p.bind(args[arg_idx], dimension=dim_idx)
|
||||
dim_vars, get_dim_values_jax = shape_poly.prepare_dim_var_env(
|
||||
args_avals_flat)
|
||||
args_avals_flat, get_dimension_size)
|
||||
|
||||
dim_values, _ = _interpret_fun_jax(get_dim_values_jax, args_flat_tf,
|
||||
args_avals_flat, name_stack)
|
||||
shape_env = zip(dim_vars, dim_values) # type: ignore
|
||||
@ -984,6 +987,7 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
|
||||
|
||||
def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfVal]:
|
||||
# Returns a tuple of shape_poly.dim_as_value_dtype
|
||||
# Used only for non-native lowering
|
||||
assert all(map(lambda x: x is not None, shape)), (
|
||||
f"Argument shape should be a valid JAX shape but got {shape}")
|
||||
if dtype is not None:
|
||||
@ -993,7 +997,7 @@ def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfV
|
||||
|
||||
dim_vars, dim_values = util.unzip2(_thread_local_state.shape_env)
|
||||
eval_shape_jax = shape_poly.get_shape_evaluator(dim_vars, shape)
|
||||
dim_aval = shape_poly.dim_as_value_abstract(1)
|
||||
dim_aval = shape_poly.dim_as_value_abstract()
|
||||
shape_values_tf, _ = _interpret_fun_jax(eval_shape_jax,
|
||||
dim_values, [dim_aval] * len(dim_values), "") # type: ignore
|
||||
# Keep only the non-constant dimensions
|
||||
|
@ -35,7 +35,7 @@ from jax._src.lib.mlir.dialects import stablehlo
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
|
||||
from jax.experimental.jax2tf import shape_poly
|
||||
|
||||
map = util.safe_map
|
||||
zip = util.safe_zip
|
||||
@ -127,7 +127,7 @@ def serialize_native(fun_jax: Callable,
|
||||
if not all(core.is_constant_shape(a.shape) for a in args_avals):
|
||||
# All arguments are kept if we have dimension variables.
|
||||
assert len(module_kept_var_idx) == len(args_avals)
|
||||
mlir_module = compute_dim_vars(mlir_module, args_avals)
|
||||
mlir_module = add_dim_arg_computation(mlir_module, args_avals)
|
||||
|
||||
xla_call_module_version = 4
|
||||
mlir_str = mlir.module_to_bytecode(mlir_module)
|
||||
@ -168,9 +168,9 @@ def serialize_native(fun_jax: Callable,
|
||||
xla_call_module_version=xla_call_module_version)
|
||||
|
||||
|
||||
def compute_dim_vars(module: mlir.ir.Module,
|
||||
args_avals: Sequence[core.ShapedArray]) -> mlir.ir.Module:
|
||||
"""Wraps the lowered module with a new "main" that computes the dim vars.
|
||||
def add_dim_arg_computation(module: mlir.ir.Module,
|
||||
args_avals: Sequence[core.ShapedArray]) -> mlir.ir.Module:
|
||||
"""Wraps the lowered module with a new "main" that computes the dim args.
|
||||
|
||||
JAX lowering in presence of shape polymorphism produces a `module` that
|
||||
takes one or more dimension arguments, specified using 0-dimensional tensors
|
||||
@ -203,7 +203,7 @@ def compute_dim_vars(module: mlir.ir.Module,
|
||||
|
||||
Returns the wrapped module.
|
||||
"""
|
||||
dim_args_builders = get_dim_arg_builders(args_avals)
|
||||
dim_vars = shape_poly.all_dim_vars(args_avals)
|
||||
|
||||
# Make a new module, do not mutate the "module" because it may be cached
|
||||
context = mlir.make_ir_context()
|
||||
@ -216,7 +216,7 @@ def compute_dim_vars(module: mlir.ir.Module,
|
||||
symbol_table.set_symbol_name(orig_main, orig_main_name)
|
||||
|
||||
orig_input_types = orig_main.type.inputs
|
||||
nr_array_args = len(orig_input_types) - len(dim_args_builders)
|
||||
nr_array_args = len(orig_input_types) - len(dim_vars)
|
||||
assert nr_array_args >= 0
|
||||
|
||||
new_main_input_types = orig_input_types[- nr_array_args:]
|
||||
@ -237,11 +237,11 @@ def compute_dim_vars(module: mlir.ir.Module,
|
||||
symbol_table.insert(new_main_op)
|
||||
entry_block = new_main_op.add_entry_block()
|
||||
with ir.InsertionPoint(entry_block):
|
||||
orig_main_args = []
|
||||
orig_main_args: List[mlir.ir.Value] = []
|
||||
dim_args = compute_dim_args(args_avals, tuple(new_main_op.arguments),
|
||||
orig_input_types[:len(dim_vars)])
|
||||
# The first arguments are the dimension variable
|
||||
for dim_arg_idx, dim_arg_builder in enumerate(dim_args_builders):
|
||||
orig_main_args.append(
|
||||
dim_arg_builder(new_main_op.arguments, orig_input_types[dim_arg_idx]))
|
||||
orig_main_args.extend(dim_args)
|
||||
# Then the array arguments
|
||||
orig_main_args.extend(new_main_op.arguments)
|
||||
call = func_dialect.CallOp(orig_output_types,
|
||||
@ -252,60 +252,40 @@ def compute_dim_vars(module: mlir.ir.Module,
|
||||
return new_module
|
||||
|
||||
|
||||
# A dimension argument builder computes a dimension argument given
|
||||
# the array arguments and the desired type of the dimension argument.
|
||||
DimArgBuilder = Callable[[Sequence[mlir.ir.Value], mlir.ir.Type], mlir.ir.Value]
|
||||
|
||||
def get_dim_arg_builders(
|
||||
args_avals: Sequence[core.ShapedArray]) -> Sequence[DimArgBuilder]:
|
||||
"""For each dimension variable, return a builder.
|
||||
def compute_dim_args(
|
||||
args_avals: Sequence[core.ShapedArray],
|
||||
array_args: Sequence[mlir.ir.Value],
|
||||
dim_arg_types: Sequence[mlir.ir.Type]) -> Sequence[mlir.ir.Value]:
|
||||
"""Compute the values of the dimension arguments.
|
||||
|
||||
Args:
|
||||
args_avals: the abstract values of the array arguments.
|
||||
array_args: the values of the array arguments.
|
||||
dim_arg_types: the desired types for the dimension arguments.
|
||||
|
||||
Returns:
|
||||
a list of DimArgBuilder, for each dimension variable appearing in `args_avals`
|
||||
in the sorted order of dimension variable name.
|
||||
the values of the dimension variables, in the sorted order of the
|
||||
dimension variables.
|
||||
"""
|
||||
def get_dim_arg(array_arg_idx: int, dim_idx: int,
|
||||
array_args: Sequence[mlir.ir.Value],
|
||||
dim_arg_type: mlir.ir.Type) -> mlir.ir.Value:
|
||||
dim_arg = hlo.GetDimensionSizeOp(array_args[array_arg_idx], dim_idx)
|
||||
if dim_arg.result.type != dim_arg_type:
|
||||
return hlo.ConvertOp(dim_arg_type, dim_arg).result
|
||||
def get_dimension_size(args, arg_idx: int, dim_idx: int) -> mlir.DimExprValueMlir:
|
||||
dim_size = hlo.GetDimensionSizeOp(args[arg_idx], dim_idx).result
|
||||
dim_type = mlir.aval_to_ir_type(shape_poly.dim_as_value_abstract())
|
||||
if dim_size.type != dim_type:
|
||||
dim_size = hlo.ConvertOp(dim_type, dim_size).result
|
||||
return mlir.DimExprValueMlir(dim_size)
|
||||
|
||||
all_dim_vars, dim_arg_builders = shape_poly.prepare_dim_var_env(
|
||||
args_avals, get_dimension_size)
|
||||
all_dim_args: Sequence[mlir.DimExprValueMlir] = dim_arg_builders(*array_args)
|
||||
|
||||
res = []
|
||||
for dim_arg, dim_arg_type in zip(all_dim_args, dim_arg_types):
|
||||
dim_arg = dim_arg.value
|
||||
if dim_arg.type != dim_arg_type:
|
||||
res.append(hlo.ConvertOp(dim_arg_type, dim_arg).result)
|
||||
else:
|
||||
return dim_arg.result
|
||||
|
||||
dim_args_builder_dict: Dict[str, DimArgBuilder] = {} # a builder for each dim var by name
|
||||
all_dim_vars: Set[str] = set()
|
||||
for arg_idx, aval in enumerate(args_avals):
|
||||
for axis_idx, d in enumerate(aval.shape):
|
||||
if not core.is_constant_dim(d):
|
||||
all_dim_vars = all_dim_vars.union(d.get_vars())
|
||||
d_var = d.to_var()
|
||||
# TODO(necula): compute dim vars from non-trivial expressions also
|
||||
if d_var is None: continue
|
||||
if not d_var in dim_args_builder_dict:
|
||||
dim_args_builder_dict[d_var] = partial(get_dim_arg, arg_idx, axis_idx)
|
||||
|
||||
if all_dim_vars:
|
||||
dim_vars_with_builders_set = set(dim_args_builder_dict.keys())
|
||||
if dim_vars_with_builders_set != all_dim_vars:
|
||||
missing = all_dim_vars.difference(dim_vars_with_builders_set)
|
||||
args_list = [f" Arg[{arg_idx}]: {aval}"
|
||||
for arg_idx, aval in enumerate(args_avals)]
|
||||
raise ValueError(
|
||||
"The following dimension variables cannot be computed from the static "
|
||||
f"shapes of the array arguments: {missing}. The argument shapes are:\n" +
|
||||
"\n".join(args_list) +
|
||||
"\n"
|
||||
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.")
|
||||
|
||||
# In sorted order by name
|
||||
builders = [dim_args_builder_dict[d_var] for d_var in sorted(dim_args_builder_dict.keys())]
|
||||
else:
|
||||
builders = []
|
||||
return builders
|
||||
res.append(dim_arg)
|
||||
return tuple(res)
|
||||
|
||||
|
||||
def check_module(mod: mlir.ir.Module, *,
|
||||
|
@ -56,10 +56,11 @@ from jax._src.typing import DimSize, Shape
|
||||
|
||||
TfVal = Any
|
||||
# A dimension environment maps dimension variables to expressions that
|
||||
# compute the values of the dimension. An expression must support addition
|
||||
# denote the values of the dimension (DimExprValue). A DimExprValue must support addition
|
||||
# and multiplication with other expressions or with int32/int64, e.g.,
|
||||
# by overloading __add__, __radd__, __mul__, __rmul__, __divmod__, __rdivmod__.
|
||||
ShapeEnv = Dict[str, Any]
|
||||
DimExprValue = Any
|
||||
ShapeEnv = Dict[str, DimExprValue]
|
||||
DType = Any
|
||||
|
||||
class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation):
|
||||
@ -324,7 +325,7 @@ class _DimExpr():
|
||||
"""Symbolic expression in terms of dimension variables.
|
||||
|
||||
A dimension expression is an addition of products (_DimMon)
|
||||
f atoms (_DimAtom).
|
||||
of atoms (_DimAtom).
|
||||
|
||||
We overload integer operations, but we do that soundly, raising
|
||||
:class:`InconclusiveDimensionOperation` when the result is not
|
||||
@ -647,7 +648,7 @@ class _DimExpr():
|
||||
|
||||
@staticmethod
|
||||
def get_aval(dim: "_DimExpr"):
|
||||
return dim_as_value_abstract(dim)
|
||||
return dim_as_value_abstract()
|
||||
|
||||
def dimension_as_value(self):
|
||||
"""Turns a dimension size into a Jax value that we can compute with."""
|
||||
@ -786,10 +787,10 @@ def dim_as_value_dtype():
|
||||
def dim_constant(ct: int):
|
||||
return np.array(ct, dtype=dim_as_value_dtype())
|
||||
|
||||
def dim_as_value_abstract(dim: DimSize) -> core.AbstractValue:
|
||||
def dim_as_value_abstract() -> core.AbstractValue:
|
||||
return core.ShapedArray((), dim_as_value_dtype(), weak_type=True)
|
||||
|
||||
dim_as_value_p.def_abstract_eval(dim_as_value_abstract)
|
||||
dim_as_value_p.def_abstract_eval(lambda dim: dim_as_value_abstract())
|
||||
|
||||
def dim_as_value_impl(dim: DimSize):
|
||||
raise NotImplementedError(
|
||||
@ -992,7 +993,7 @@ def _is_known_constant(v) -> Optional[int]:
|
||||
# value of type shape_poly.dim_as_value_dtype().
|
||||
dimension_size_p = core.Primitive("dimension_size")
|
||||
def _dimension_size_abstract(aval: core.AbstractValue, **_) -> core.AbstractValue:
|
||||
return dim_as_value_abstract(aval)
|
||||
return dim_as_value_abstract()
|
||||
|
||||
dimension_size_p.def_abstract_eval(_dimension_size_abstract)
|
||||
|
||||
@ -1004,9 +1005,13 @@ _JaxValue = Any
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DimEquation:
|
||||
# Represents poly == _expr
|
||||
poly: _DimExpr
|
||||
dim_expr: _JaxValue # Of type dim_as_value_dtype()
|
||||
# Represents args[arg_idx].shape[dim_idx] == dim_expr
|
||||
arg_idx: int
|
||||
dim_idx: int
|
||||
dim_expr: _DimExpr
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.dim_expr} == args[{self.arg_idx}].shape[{self.dim_idx}]"
|
||||
|
||||
|
||||
def get_shape_evaluator(dim_vars: Sequence[str], shape: Sequence[DimSize]) ->\
|
||||
@ -1041,37 +1046,50 @@ def arg_aval(
|
||||
aval_shape = _parse_spec(polymorphic_shape, arg_shape)
|
||||
return core.ShapedArray(aval_shape, arg_jax_dtype)
|
||||
|
||||
def prepare_dim_var_env(args_avals: Sequence[core.AbstractValue]) -> \
|
||||
Tuple[Sequence[str],
|
||||
Callable[..., Sequence[TfVal]]]:
|
||||
"""Get the dimension variables and the function to compute them.
|
||||
|
||||
Returns a tuple of dimension variables that appear in `args_avals` along
|
||||
with a function that given the actual arguments of the top-level function
|
||||
returns a tuple of dimension variable values, in the same order as the
|
||||
dimension variables returned in the first component.
|
||||
The dimension variables are TfVal with type dim_as_value_dtype().
|
||||
"""
|
||||
def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Set[str]:
|
||||
dim_vars: Set[str] = set()
|
||||
for a in args_avals:
|
||||
for d in a.shape:
|
||||
if is_poly_dim(d):
|
||||
dim_vars = dim_vars.union(d.get_vars())
|
||||
return dim_vars
|
||||
|
||||
def prepare_dim_var_env(
|
||||
args_avals: Sequence[core.AbstractValue],
|
||||
get_dimension_size: Callable[[Sequence[Any], int, int], DimExprValue]) -> \
|
||||
Tuple[Sequence[str],
|
||||
Callable[..., Sequence[DimExprValue]]]:
|
||||
"""Get the dimension variables and the function to compute them.
|
||||
|
||||
Args:
|
||||
args_aval: the abstract values of the array arguments
|
||||
get_dimension_size: a function that given the array arguments, the argument
|
||||
index, and the dimension index, computes the value of args[arg_idx].shape[dim_idx].
|
||||
|
||||
Returns: a tuple of dimension variables that appear in `args_avals` along
|
||||
with a function that given the actual arguments of the top-level function
|
||||
returns a tuple of dimension variable values, in the same order as the
|
||||
dimension variables returned in the first component.
|
||||
"""
|
||||
# TODO(necula): clean up the notion of shape environments, and solving for dimension
|
||||
# variables.
|
||||
dim_vars: Set[str] = all_dim_vars(args_avals)
|
||||
dim_vars_sorted = sorted(tuple(dim_vars))
|
||||
def get_dim_var_values(*args: Any) -> Sequence[Any]:
|
||||
dim_equations: List[DimEquation] = []
|
||||
for a in args:
|
||||
for i, d in enumerate(a.shape):
|
||||
for arg_idx, a in enumerate(args_avals):
|
||||
for dim_idx, d in enumerate(a.shape):
|
||||
if is_poly_dim(d):
|
||||
dim_equations.append(DimEquation(
|
||||
poly=d, dim_expr=dimension_size_p.bind(a, dimension=i)))
|
||||
dim_equations.append(
|
||||
DimEquation(arg_idx=arg_idx, dim_idx=dim_idx, dim_expr=d))
|
||||
dim_env = _solve_dim_equations(dim_equations,
|
||||
functools.partial(get_dimension_size, args))
|
||||
return tuple(dim_env[dv] for dv in dim_vars_sorted)
|
||||
return dim_vars_sorted, get_dim_var_values
|
||||
|
||||
dim_env = _solve_dim_equations(dim_equations)
|
||||
assert all(dim_env[dv].dtype == dim_as_value_dtype() for dv in dim_vars)
|
||||
return tuple(dim_env[dv] for dv in dim_vars)
|
||||
return tuple(dim_vars), get_dim_var_values
|
||||
|
||||
def _solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
|
||||
def _solve_dim_equations(eqns: List[DimEquation],
|
||||
get_dimension_size: Callable[[int, int], DimExprValue]) -> ShapeEnv:
|
||||
# Returns a shape environment if it can solve all dimension variables.
|
||||
# Raises an exception if it cannot.
|
||||
shapeenv: ShapeEnv = {}
|
||||
@ -1084,19 +1102,19 @@ def _solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
|
||||
return ""
|
||||
|
||||
def process_one_eqn(eqn: DimEquation) -> bool:
|
||||
# Try to rewrite the equation as "var * factor_var = dim_expr" (a linear
|
||||
# uni-variate equation. Return False if this rewrite fails.
|
||||
# Try to rewrite the equation as "var * factor_var = dim_value" (a linear
|
||||
# uni-variate equation). Return False if this rewrite fails.
|
||||
# Otherwise, add the variable to shapeenv and return True.
|
||||
|
||||
# The invariant is: var * factor_var + rest_eqn_poly = dim_expr
|
||||
var, factor_var = None, None
|
||||
dim_expr = eqn.dim_expr
|
||||
dim_value = get_dimension_size(eqn.arg_idx, eqn.dim_idx)
|
||||
# The invariant is: var * factor_var + rest_eqn_dim_expr = dim_value
|
||||
|
||||
for mon, factor in eqn.poly.monomials():
|
||||
for mon, factor in eqn.dim_expr.monomials():
|
||||
# Perhaps we can already evaluate this monomial (all vars solved)
|
||||
try:
|
||||
mon_value = mon.evaluate(shapeenv)
|
||||
dim_expr = dim_expr - _evaluate_multiply(mon_value, dim_constant(factor))
|
||||
dim_value = dim_value + -1 * _evaluate_multiply(mon_value, dim_constant(factor))
|
||||
continue
|
||||
except KeyError:
|
||||
# There are some indeterminate variables. We handle only the case of
|
||||
@ -1110,10 +1128,9 @@ def _solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
|
||||
|
||||
if var is not None:
|
||||
if factor_var == 1:
|
||||
var_value, var_remainder = dim_expr, dim_constant(0)
|
||||
var_value, var_remainder = dim_value, dim_constant(0)
|
||||
else:
|
||||
var_value = lax.div(dim_expr, dim_constant(factor_var)) # type: ignore
|
||||
var_remainder = lax.rem(dim_expr, dim_constant(factor_var)) # type: ignore
|
||||
var_value, var_remainder = divmod(dim_value, dim_constant(factor_var)) # type: ignore
|
||||
|
||||
# Check that the division is even. Works only in eager mode.
|
||||
var_remainder_int = _is_known_constant(var_remainder)
|
||||
@ -1121,25 +1138,25 @@ def _solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
|
||||
# TODO(necula): check even in graph mode, by embedding the checks in
|
||||
# the graph.
|
||||
msg = (f"Dimension variable {var} must have integer value >= 1. " # type: ignore
|
||||
f"Found value {int(_is_known_constant(dim_expr)) / factor_var} when solving " # type: ignore
|
||||
f"{eqn.poly} == {eqn.dim_expr}.{_shapeenv_to_str()}")
|
||||
f"Found value {int(_is_known_constant(dim_value)) / factor_var} when solving " # type: ignore
|
||||
f"{eqn}.{_shapeenv_to_str()}")
|
||||
raise ValueError(msg)
|
||||
var_value_int = _is_known_constant(var_value)
|
||||
if var_value_int is not None and var_value_int <= 0:
|
||||
msg = (f"{var_value_int} Dimension variable {var} must have integer value >= 1. "
|
||||
f"Found value {int(var_value_int)} when solving "
|
||||
f"{eqn.poly} == {eqn.dim_expr}.{_shapeenv_to_str()}")
|
||||
f"{eqn}.{_shapeenv_to_str()}")
|
||||
raise ValueError(msg)
|
||||
|
||||
shapeenv[var] = var_value
|
||||
return True
|
||||
else:
|
||||
# All variables are resolved for this equation
|
||||
dim_expr_int = _is_known_constant(dim_expr)
|
||||
if dim_expr_int is not None and dim_expr_int != 0:
|
||||
dim_value_int = _is_known_constant(dim_value)
|
||||
if dim_value_int is not None and dim_value_int != 0:
|
||||
err_msg = (
|
||||
"Found inconsistency when solving "
|
||||
f"{eqn.poly} == {eqn.dim_expr}.{_shapeenv_to_str()}")
|
||||
f"{eqn}.{_shapeenv_to_str()}")
|
||||
raise ValueError(err_msg)
|
||||
return True
|
||||
|
||||
@ -1155,10 +1172,10 @@ def _solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
|
||||
unsolved_vars: Set[str] = set()
|
||||
unsolved_polys: List[_DimExpr] = []
|
||||
for eqn in eqns:
|
||||
unsolved_vars = unsolved_vars.union(eqn.poly.get_vars())
|
||||
unsolved_polys.append(eqn.poly)
|
||||
unsolved_vars = unsolved_vars.union(eqn.dim_expr.get_vars())
|
||||
unsolved_polys.append(eqn.dim_expr)
|
||||
unsolved_vars = unsolved_vars.difference(shapeenv.keys())
|
||||
eqns_str = "\n ".join([str(eqn.poly) for eqn in eqns])
|
||||
eqns_str = "\n ".join([str(eqn.dim_expr) for eqn in eqns])
|
||||
err_msg = (
|
||||
f"Cannot solve for values of dimension variables {unsolved_vars} from "
|
||||
f"the remaining dimension polynomials\n {eqns_str}.{_shapeenv_to_str()} "
|
||||
|
@ -739,7 +739,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
arg_dtypes = (_f32,) * len(arg_shapes)
|
||||
def f_tf(*args_tf):
|
||||
avals = tuple(map(shape_poly.arg_aval, arg_shapes, arg_dtypes, polymorphic_shapes))
|
||||
dim_vars, get_dim_values_jax = shape_poly.prepare_dim_var_env(avals)
|
||||
def get_dimension_size(args, arg_idx: int, dim_idx: int) -> shape_poly.DimExprValue:
|
||||
return shape_poly.dimension_size_p.bind(args[arg_idx], dimension=dim_idx)
|
||||
dim_vars, get_dim_values_jax = shape_poly.prepare_dim_var_env(avals, get_dimension_size)
|
||||
dim_values, _ = jax2tf.jax2tf._interpret_fun_jax(get_dim_values_jax,
|
||||
args_tf, avals, "")
|
||||
if expected_avals is not None:
|
||||
@ -1007,16 +1009,25 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
poly_axes=[(0, 1)])
|
||||
|
||||
def test_non_trivial_polynomials_spec(self):
|
||||
if config.jax_dynamic_shapes:
|
||||
raise unittest.SkipTest("--jax_dynamic_shapes supports only trivial polynomials")
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_{str(polymorphic_shapes)}",
|
||||
polymorphic_shapes=polymorphic_shapes)
|
||||
# The polymorphic_shapes should have three comma-separated DimExpr matching
|
||||
# 16, 24, 32
|
||||
for polymorphic_shapes in [
|
||||
"b1+6,b1+14,b2", # b1=10, b2=32
|
||||
"2*b1,4*b2,b1+b2+18", # b1=8,b2=6
|
||||
"b1+2*b2,4*b2,b1*b1+16", # b1=4,b2=6
|
||||
])
|
||||
def test_non_trivial_polynomials_spec(self,
|
||||
polymorphic_shapes="2*b1,4*b2,b1+b2+18"):
|
||||
# We can handle non-trivial polynomials in the input shape,
|
||||
# as long as all variables also occur in trivial polynoamials
|
||||
# as long as all variables also occur in trivial expressions
|
||||
check_shape_poly(self,
|
||||
lambda x, y: x + y.reshape((-1,)),
|
||||
arg_descriptors=[RandArg((9,), _f32), RandArg((3, 3), _f32)],
|
||||
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None, None])],
|
||||
polymorphic_shapes=["b * b", "b, b"])
|
||||
lambda x: 2 * x.shape[0] + 3 * x.shape[1] + 4 * x.shape[2],
|
||||
arg_descriptors=[RandArg((16, 24, 32), _f32)],
|
||||
input_signature=[tf.TensorSpec([None, None, None])],
|
||||
polymorphic_shapes=polymorphic_shapes)
|
||||
|
||||
def test_unused_args(self):
|
||||
# Tests with functions that do not use their inputs.
|
||||
@ -1166,9 +1177,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertEqual((None, 3, 4), tuple(tf_grad.output_shapes[1]["grad"]))
|
||||
|
||||
def test_grad_not_var_output(self):
|
||||
# Output of the function has poly shapes, non-variable
|
||||
if config.jax2tf_default_native_serialization:
|
||||
raise unittest.SkipTest("Not supported with native serialization")
|
||||
def f_jax(x): # :[b, 3]
|
||||
return jnp.reshape(x, (-1,)) # : [3b]
|
||||
x = np.arange(12, dtype=np.float32).reshape((4, 3))
|
||||
@ -1182,7 +1190,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
grad_tf = tape.gradient(res_tf, xv)
|
||||
self.assertAllClose(np.ones(x.shape, dtype=np.float32), grad_tf.numpy())
|
||||
|
||||
|
||||
def test_cond(self):
|
||||
# Test the primitive under conditional
|
||||
def f(x, y):
|
||||
@ -1280,9 +1287,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(f_jax(y), restored_f(y))
|
||||
|
||||
def test_saved_model_int_function(self):
|
||||
# TODO(https://github.com/google/jax/issues/14437)
|
||||
if config.jax2tf_default_native_serialization:
|
||||
raise unittest.SkipTest("Gradient function does not use the dimension variables")
|
||||
|
||||
def f_jax(x): # x:s32[b, 3, 4]
|
||||
return jnp.reshape(x, (-1,)) # : s32[b * 12]
|
||||
@ -1301,10 +1305,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(f_jax(x), res_jax_rt)
|
||||
|
||||
def test_saved_model_constant_gradient(self):
|
||||
# TODO(https://github.com/google/jax/issues/14437)
|
||||
if config.jax2tf_default_native_serialization:
|
||||
raise unittest.SkipTest("Gradient function does not use the dimension variables")
|
||||
|
||||
def f_jax(x): # A function whose gradient is a constant
|
||||
return x
|
||||
|
||||
@ -1369,10 +1369,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
jax2tf.convert(lambda x: jnp.reshape(x, (-1, x.shape[0])),
|
||||
polymorphic_shapes=["(b1, b2, ...)"])(np.ones((4, 5, 6)))
|
||||
|
||||
if not config.jax2tf_default_native_serialization:
|
||||
# Does not support 2*b constraints
|
||||
jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),
|
||||
polymorphic_shapes=["(2*b, ...)"])(np.ones((4, 5, 7)))
|
||||
jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),
|
||||
polymorphic_shapes=["(2*b, ...)"])(np.ones((4, 5, 7)))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
core.InconclusiveDimensionOperation,
|
||||
@ -1398,7 +1396,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
|
||||
native_serialization=False)(x0)
|
||||
|
||||
# TODO(https://github.com/google/jax/issues/14437)
|
||||
# In native serialization, or if we trace to a TF graph, we miss this
|
||||
res1_tf = jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
|
||||
native_serialization=True)(x0)
|
||||
@ -1427,7 +1424,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
|
||||
native_serialization=False)(x45)
|
||||
|
||||
# TODO(https://github.com/google/jax/issues/14437)
|
||||
# In native serialization, or if we trace to a TF graph, we miss this
|
||||
res2_tf = jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
|
||||
native_serialization=True)(x45)
|
||||
@ -1443,13 +1439,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Cannot solve for values of dimension variables"):
|
||||
jax2tf.convert(lambda x: jnp.sum(x), polymorphic_shapes=["a + b"],
|
||||
native_serialization=False)(x)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"dimension variables cannot be computed from the static shapes of the array arguments"):
|
||||
jax2tf.convert(lambda x: jnp.sum(x), polymorphic_shapes=["a + b"],
|
||||
native_serialization=True)(x)
|
||||
jax2tf.convert(lambda x: jnp.sum(x), polymorphic_shapes=["a + b"])(x)
|
||||
|
||||
|
||||
def test_dynamic_shapes(self):
|
||||
# Test dim_as_value with dynamic shapes.
|
||||
@ -1470,11 +1461,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
polymorphic_shapes=["b", "b"])(x, np.array([0.1, 0.2, 0.3]))
|
||||
self.assertAllClose((9., 1.8), (res_primal, res_tangent))
|
||||
|
||||
# TODO(https://github.com/google/jax/issues/14437)
|
||||
if not config.jax2tf_default_native_serialization:
|
||||
self.assertAllClose(
|
||||
np.array([3., 3., 3.]),
|
||||
jax2tf.convert(jax.grad(f), polymorphic_shapes=["b"])(x))
|
||||
self.assertAllClose(
|
||||
np.array([3., 3., 3.]),
|
||||
jax2tf.convert(jax.grad(f), polymorphic_shapes=["b"])(x))
|
||||
|
||||
xv = np.arange(24.).reshape((2, 3, 4))
|
||||
res_vmap = jax.vmap(f, in_axes=1)(xv)
|
||||
|
@ -16,7 +16,6 @@ from jax._src.interpreters.mlir import (
|
||||
AxisContext as AxisContext,
|
||||
ConstantHandler as ConstantHandler,
|
||||
DEVICE_TO_DEVICE_TYPE as DEVICE_TO_DEVICE_TYPE,
|
||||
DimExprEvaluator as DimExprEvaluator,
|
||||
LoweringResult as LoweringResult,
|
||||
LoweringRule as LoweringRule,
|
||||
LoweringRuleContext as LoweringRuleContext,
|
||||
|
Loading…
x
Reference in New Issue
Block a user