Merge pull request #15340 from gnecula:dim_vars3

PiperOrigin-RevId: 521424534
This commit is contained in:
jax authors 2023-04-03 04:44:16 -07:00
commit 0d32724882
6 changed files with 135 additions and 192 deletions

View File

@ -52,7 +52,7 @@ from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
import jax._src.pretty_printer as pp
from jax._src.lib import jax_jit
from jax._src import traceback_util
from jax._src.typing import DimSize, OpaqueDType, Shape
from jax._src.typing import Array, DimSize, OpaqueDType, Shape
from jax._src import typing
traceback_util.register_exclusion(__file__)
@ -2064,6 +2064,37 @@ def _invalid_shape_error(shape: Shape, context: str=""):
return TypeError(msg)
def evaluate_shape(shape: Shape, dim_vars: Sequence[str],
*dim_values: Array) -> Sequence[Array]:
"""Evaluates a shape possibly containing non-constants.
Args:
shape: the shape to evaluate.
dim_vars: the dimension variables names that may appear in `shape`.
dim_values: the dimension values corresponding to `dim_vars`.
Returns:
a tuple of JAX values corresponding to `shape`, of type
`dim_value_dtype`.
"""
env = dict(zip(dim_vars, dim_values))
def eval_one_dim(d: DimSize):
try:
return operator.index(d)
except:
# Is a _DimExpr
return d.evaluate(env) # type: ignore
return tuple(eval_one_dim(d) for d in shape)
def dim_value_dtype():
"""The dtype to be used for dimension values."""
return dtypes.canonicalize_dtype(np.int64)
def dim_constant(ct: int):
return np.array(ct, dtype=dim_value_dtype())
def dim_value_aval() -> AbstractValue:
return ShapedArray((), dim_value_dtype(), weak_type=True)
# ------------------- Named shapes -------------------

View File

