mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15340 from gnecula:dim_vars3
PiperOrigin-RevId: 521424534
This commit is contained in:
commit
0d32724882
@ -52,7 +52,7 @@ from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src import traceback_util
|
||||
from jax._src.typing import DimSize, OpaqueDType, Shape
|
||||
from jax._src.typing import Array, DimSize, OpaqueDType, Shape
|
||||
from jax._src import typing
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
@ -2064,6 +2064,37 @@ def _invalid_shape_error(shape: Shape, context: str=""):
|
||||
|
||||
return TypeError(msg)
|
||||
|
||||
def evaluate_shape(shape: Shape, dim_vars: Sequence[str],
|
||||
*dim_values: Array) -> Sequence[Array]:
|
||||
"""Evaluates a shape possibly containing non-constants.
|
||||
|
||||
Args:
|
||||
shape: the shape to evaluate.
|
||||
dim_vars: the dimension variables names that may appear in `shape`.
|
||||
dim_values: the dimension values corresponding to `dim_vars`.
|
||||
|
||||
Returns:
|
||||
a tuple of JAX values corresponding to `shape`, of type
|
||||
`dim_value_dtype`.
|
||||
"""
|
||||
env = dict(zip(dim_vars, dim_values))
|
||||
def eval_one_dim(d: DimSize):
|
||||
try:
|
||||
return operator.index(d)
|
||||
except:
|
||||
# Is a _DimExpr
|
||||
return d.evaluate(env) # type: ignore
|
||||
return tuple(eval_one_dim(d) for d in shape)
|
||||
|
||||
def dim_value_dtype():
|
||||
"""The dtype to be used for dimension values."""
|
||||
return dtypes.canonicalize_dtype(np.int64)
|
||||
|
||||
def dim_constant(ct: int):
|
||||
return np.array(ct, dtype=dim_value_dtype())
|
||||
|
||||
def dim_value_aval() -> AbstractValue:
|
||||
return ShapedArray((), dim_value_dtype(), weak_type=True)
|
||||
|
||||
# ------------------- Named shapes -------------------
|
||||
|
||||
|
@ -21,7 +21,6 @@ import functools
|
||||
from functools import partial
|
||||
import io
|
||||
import itertools
|
||||
import operator
|
||||
import re
|
||||
import typing
|
||||
from typing import (Any, Callable, Dict, Iterator, List, NamedTuple, Optional,
|
||||
@ -572,71 +571,20 @@ def sharded_aval(aval: core.ShapedArray,
|
||||
return aval.update(tuple(sharded_shape))
|
||||
|
||||
|
||||
class DimExprValueMlir:
|
||||
# TODO(necula): remove this, use regular JAX lowering
|
||||
# A wrapper for an ir.Value that overloads + and * to be used for evaluating
|
||||
# symbolic dimensions resulting in ir.Value. See shape_poly.DimExprValue.
|
||||
__array_priority__ = 1000 # Same as tracer, for __radd__ and others on ndarray
|
||||
def __init__(self, value: ir.Value):
|
||||
self.value = value
|
||||
|
||||
def __add__(self, other: Union[np.int32, np.int64, DimExprValueMlir]):
|
||||
if not isinstance(other, DimExprValueMlir):
|
||||
other = DimExprValueMlir(ir_constant(other))
|
||||
return DimExprValueMlir(hlo.AddOp(self.value, other.value).result)
|
||||
|
||||
def __radd__(self, other: Union[np.int32, np.int64]):
|
||||
return DimExprValueMlir(ir_constant(other)).__add__(self)
|
||||
|
||||
def __mul__(self, other: Union[np.int32, np.int64, DimExprValueMlir]):
|
||||
if not isinstance(other, DimExprValueMlir):
|
||||
other = DimExprValueMlir(ir_constant(other))
|
||||
return DimExprValueMlir(hlo.MulOp(self.value, other.value).result)
|
||||
|
||||
def __rmul__(self, other: Union[np.int32, np.int64]):
|
||||
return DimExprValueMlir(ir_constant(other)).__mul__(self)
|
||||
|
||||
def __divmod__(self, divisor: Union[np.int32, np.int64, DimExprValueMlir]):
|
||||
if not isinstance(divisor, DimExprValueMlir):
|
||||
divisor = DimExprValueMlir(ir_constant(divisor))
|
||||
# Quotient
|
||||
raw_quotient = hlo.DivOp(self.value, divisor.value)
|
||||
raw_remainder = hlo.RemOp(self.value, divisor.value)
|
||||
ops_different_sign = compare_hlo(hlo.SignOp(self.value),
|
||||
hlo.SignOp(divisor.value),
|
||||
"NE", "SIGNED")
|
||||
rem_ne_zero = compare_hlo(raw_remainder, ir_constant(np.int64(0)),
|
||||
"NE", "SIGNED")
|
||||
must_adjust = hlo.AndOp(ops_different_sign, rem_ne_zero)
|
||||
quotient = hlo.SelectOp(must_adjust,
|
||||
hlo.SubtractOp(raw_quotient, ir_constant(np.int64(1))),
|
||||
raw_quotient)
|
||||
# Remainder
|
||||
remainder = hlo.SubtractOp(self.value, hlo.MulOp(divisor.value, quotient))
|
||||
return (DimExprValueMlir(quotient.result),
|
||||
DimExprValueMlir(remainder.result))
|
||||
|
||||
def __rdivmod__(self, dividend: Union[np.int32, np.int64]):
|
||||
return DimExprValueMlir(ir_constant(dividend)).__divmod__(self)
|
||||
|
||||
def eval_dynamic_shape(ctx: LoweringRuleContext,
|
||||
shape: core.Shape) -> Tuple[Union[int, Value], ...]:
|
||||
# assert not core.is_constant_shape(shape)
|
||||
if config.jax_dynamic_shapes:
|
||||
return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore
|
||||
else:
|
||||
dim_var_env = {dv_name: DimExprValueMlir(dv_val[0])
|
||||
for dv_name, dv_val in zip(ctx.module_context.dim_vars, ctx.dim_var_values)}
|
||||
def eval_dim(d: core.DimSize) -> Union[int, ir.Value]:
|
||||
try:
|
||||
return operator.index(d)
|
||||
except:
|
||||
if isinstance(d, ir.Value):
|
||||
return d
|
||||
else:
|
||||
# Is a dimension polynomial
|
||||
return d.evaluate(dim_var_env).value # type: ignore
|
||||
return tuple(eval_dim(d) for d in shape)
|
||||
ctx = ctx.replace(
|
||||
primitive="eval_dynamic_shape",
|
||||
avals_in=[core.dim_value_aval()] * len(ctx.module_context.dim_vars))
|
||||
res = lower_fun(
|
||||
partial(core.evaluate_shape, shape, ctx.module_context.dim_vars),
|
||||
multiple_results=True)(ctx, *ctx.dim_var_values)
|
||||
return util.flatten(res) # type: ignore
|
||||
|
||||
|
||||
class LoweringResult(NamedTuple):
|
||||
module: ir.Module
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Provides JAX and TensorFlow interoperation APIs."""
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import contextlib
|
||||
import operator
|
||||
@ -55,7 +54,6 @@ from jax._src import random as random_internal
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax._src.lax import lax as lax_internal
|
||||
@ -382,7 +380,7 @@ def convert(fun_jax: Callable,
|
||||
args_flat_tf, args_avals_flat = util.unzip2(args_and_avals)
|
||||
|
||||
if native_serialization:
|
||||
shape_env = ()
|
||||
shape_env: Sequence[Tuple[str, TfVal]] = ()
|
||||
if native_serialization_platforms:
|
||||
lowering_platform = native_serialization_platforms[0]
|
||||
else:
|
||||
@ -399,14 +397,11 @@ def convert(fun_jax: Callable,
|
||||
native_serialization_strict_checks)
|
||||
return outs_tf, out_avals
|
||||
else:
|
||||
def get_dimension_size(args, arg_idx: int, dim_idx: int) -> shape_poly.DimExprValue:
|
||||
return shape_poly.dimension_size_p.bind(args[arg_idx], dimension=dim_idx)
|
||||
dim_vars, get_dim_values_jax = shape_poly.prepare_dim_var_env(
|
||||
args_avals_flat, get_dimension_size)
|
||||
|
||||
dim_values, _ = _interpret_fun_jax(get_dim_values_jax, args_flat_tf,
|
||||
args_avals_flat, name_stack)
|
||||
shape_env = zip(dim_vars, dim_values) # type: ignore
|
||||
dim_vars = shape_poly.all_dim_vars(args_avals_flat)
|
||||
dim_values, _ = _interpret_fun_jax(
|
||||
partial(shape_poly.compute_dim_values, args_avals_flat, dim_vars),
|
||||
args_flat_tf, args_avals_flat, name_stack)
|
||||
shape_env = zip(dim_vars, dim_values)
|
||||
exported = None
|
||||
def run_fun_flat_as_tf(
|
||||
args_flat_tf: Sequence[TfVal]
|
||||
@ -984,10 +979,9 @@ def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfV
|
||||
return tuple(int(d) for d in shape)
|
||||
|
||||
dim_vars, dim_values = util.unzip2(_thread_local_state.shape_env)
|
||||
eval_shape_jax = shape_poly.get_shape_evaluator(dim_vars, shape)
|
||||
dim_aval = shape_poly.dim_as_value_abstract()
|
||||
shape_values_tf, _ = _interpret_fun_jax(eval_shape_jax,
|
||||
dim_values, [dim_aval] * len(dim_values), "") # type: ignore
|
||||
shape_values_tf, _ = _interpret_fun_jax(
|
||||
partial(core.evaluate_shape, shape, dim_vars),
|
||||
dim_values, [core.dim_value_aval()] * len(dim_values), "") # type: ignore
|
||||
# Keep only the non-constant dimensions
|
||||
return tuple(operator.index(d) if core.is_constant_dim(d) else d_tf
|
||||
for d, d_tf in zip(shape, shape_values_tf))
|
||||
|
@ -16,8 +16,10 @@
|
||||
This module is used with jax2tf, but should have no TensorFlow dependencies.
|
||||
"""
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import re
|
||||
from typing import Callable, List, Optional, Sequence, Union
|
||||
from typing import Callable, List, Optional, Sequence, Union
|
||||
|
||||
from absl import logging
|
||||
|
||||
@ -25,6 +27,7 @@ import jax
|
||||
from jax import sharding
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import mlir
|
||||
@ -34,6 +37,7 @@ from jax._src.lib.mlir.dialects import stablehlo
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
|
||||
from jax.experimental.jax2tf import shape_poly
|
||||
|
||||
map = util.safe_map
|
||||
@ -283,7 +287,14 @@ def add_dim_arg_computation(module: mlir.ir.Module,
|
||||
entry_block = new_main_op.add_entry_block()
|
||||
with ir.InsertionPoint(entry_block):
|
||||
orig_main_args: List[mlir.ir.Value] = []
|
||||
dim_args = compute_dim_args(args_avals, tuple(new_main_op.arguments),
|
||||
module_context = mlir.ModuleContext(
|
||||
"cpu", "cpu", mlir.ShardingContext([]),
|
||||
source_info_util.new_name_stack(),
|
||||
[], itertools.count(1), [], module=new_module, context=context)
|
||||
ctx = mlir.LoweringRuleContext(module_context=module_context,
|
||||
primitive=None, avals_in=args_avals, avals_out=None,
|
||||
tokens_in=mlir.TokenSet(), tokens_out=None)
|
||||
dim_args = compute_dim_args(ctx, args_avals, tuple(new_main_op.arguments),
|
||||
orig_input_types[:len(dim_vars)])
|
||||
# The first arguments are the dimension variable
|
||||
orig_main_args.extend(dim_args)
|
||||
@ -298,6 +309,7 @@ def add_dim_arg_computation(module: mlir.ir.Module,
|
||||
|
||||
|
||||
def compute_dim_args(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
args_avals: Sequence[core.ShapedArray],
|
||||
array_args: Sequence[mlir.ir.Value],
|
||||
dim_arg_types: Sequence[mlir.ir.Type]) -> Sequence[mlir.ir.Value]:
|
||||
@ -312,20 +324,12 @@ def compute_dim_args(
|
||||
the values of the dimension variables, in the sorted order of the
|
||||
dimension variables.
|
||||
"""
|
||||
def get_dimension_size(args, arg_idx: int, dim_idx: int) -> mlir.DimExprValueMlir:
|
||||
dim_size = hlo.GetDimensionSizeOp(args[arg_idx], dim_idx).result
|
||||
dim_type = mlir.aval_to_ir_type(shape_poly.dim_as_value_abstract())
|
||||
if dim_size.type != dim_type:
|
||||
dim_size = hlo.ConvertOp(dim_type, dim_size).result
|
||||
return mlir.DimExprValueMlir(dim_size)
|
||||
|
||||
all_dim_vars, dim_arg_builders = shape_poly.prepare_dim_var_env(
|
||||
args_avals, get_dimension_size)
|
||||
all_dim_args: Sequence[mlir.DimExprValueMlir] = dim_arg_builders(*array_args)
|
||||
|
||||
dim_vars = shape_poly.all_dim_vars(args_avals)
|
||||
dim_values = mlir.lower_fun(
|
||||
functools.partial(shape_poly.compute_dim_values, args_avals, dim_vars),
|
||||
multiple_results=True)(ctx, *array_args)
|
||||
res = []
|
||||
for dim_arg, dim_arg_type in zip(all_dim_args, dim_arg_types):
|
||||
dim_arg = dim_arg.value
|
||||
for dim_arg, dim_arg_type in zip(util.flatten(dim_values), dim_arg_types):
|
||||
if dim_arg.type != dim_arg_type:
|
||||
res.append(hlo.ConvertOp(dim_arg_type, dim_arg).result)
|
||||
else:
|
||||
|
@ -32,12 +32,12 @@ jax2tf.convert docstring, and the
|
||||
"""
|
||||
import collections
|
||||
import dataclasses
|
||||
import itertools
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
import operator as op
|
||||
import re
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import opt_einsum
|
||||
@ -50,17 +50,11 @@ from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.lax import lax
|
||||
from jax._src.typing import DimSize, Shape
|
||||
|
||||
|
||||
TfVal = Any
|
||||
# A dimension environment maps dimension variables to expressions that
|
||||
# denote the values of the dimension (DimExprValue). A DimExprValue must support addition
|
||||
# and multiplication with other expressions or with int32/int64, e.g.,
|
||||
# by overloading __add__, __radd__, __mul__, __rmul__, __divmod__, __rdivmod__.
|
||||
DimExprValue = Any
|
||||
ShapeEnv = Dict[str, DimExprValue]
|
||||
DimVarEnv = Dict[str, jax.Array]
|
||||
DType = Any
|
||||
|
||||
class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation):
|
||||
@ -208,7 +202,7 @@ class _DimAtom:
|
||||
else:
|
||||
assert False
|
||||
|
||||
def evaluate(self, env: ShapeEnv):
|
||||
def evaluate(self, env: DimVarEnv):
|
||||
if self.var is not None:
|
||||
try:
|
||||
return env[self.var]
|
||||
@ -314,8 +308,8 @@ class _DimMon(dict):
|
||||
return (min(*candidates), max(*candidates)) # type: ignore
|
||||
|
||||
|
||||
def evaluate(self, env: ShapeEnv):
|
||||
prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else dim_constant(1)
|
||||
def evaluate(self, env: DimVarEnv):
|
||||
prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else core.dim_constant(1)
|
||||
def pow_opt(v, p: int):
|
||||
return v if p == 1 else prod([v] * p)
|
||||
return prod([pow_opt(a.evaluate(env), deg) for a, deg in self.items()])
|
||||
@ -641,14 +635,14 @@ class _DimExpr():
|
||||
"""Returns the highest degree term that comes first lexicographically."""
|
||||
return max(self.monomials())
|
||||
|
||||
def evaluate(self, env: ShapeEnv):
|
||||
# Evaluates as a value of dtype=dim_as_value_dtype()
|
||||
terms = [_evaluate_multiply(mon.evaluate(env), dim_constant(coeff)) for mon, coeff in self.monomials()]
|
||||
def evaluate(self, env: DimVarEnv):
|
||||
# Evaluates as a value of dtype=core.dim_value_dtype()
|
||||
terms = [_evaluate_multiply(mon.evaluate(env), core.dim_constant(coeff)) for mon, coeff in self.monomials()]
|
||||
return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0]
|
||||
|
||||
@staticmethod
|
||||
def get_aval(dim: "_DimExpr"):
|
||||
return dim_as_value_abstract()
|
||||
return core.dim_value_aval()
|
||||
|
||||
def dimension_as_value(self):
|
||||
"""Turns a dimension size into a Jax value that we can compute with."""
|
||||
@ -779,18 +773,9 @@ lax_numpy._poly_einsum_handlers[_DimExpr] = _einsum_contract_path
|
||||
|
||||
# A JAX primitive with no array arguments but with a dimension parameter
|
||||
# that is a DimExpr. The value of the primitive is the value of the dimension,
|
||||
# using int64 in x64 mode or int32 otherwise (dim_as_value_dtype())
|
||||
# using int64 in x64 mode or int32 otherwise (core.dim_value_dtype())
|
||||
dim_as_value_p = core.Primitive("dim_as_value")
|
||||
def dim_as_value_dtype():
|
||||
return dtypes.canonicalize_dtype(np.int64)
|
||||
|
||||
def dim_constant(ct: int):
|
||||
return np.array(ct, dtype=dim_as_value_dtype())
|
||||
|
||||
def dim_as_value_abstract() -> core.AbstractValue:
|
||||
return core.ShapedArray((), dim_as_value_dtype(), weak_type=True)
|
||||
|
||||
dim_as_value_p.def_abstract_eval(lambda dim: dim_as_value_abstract())
|
||||
dim_as_value_p.def_abstract_eval(lambda dim: core.dim_value_aval())
|
||||
|
||||
def dim_as_value_impl(dim: DimSize):
|
||||
raise NotImplementedError(
|
||||
@ -992,44 +977,24 @@ def _is_known_constant(v) -> Optional[int]:
|
||||
# dimension_size(operand, dimension=i) get the operand.shape[i] as a
|
||||
# value of type shape_poly.dim_as_value_dtype().
|
||||
dimension_size_p = core.Primitive("dimension_size")
|
||||
def _dimension_size_abstract(aval: core.AbstractValue, **_) -> core.AbstractValue:
|
||||
return dim_as_value_abstract()
|
||||
def _dimension_size_abstract_eval(aval: core.AbstractValue, **_) -> core.AbstractValue:
|
||||
return core.dim_value_aval()
|
||||
|
||||
dimension_size_p.def_abstract_eval(_dimension_size_abstract)
|
||||
dimension_size_p.def_abstract_eval(_dimension_size_abstract_eval)
|
||||
|
||||
def _dimension_size_impl(arg, *, dimension):
|
||||
return dim_constant(arg.shape[dimension])
|
||||
return core.dim_constant(arg.shape[dimension])
|
||||
dimension_size_p.def_impl(_dimension_size_impl)
|
||||
|
||||
_JaxValue = Any
|
||||
def _dimension_size_lowering_rule(ctx, arg, *, dimension):
|
||||
dim_size = mlir.hlo.GetDimensionSizeOp(arg, dimension)
|
||||
dim_type = mlir.aval_to_ir_type(core.dim_value_aval())
|
||||
if dim_size.result.type != dim_type:
|
||||
dim_size = mlir.hlo.ConvertOp(dim_type, dim_size)
|
||||
return dim_size.results
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DimEquation:
|
||||
# Represents args[arg_idx].shape[dim_idx] == dim_expr
|
||||
arg_idx: int
|
||||
dim_idx: int
|
||||
dim_expr: _DimExpr
|
||||
mlir.register_lowering(dimension_size_p, _dimension_size_lowering_rule)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.dim_expr} == args[{self.arg_idx}].shape[{self.dim_idx}]"
|
||||
|
||||
|
||||
def get_shape_evaluator(dim_vars: Sequence[str], shape: Sequence[DimSize]) ->\
|
||||
Callable[..., TfVal]:
|
||||
"""Prepares a shape evaluator.
|
||||
|
||||
Returns a JAX function that given the values for the dimension variables
|
||||
returns the values for the dimensions of `shape`.
|
||||
"""
|
||||
def eval_shape(*dim_values: Any) -> Sequence[Any]:
|
||||
shape_env_jax = dict(zip(dim_vars, dim_values))
|
||||
|
||||
def eval_dim(d: DimSize):
|
||||
return d.evaluate(shape_env_jax) # type: ignore[union-attr]
|
||||
|
||||
return tuple(eval_dim(d) if type(d) is _DimExpr else np.array(d, dtype=dim_as_value_dtype()) # type: ignore
|
||||
for d in shape)
|
||||
return eval_shape
|
||||
|
||||
def arg_aval(
|
||||
arg_shape: Sequence[Optional[int]],
|
||||
@ -1046,53 +1011,55 @@ def arg_aval(
|
||||
aval_shape = _parse_spec(polymorphic_shape, arg_shape)
|
||||
return core.ShapedArray(aval_shape, arg_jax_dtype)
|
||||
|
||||
|
||||
def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Set[str]:
|
||||
def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Sequence[str]:
|
||||
dim_vars: Set[str] = set()
|
||||
for a in args_avals:
|
||||
for d in a.shape:
|
||||
if is_poly_dim(d):
|
||||
dim_vars = dim_vars.union(d.get_vars())
|
||||
return dim_vars
|
||||
return sorted(tuple(dim_vars))
|
||||
|
||||
def prepare_dim_var_env(
|
||||
@dataclasses.dataclass
|
||||
class DimEquation:
|
||||
# Represents arg.shape[dim_idx] == dim_expr
|
||||
arg_idx: int
|
||||
arg: jax.Array
|
||||
dim_idx: int
|
||||
dim_expr: _DimExpr
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.dim_expr} == args[{self.arg_idx}].shape[{self.dim_idx}]"
|
||||
|
||||
def compute_dim_values(
|
||||
args_avals: Sequence[core.AbstractValue],
|
||||
get_dimension_size: Callable[[Sequence[Any], int, int], DimExprValue]) -> \
|
||||
Tuple[Sequence[str],
|
||||
Callable[..., Sequence[DimExprValue]]]:
|
||||
"""Get the dimension variables and the function to compute them.
|
||||
dim_vars: Sequence[str],
|
||||
*args: jax.Array) -> Sequence[jax.Array]:
|
||||
"""Compute values of dimension variables from the actual arguments.
|
||||
|
||||
Args:
|
||||
args_aval: the abstract values of the array arguments
|
||||
get_dimension_size: a function that given the array arguments, the argument
|
||||
index, and the dimension index, computes the value of args[arg_idx].shape[dim_idx].
|
||||
args_avals: the abstract values of the `args`, with shapes that may
|
||||
include dimension variables.
|
||||
dim_vars: the dimension variables
|
||||
args: the actual arguments
|
||||
|
||||
Returns: a tuple of dimension variables that appear in `args_avals` along
|
||||
with a function that given the actual arguments of the top-level function
|
||||
returns a tuple of dimension variable values, in the same order as the
|
||||
dimension variables returned in the first component.
|
||||
Returns: the values of `dim_vars`.
|
||||
"""
|
||||
# TODO(necula): clean up the notion of shape environments, and solving for dimension
|
||||
# variables.
|
||||
dim_vars: Set[str] = all_dim_vars(args_avals)
|
||||
dim_vars_sorted = sorted(tuple(dim_vars))
|
||||
def get_dim_var_values(*args: Any) -> Sequence[Any]:
|
||||
dim_equations: List[DimEquation] = []
|
||||
for arg_idx, a in enumerate(args_avals):
|
||||
for dim_idx, d in enumerate(a.shape):
|
||||
if is_poly_dim(d):
|
||||
dim_equations.append(
|
||||
DimEquation(arg_idx=arg_idx, dim_idx=dim_idx, dim_expr=d))
|
||||
dim_env = _solve_dim_equations(dim_equations,
|
||||
functools.partial(get_dimension_size, args))
|
||||
return tuple(dim_env[dv] for dv in dim_vars_sorted)
|
||||
return dim_vars_sorted, get_dim_var_values
|
||||
dim_equations: List[DimEquation] = []
|
||||
for arg_idx, a in enumerate(args_avals):
|
||||
for dim_idx, d in enumerate(a.shape):
|
||||
if is_poly_dim(d):
|
||||
dim_equations.append(
|
||||
DimEquation(arg_idx=arg_idx, arg=args[arg_idx],
|
||||
dim_idx=dim_idx, dim_expr=d))
|
||||
dim_env = _solve_dim_equations(dim_equations)
|
||||
dim_values = tuple(dim_env[dv] for dv in dim_vars)
|
||||
return dim_values
|
||||
|
||||
def _solve_dim_equations(eqns: List[DimEquation],
|
||||
get_dimension_size: Callable[[int, int], DimExprValue]) -> ShapeEnv:
|
||||
|
||||
def _solve_dim_equations(eqns: List[DimEquation]) -> DimVarEnv:
|
||||
# Returns a shape environment if it can solve all dimension variables.
|
||||
# Raises an exception if it cannot.
|
||||
shapeenv: ShapeEnv = {}
|
||||
shapeenv: DimVarEnv = {}
|
||||
|
||||
def _shapeenv_to_str() -> str:
|
||||
if shapeenv:
|
||||
@ -1107,14 +1074,14 @@ def _solve_dim_equations(eqns: List[DimEquation],
|
||||
# Otherwise, add the variable to shapeenv and return True.
|
||||
|
||||
var, factor_var = None, None
|
||||
dim_value = get_dimension_size(eqn.arg_idx, eqn.dim_idx)
|
||||
dim_value = dimension_size_p.bind(eqn.arg, dimension=eqn.dim_idx)
|
||||
# The invariant is: var * factor_var + rest_eqn_dim_expr = dim_value
|
||||
|
||||
for mon, factor in eqn.dim_expr.monomials():
|
||||
# Perhaps we can already evaluate this monomial (all vars solved)
|
||||
try:
|
||||
mon_value = mon.evaluate(shapeenv)
|
||||
dim_value = dim_value + -1 * _evaluate_multiply(mon_value, dim_constant(factor))
|
||||
dim_value = dim_value + -1 * _evaluate_multiply(mon_value, core.dim_constant(factor))
|
||||
continue
|
||||
except KeyError:
|
||||
# There are some indeterminate variables. We handle only the case of
|
||||
@ -1128,9 +1095,9 @@ def _solve_dim_equations(eqns: List[DimEquation],
|
||||
|
||||
if var is not None:
|
||||
if factor_var == 1:
|
||||
var_value, var_remainder = dim_value, dim_constant(0)
|
||||
var_value, var_remainder = dim_value, core.dim_constant(0)
|
||||
else:
|
||||
var_value, var_remainder = divmod(dim_value, dim_constant(factor_var)) # type: ignore
|
||||
var_value, var_remainder = divmod(dim_value, core.dim_constant(factor_var)) # type: ignore
|
||||
|
||||
# Check that the division is even. Works only in eager mode.
|
||||
var_remainder_int = _is_known_constant(var_remainder)
|
||||
@ -1148,7 +1115,7 @@ def _solve_dim_equations(eqns: List[DimEquation],
|
||||
f"{eqn}.{_shapeenv_to_str()}")
|
||||
raise ValueError(msg)
|
||||
|
||||
shapeenv[var] = var_value
|
||||
shapeenv[var] = var_value.astype(core.dim_value_dtype())
|
||||
return True
|
||||
else:
|
||||
# All variables are resolved for this equation
|
||||
|
@ -797,11 +797,10 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
arg_dtypes = (_f32,) * len(arg_shapes)
|
||||
def f_tf(*args_tf):
|
||||
avals = tuple(map(shape_poly.arg_aval, arg_shapes, arg_dtypes, polymorphic_shapes))
|
||||
def get_dimension_size(args, arg_idx: int, dim_idx: int) -> shape_poly.DimExprValue:
|
||||
return shape_poly.dimension_size_p.bind(args[arg_idx], dimension=dim_idx)
|
||||
dim_vars, get_dim_values_jax = shape_poly.prepare_dim_var_env(avals, get_dimension_size)
|
||||
dim_values, _ = jax2tf.jax2tf._interpret_fun_jax(get_dim_values_jax,
|
||||
args_tf, avals, "")
|
||||
dim_vars = shape_poly.all_dim_vars(avals)
|
||||
dim_values, _ = jax2tf.jax2tf._interpret_fun_jax(
|
||||
partial(shape_poly.compute_dim_values, avals, dim_vars),
|
||||
args_tf, avals, "")
|
||||
if expected_avals is not None:
|
||||
self.assertEqual(expected_avals, avals)
|
||||
return dict(zip(dim_vars, dim_values))
|
||||
|
Loading…
x
Reference in New Issue
Block a user