@ -21,7 +21,6 @@ import functools
from functools import partial
import io
import itertools
import operator
import re
import typing
from typing import (Any, Callable, Dict, Iterator, List, NamedTuple, Optional,
@ -572,71 +571,20 @@ def sharded_aval(aval: core.ShapedArray,
return aval.update(tuple(sharded_shape))
class DimExprValueMlir:
# TODO(necula): remove this, use regular JAX lowering
# A wrapper for an ir.Value that overloads + and * to be used for evaluating
# symbolic dimensions resulting in ir.Value. See shape_poly.DimExprValue.
__array_priority__ = 1000 # Same as tracer, for __radd__ and others on ndarray
def __init__(self, value: ir.Value):
self.value = value
def __add__(self, other: Union[np.int32, np.int64, DimExprValueMlir]):
if not isinstance(other, DimExprValueMlir):
other = DimExprValueMlir(ir_constant(other))
return DimExprValueMlir(hlo.AddOp(self.value, other.value).result)
def __radd__(self, other: Union[np.int32, np.int64]):
return DimExprValueMlir(ir_constant(other)).__add__(self)
def __mul__(self, other: Union[np.int32, np.int64, DimExprValueMlir]):
if not isinstance(other, DimExprValueMlir):
other = DimExprValueMlir(ir_constant(other))
return DimExprValueMlir(hlo.MulOp(self.value, other.value).result)
def __rmul__(self, other: Union[np.int32, np.int64]):
return DimExprValueMlir(ir_constant(other)).__mul__(self)
def __divmod__(self, divisor: Union[np.int32, np.int64, DimExprValueMlir]):
if not isinstance(divisor, DimExprValueMlir):
divisor = DimExprValueMlir(ir_constant(divisor))
# Quotient
raw_quotient = hlo.DivOp(self.value, divisor.value)
raw_remainder = hlo.RemOp(self.value, divisor.value)
ops_different_sign = compare_hlo(hlo.SignOp(self.value),
hlo.SignOp(divisor.value),
"NE", "SIGNED")
rem_ne_zero = compare_hlo(raw_remainder, ir_constant(np.int64(0)),
"NE", "SIGNED")
must_adjust = hlo.AndOp(ops_different_sign, rem_ne_zero)
quotient = hlo.SelectOp(must_adjust,
hlo.SubtractOp(raw_quotient, ir_constant(np.int64(1))),
raw_quotient)
# Remainder
remainder = hlo.SubtractOp(self.value, hlo.MulOp(divisor.value, quotient))
return (DimExprValueMlir(quotient.result),
DimExprValueMlir(remainder.result))
def __rdivmod__(self, dividend: Union[np.int32, np.int64]):
return DimExprValueMlir(ir_constant(dividend)).__divmod__(self)
def eval_dynamic_shape(ctx: LoweringRuleContext,
shape: core.Shape) -> Tuple[Union[int, Value], ...]:
# assert not core.is_constant_shape(shape)
if config.jax_dynamic_shapes:
return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore
else:
dim_var_env = {dv_name: DimExprValueMlir(dv_val[0])
for dv_name, dv_val in zip(ctx.module_context.dim_vars, ctx.dim_var_values)}
def eval_dim(d: core.DimSize) -> Union[int, ir.Value]:
try:
return operator.index(d)
except:
if isinstance(d, ir.Value):
return d
else:
# Is a dimension polynomial
return d.evaluate(dim_var_env).value # type: ignore
return tuple(eval_dim(d) for d in shape)
ctx = ctx.replace(
primitive="eval_dynamic_shape",
avals_in=[core.dim_value_aval()] * len(ctx.module_context.dim_vars))
res = lower_fun(
partial(core.evaluate_shape, shape, ctx.module_context.dim_vars),
multiple_results=True)(ctx, *ctx.dim_var_values)
return util.flatten(res) # type: ignore
class LoweringResult(NamedTuple):
module: ir.Module

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides JAX and TensorFlow interoperation APIs."""
import dataclasses
from functools import partial
import contextlib
import operator
@ -55,7 +54,6 @@ from jax._src import random as random_internal
from jax._src import source_info_util
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax import lax as lax_internal
@ -382,7 +380,7 @@ def convert(fun_jax: Callable,
args_flat_tf, args_avals_flat = util.unzip2(args_and_avals)
if native_serialization:
shape_env = ()
shape_env: Sequence[Tuple[str, TfVal]] = ()
if native_serialization_platforms:
lowering_platform = native_serialization_platforms[0]
else:
@ -399,14 +397,11 @@ def convert(fun_jax: Callable,
native_serialization_strict_checks)
return outs_tf, out_avals
else:
def get_dimension_size(args, arg_idx: int, dim_idx: int) -> shape_poly.DimExprValue:
return shape_poly.dimension_size_p.bind(args[arg_idx], dimension=dim_idx)
dim_vars, get_dim_values_jax = shape_poly.prepare_dim_var_env(
args_avals_flat, get_dimension_size)
dim_values, _ = _interpret_fun_jax(get_dim_values_jax, args_flat_tf,
args_avals_flat, name_stack)
shape_env = zip(dim_vars, dim_values) # type: ignore
dim_vars = shape_poly.all_dim_vars(args_avals_flat)
dim_values, _ = _interpret_fun_jax(
partial(shape_poly.compute_dim_values, args_avals_flat, dim_vars),
args_flat_tf, args_avals_flat, name_stack)
shape_env = zip(dim_vars, dim_values)
exported = None
def run_fun_flat_as_tf(
args_flat_tf: Sequence[TfVal]
@ -984,10 +979,9 @@ def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfV
return tuple(int(d) for d in shape)
dim_vars, dim_values = util.unzip2(_thread_local_state.shape_env)
eval_shape_jax = shape_poly.get_shape_evaluator(dim_vars, shape)
dim_aval = shape_poly.dim_as_value_abstract()
shape_values_tf, _ = _interpret_fun_jax(eval_shape_jax,
dim_values, [dim_aval] * len(dim_values), "") # type: ignore
shape_values_tf, _ = _interpret_fun_jax(
partial(core.evaluate_shape, shape, dim_vars),
dim_values, [core.dim_value_aval()] * len(dim_values), "") # type: ignore
# Keep only the non-constant dimensions
return tuple(operator.index(d) if core.is_constant_dim(d) else d_tf
for d, d_tf in zip(shape, shape_values_tf))

View File

@ -16,8 +16,10 @@
This module is used with jax2tf, but should have no TensorFlow dependencies.
"""
import dataclasses
import functools
import itertools
import re
from typing import Callable, List, Optional, Sequence, Union
from typing import Callable, List, Optional, Sequence, Union
from absl import logging
@ -25,6 +27,7 @@ import jax
from jax import sharding
from jax._src import core
from jax._src import source_info_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.interpreters import mlir
@ -34,6 +37,7 @@ from jax._src.lib.mlir.dialects import stablehlo
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import func as func_dialect
from jax.experimental.jax2tf import shape_poly
map = util.safe_map
@ -283,7 +287,14 @@ def add_dim_arg_computation(module: mlir.ir.Module,
entry_block = new_main_op.add_entry_block()
with ir.InsertionPoint(entry_block):
orig_main_args: List[mlir.ir.Value] = []
dim_args = compute_dim_args(args_avals, tuple(new_main_op.arguments),
module_context = mlir.ModuleContext(
"cpu", "cpu", mlir.ShardingContext([]),
source_info_util.new_name_stack(),
[], itertools.count(1), [], module=new_module, context=context)
ctx = mlir.LoweringRuleContext(module_context=module_context,
primitive=None, avals_in=args_avals, avals_out=None,
tokens_in=mlir.TokenSet(), tokens_out=None)
dim_args = compute_dim_args(ctx, args_avals, tuple(new_main_op.arguments),
orig_input_types[:len(dim_vars)])
# The first arguments are the dimension variable
orig_main_args.extend(dim_args)
@ -298,6 +309,7 @@ def add_dim_arg_computation(module: mlir.ir.Module,
def compute_dim_args(
ctx: mlir.LoweringRuleContext,
args_avals: Sequence[core.ShapedArray],
array_args: Sequence[mlir.ir.Value],
dim_arg_types: Sequence[mlir.ir.Type]) -> Sequence[mlir.ir.Value]:
@ -312,20 +324,12 @@ def compute_dim_args(
the values of the dimension variables, in the sorted order of the
dimension variables.
"""
def get_dimension_size(args, arg_idx: int, dim_idx: int) -> mlir.DimExprValueMlir:
dim_size = hlo.GetDimensionSizeOp(args[arg_idx], dim_idx).result
dim_type = mlir.aval_to_ir_type(shape_poly.dim_as_value_abstract())
if dim_size.type != dim_type:
dim_size = hlo.ConvertOp(dim_type, dim_size).result
return mlir.DimExprValueMlir(dim_size)
all_dim_vars, dim_arg_builders = shape_poly.prepare_dim_var_env(
args_avals, get_dimension_size)
all_dim_args: Sequence[mlir.DimExprValueMlir] = dim_arg_builders(*array_args)
dim_vars = shape_poly.all_dim_vars(args_avals)
dim_values = mlir.lower_fun(
functools.partial(shape_poly.compute_dim_values, args_avals, dim_vars),
multiple_results=True)(ctx, *array_args)
res = []
for dim_arg, dim_arg_type in zip(all_dim_args, dim_arg_types):
dim_arg = dim_arg.value
for dim_arg, dim_arg_type in zip(util.flatten(dim_values), dim_arg_types):
if dim_arg.type != dim_arg_type:
res.append(hlo.ConvertOp(dim_arg_type, dim_arg).result)
else:

View File

@ -32,12 +32,12 @@ jax2tf.convert docstring, and the
"""
import collections
import dataclasses
import itertools
import functools
import itertools
import math
import operator as op
import re
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
import numpy as np
import opt_einsum
@ -50,17 +50,11 @@ from jax._src import core
from jax._src import dtypes
from jax._src.interpreters import mlir
from jax._src.numpy import lax_numpy
from jax._src.lax import lax
from jax._src.typing import DimSize, Shape
TfVal = Any
# A dimension environment maps dimension variables to expressions that
# denote the values of the dimension (DimExprValue). A DimExprValue must support addition
# and multiplication with other expressions or with int32/int64, e.g.,
# by overloading __add__, __radd__, __mul__, __rmul__, __divmod__, __rdivmod__.
DimExprValue = Any
ShapeEnv = Dict[str, DimExprValue]
DimVarEnv = Dict[str, jax.Array]
DType = Any
class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation):
@ -208,7 +202,7 @@ class _DimAtom:
else:
assert False
def evaluate(self, env: ShapeEnv):
def evaluate(self, env: DimVarEnv):
if self.var is not None:
try:
return env[self.var]
@ -314,8 +308,8 @@ class _DimMon(dict):
return (min(*candidates), max(*candidates)) # type: ignore
def evaluate(self, env: ShapeEnv):
prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else dim_constant(1)
def evaluate(self, env: DimVarEnv):
prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else core.dim_constant(1)
def pow_opt(v, p: int):
return v if p == 1 else prod([v] * p)
return prod([pow_opt(a.evaluate(env), deg) for a, deg in self.items()])
@ -641,14 +635,14 @@ class _DimExpr():
"""Returns the highest degree term that comes first lexicographically."""
return max(self.monomials())
def evaluate(self, env: ShapeEnv):
# Evaluates as a value of dtype=dim_as_value_dtype()
terms = [_evaluate_multiply(mon.evaluate(env), dim_constant(coeff)) for mon, coeff in self.monomials()]
def evaluate(self, env: DimVarEnv):
# Evaluates as a value of dtype=core.dim_value_dtype()
terms = [_evaluate_multiply(mon.evaluate(env), core.dim_constant(coeff)) for mon, coeff in self.monomials()]
return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0]
@staticmethod
def get_aval(dim: "_DimExpr"):
return dim_as_value_abstract()
return core.dim_value_aval()
def dimension_as_value(self):
"""Turns a dimension size into a Jax value that we can compute with."""
@ -779,18 +773,9 @@ lax_numpy._poly_einsum_handlers[_DimExpr] = _einsum_contract_path
# A JAX primitive with no array arguments but with a dimension parameter
# that is a DimExpr. The value of the primitive is the value of the dimension,
# using int64 in x64 mode or int32 otherwise (dim_as_value_dtype())
# using int64 in x64 mode or int32 otherwise (core.dim_value_dtype())
dim_as_value_p = core.Primitive("dim_as_value")
def dim_as_value_dtype():
return dtypes.canonicalize_dtype(np.int64)
def dim_constant(ct: int):
return np.array(ct, dtype=dim_as_value_dtype())
def dim_as_value_abstract() -> core.AbstractValue:
return core.ShapedArray((), dim_as_value_dtype(), weak_type=True)
dim_as_value_p.def_abstract_eval(lambda dim: dim_as_value_abstract())
dim_as_value_p.def_abstract_eval(lambda dim: core.dim_value_aval())
def dim_as_value_impl(dim: DimSize):
raise NotImplementedError(
@ -992,44 +977,24 @@ def _is_known_constant(v) -> Optional[int]:
# dimension_size(operand, dimension=i) get the operand.shape[i] as a
# value of type shape_poly.dim_as_value_dtype().
dimension_size_p = core.Primitive("dimension_size")
def _dimension_size_abstract(aval: core.AbstractValue, **_) -> core.AbstractValue:
return dim_as_value_abstract()
def _dimension_size_abstract_eval(aval: core.AbstractValue, **_) -> core.AbstractValue:
return core.dim_value_aval()
dimension_size_p.def_abstract_eval(_dimension_size_abstract)
dimension_size_p.def_abstract_eval(_dimension_size_abstract_eval)
def _dimension_size_impl(arg, *, dimension):
return dim_constant(arg.shape[dimension])
return core.dim_constant(arg.shape[dimension])
dimension_size_p.def_impl(_dimension_size_impl)
_JaxValue = Any
def _dimension_size_lowering_rule(ctx, arg, *, dimension):
dim_size = mlir.hlo.GetDimensionSizeOp(arg, dimension)
dim_type = mlir.aval_to_ir_type(core.dim_value_aval())
if dim_size.result.type != dim_type:
dim_size = mlir.hlo.ConvertOp(dim_type, dim_size)
return dim_size.results
@dataclasses.dataclass
class DimEquation:
# Represents args[arg_idx].shape[dim_idx] == dim_expr
arg_idx: int
dim_idx: int
dim_expr: _DimExpr
mlir.register_lowering(dimension_size_p, _dimension_size_lowering_rule)
def __str__(self):
return f"{self.dim_expr} == args[{self.arg_idx}].shape[{self.dim_idx}]"
def get_shape_evaluator(dim_vars: Sequence[str], shape: Sequence[DimSize]) ->\
Callable[..., TfVal]:
"""Prepares a shape evaluator.
Returns a JAX function that given the values for the dimension variables
returns the values for the dimensions of `shape`.
"""
def eval_shape(*dim_values: Any) -> Sequence[Any]:
shape_env_jax = dict(zip(dim_vars, dim_values))
def eval_dim(d: DimSize):
return d.evaluate(shape_env_jax) # type: ignore[union-attr]
return tuple(eval_dim(d) if type(d) is _DimExpr else np.array(d, dtype=dim_as_value_dtype()) # type: ignore
for d in shape)
return eval_shape
def arg_aval(
arg_shape: Sequence[Optional[int]],
@ -1046,53 +1011,55 @@ def arg_aval(
aval_shape = _parse_spec(polymorphic_shape, arg_shape)
return core.ShapedArray(aval_shape, arg_jax_dtype)
def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Set[str]:
def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Sequence[str]:
dim_vars: Set[str] = set()
for a in args_avals:
for d in a.shape:
if is_poly_dim(d):
dim_vars = dim_vars.union(d.get_vars())
return dim_vars
return sorted(tuple(dim_vars))
def prepare_dim_var_env(
@dataclasses.dataclass
class DimEquation:
# Represents arg.shape[dim_idx] == dim_expr
arg_idx: int
arg: jax.Array
dim_idx: int
dim_expr: _DimExpr
def __str__(self):
return f"{self.dim_expr} == args[{self.arg_idx}].shape[{self.dim_idx}]"
def compute_dim_values(
args_avals: Sequence[core.AbstractValue],
get_dimension_size: Callable[[Sequence[Any], int, int], DimExprValue]) -> \
Tuple[Sequence[str],
Callable[..., Sequence[DimExprValue]]]:
"""Get the dimension variables and the function to compute them.
dim_vars: Sequence[str],
*args: jax.Array) -> Sequence[jax.Array]:
"""Compute values of dimension variables from the actual arguments.
Args:
args_aval: the abstract values of the array arguments
get_dimension_size: a function that given the array arguments, the argument
index, and the dimension index, computes the value of args[arg_idx].shape[dim_idx].
args_avals: the abstract values of the `args`, with shapes that may
include dimension variables.
dim_vars: the dimension variables
args: the actual arguments
Returns: a tuple of dimension variables that appear in `args_avals` along
with a function that given the actual arguments of the top-level function
returns a tuple of dimension variable values, in the same order as the
dimension variables returned in the first component.
Returns: the values of `dim_vars`.
"""
# TODO(necula): clean up the notion of shape environments, and solving for dimension
# variables.
dim_vars: Set[str] = all_dim_vars(args_avals)
dim_vars_sorted = sorted(tuple(dim_vars))
def get_dim_var_values(*args: Any) -> Sequence[Any]:
dim_equations: List[DimEquation] = []
for arg_idx, a in enumerate(args_avals):
for dim_idx, d in enumerate(a.shape):
if is_poly_dim(d):
dim_equations.append(
DimEquation(arg_idx=arg_idx, dim_idx=dim_idx, dim_expr=d))
dim_env = _solve_dim_equations(dim_equations,
functools.partial(get_dimension_size, args))
return tuple(dim_env[dv] for dv in dim_vars_sorted)
return dim_vars_sorted, get_dim_var_values
dim_equations: List[DimEquation] = []
for arg_idx, a in enumerate(args_avals):
for dim_idx, d in enumerate(a.shape):
if is_poly_dim(d):
dim_equations.append(
DimEquation(arg_idx=arg_idx, arg=args[arg_idx],
dim_idx=dim_idx, dim_expr=d))
dim_env = _solve_dim_equations(dim_equations)
dim_values = tuple(dim_env[dv] for dv in dim_vars)
return dim_values
def _solve_dim_equations(eqns: List[DimEquation],
get_dimension_size: Callable[[int, int], DimExprValue]) -> ShapeEnv:
def _solve_dim_equations(eqns: List[DimEquation]) -> DimVarEnv:
# Returns a shape environment if it can solve all dimension variables.
# Raises an exception if it cannot.
shapeenv: ShapeEnv = {}
shapeenv: DimVarEnv = {}
def _shapeenv_to_str() -> str:
if shapeenv:
@ -1107,14 +1074,14 @@ def _solve_dim_equations(eqns: List[DimEquation],
# Otherwise, add the variable to shapeenv and return True.
var, factor_var = None, None
dim_value = get_dimension_size(eqn.arg_idx, eqn.dim_idx)
dim_value = dimension_size_p.bind(eqn.arg, dimension=eqn.dim_idx)
# The invariant is: var * factor_var + rest_eqn_dim_expr = dim_value
for mon, factor in eqn.dim_expr.monomials():
# Perhaps we can already evaluate this monomial (all vars solved)
try:
mon_value = mon.evaluate(shapeenv)
dim_value = dim_value + -1 * _evaluate_multiply(mon_value, dim_constant(factor))
dim_value = dim_value + -1 * _evaluate_multiply(mon_value, core.dim_constant(factor))
continue
except KeyError:
# There are some indeterminate variables. We handle only the case of
@ -1128,9 +1095,9 @@ def _solve_dim_equations(eqns: List[DimEquation],
if var is not None:
if factor_var == 1:
var_value, var_remainder = dim_value, dim_constant(0)
var_value, var_remainder = dim_value, core.dim_constant(0)
else:
var_value, var_remainder = divmod(dim_value, dim_constant(factor_var)) # type: ignore
var_value, var_remainder = divmod(dim_value, core.dim_constant(factor_var)) # type: ignore
# Check that the division is even. Works only in eager mode.
var_remainder_int = _is_known_constant(var_remainder)
@ -1148,7 +1115,7 @@ def _solve_dim_equations(eqns: List[DimEquation],
f"{eqn}.{_shapeenv_to_str()}")
raise ValueError(msg)
shapeenv[var] = var_value
shapeenv[var] = var_value.astype(core.dim_value_dtype())
return True
else:
# All variables are resolved for this equation

View File

@ -797,11 +797,10 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
arg_dtypes = (_f32,) * len(arg_shapes)
def f_tf(*args_tf):
avals = tuple(map(shape_poly.arg_aval, arg_shapes, arg_dtypes, polymorphic_shapes))
def get_dimension_size(args, arg_idx: int, dim_idx: int) -> shape_poly.DimExprValue:
return shape_poly.dimension_size_p.bind(args[arg_idx], dimension=dim_idx)
dim_vars, get_dim_values_jax = shape_poly.prepare_dim_var_env(avals, get_dimension_size)
dim_values, _ = jax2tf.jax2tf._interpret_fun_jax(get_dim_values_jax,
args_tf, avals, "")
dim_vars = shape_poly.all_dim_vars(avals)
dim_values, _ = jax2tf.jax2tf._interpret_fun_jax(
partial(shape_poly.compute_dim_values, avals, dim_vars),
args_tf, avals, "")
if expected_avals is not None:
self.assertEqual(expected_avals, avals)
return dict(zip(dim_vars, dim_values))