George Necula b62ceba91c [jax2tf] Expand shape polymorphism support to use dimension polynomials as values.
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.

This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,

```
def average(x):
   return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```

This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.

Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:

```
def dim_as_value(d):
   jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```

We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
2021-07-27 09:02:15 +03:00

632 lines
22 KiB
Python

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shape polymorphism support for jax2tf.
For usage instructions, read the jax2tf.convert docstring, and the
[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
"""
import collections
import itertools
import functools
import operator as op
import re
from typing import Any, Dict, Optional, Sequence, Set, Tuple, Union
import jax
from jax._src.numpy import lax_numpy
import opt_einsum
from jax import config
from jax import core
from . import jax2tf as jax2tf_internal
import numpy as np
import tensorflow as tf # type: ignore[import]
DimSize = core.DimSize
Shape = core.Shape
class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation):
"""Raised when we cannot conclusively compute with symbolic dimensions."""
_help_msg = """
This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
"""
def __init__(self, message: str):
error_msg = f"{message}\n{InconclusiveDimensionOperation._help_msg}"
# https://github.com/python/mypy/issues/5887
super().__init__(error_msg) # type: ignore
class _DimMon(dict):
"""Represents a multivariate monomial, such as n^3 * m.
The representation is a dictionary mapping var:exponent.
The `var` are strings and the exponents are >= 1.
The shape variables are assumed to range over integers >= 1.
"""
def __hash__(self):
return hash(frozenset(self.items()))
def __str__(self):
return "*".join(f"{key}^{exponent}" if exponent != 1 else str(key)
for key, exponent in sorted(self.items()))
@classmethod
def from_var(cls, v: str) -> '_DimMon':
return _DimMon({v: 1})
def to_var(self) -> Optional[str]:
"""Extract the variable name "x", from a monomial "x".
Return None, if the monomial is not a single variable."""
items = self.items()
if len(items) != 1:
return None
(v, vexp), = items
if vexp != 1:
return None
return v
def get_vars(self) -> Set[str]:
return set(self.keys())
@property
def degree(self):
return sum(self.values())
def __lt__(self, other: '_DimMon'):
"""
Comparison to another monomial in graded reverse lexicographic order.
Used for sorting.
"""
self_key = -self.degree, tuple(sorted(self))
other_key = -other.degree, tuple(sorted(other))
return self_key > other_key
def mul(self, other: '_DimMon') -> '_DimMon':
"""
Returns the product with another monomial. Example: (n^2*m) * n == n^3 * m.
"""
return _DimMon(collections.Counter(self) + collections.Counter(other))
def divide(self, divisor: '_DimMon') -> '_DimMon':
"""
Divides by another monomial. Raises a InconclusiveDimensionOperation
if the result is not a monomial.
For example, (n^3 * m) // n == n^2*m, but n // m fails.
"""
d = collections.Counter(self)
for key, exponent in divisor.items():
diff = self.get(key, 0) - exponent
if diff < 0:
raise InconclusiveDimensionOperation(f"Cannot divide {self} by {divisor}.")
elif diff == 0: del d[key]
elif diff > 0: d[key] = diff
return _DimMon(d)
class _DimPolynomial(dict):
"""Polynomial with integer coefficients for polymorphic shapes.
The shape variables are assumed to range over integers >= 1.
We overload integer operations, but we do that soundly, raising
:class:`InconclusiveDimensionOperation` when the result is not
representable as a polynomial.
The representation of a polynomial is as a dictionary mapping _DimMonomial to
integer coefficients. The special monomial `_DimMonomial()` is mapped to the
free integer coefficient of the polynomial. The constant result of arithmetic
operations is represented as a Python constant.
"""
def __init__(self, coeffs: Dict[_DimMon, int]):
# Makes sure Polynomials are always in canonical form
coeffs = {mon: op.index(coeff)
for mon, coeff in coeffs.items() if coeff != 0}
coeffs = coeffs or {_DimMon(): 0}
super().__init__(coeffs)
@classmethod
def from_coeffs(cls, coeffs: Dict[_DimMon, int]) -> DimSize:
"""Constructs _DimPolynomial or an int."""
has_non_zero_degree = False
zero_degree_const = 0
new_coeffs = {}
for mon, count in coeffs.items():
if count != 0:
if mon.degree == 0:
zero_degree_const = count
else:
has_non_zero_degree = True
new_coeffs[mon] = count
if not has_non_zero_degree:
return int(zero_degree_const)
return _DimPolynomial(new_coeffs)
@classmethod
def from_var(cls, v: str) -> '_DimPolynomial':
return _DimPolynomial({_DimMon.from_var(v): 1})
def to_var(self) -> Optional[str]:
"""Extract the variable name "x", from a polynomial "x" """
items = self.items()
if len(items) != 1:
return None
(mon, mon_count), = items
if mon_count != 1:
return None
return mon.to_var()
def get_vars(self) -> Set[str]:
"""The variables that appear in a polynomial."""
acc = set()
for mon, _ in self.items():
acc.update(mon.get_vars())
return acc
def __hash__(self):
return hash(tuple(sorted(self.items())))
def __str__(self):
def _one_monomial(mon, c):
if mon.degree == 0:
return str(c)
if c == 1:
return str(mon)
return f"{c}*{mon}"
return " + ".join(_one_monomial(mon, c)
for mon, c in sorted(self.items(), reverse=True))
def __repr__(self):
return str(self)
# We overload , -, *, because they are fully defined for _DimPolynomial.
def __add__(self, other: DimSize) -> DimSize:
coeffs = self.copy()
for mon, coeff in _ensure_poly(other).items():
coeffs[mon] = coeffs.get(mon, 0) + coeff
return _DimPolynomial.from_coeffs(coeffs)
def __sub__(self, other: DimSize) -> DimSize:
return self + -other
def __neg__(self) -> '_DimPolynomial':
return _DimPolynomial({mon: -coeff for mon, coeff in self.items()})
def __mul__(self, other: DimSize) -> DimSize:
other = _ensure_poly(other)
coeffs: Dict[_DimMon, int] = {}
for (mon1, coeff1), (mon2, coeff2) in itertools.product(self.items(), other.items()):
mon = mon1.mul(mon2)
coeffs[mon] = coeffs.get(mon, 0) + coeff1 * coeff2
return _DimPolynomial.from_coeffs(coeffs)
def __pow__(self, power, modulo=None):
assert modulo is None
try:
power = int(power)
except:
raise InconclusiveDimensionOperation(f"Dimension polynomial cannot be raised to non-integer power '{self}' ^ '{power}'")
return functools.reduce(op.mul, [self] * power)
def __rmul__(self, other: DimSize) -> DimSize:
return self * other # multiplication commutes
def __radd__(self, other: DimSize) -> DimSize:
return self + other # addition commutes
def __rsub__(self, other: DimSize) -> DimSize:
return _ensure_poly(other) - self
def eq(self, other: DimSize) -> bool:
lb, ub = _ensure_poly(self - other).bounds()
if lb == ub == 0:
return True
if lb is not None and lb > 0:
return False
if ub is not None and ub < 0:
return False
raise InconclusiveDimensionOperation(f"Dimension polynomial comparison '{self}' == '{other}' is inconclusive")
# We must overload __eq__ and __ne__, or else we get unsound defaults.
__eq__ = eq
def __ne__(self, other: DimSize) -> bool:
return not self.eq(other)
def ge(self, other: DimSize) -> bool:
lb, ub = _ensure_poly(self - other).bounds()
if lb is not None and lb >= 0:
return True
if ub is not None and ub < 0:
return False
raise InconclusiveDimensionOperation(f"Dimension polynomial comparison '{self}' >= '{other}' is inconclusive")
__ge__ = ge
def __le__(self, other: DimSize):
return _ensure_poly(other).__ge__(self)
def __gt__(self, other: DimSize):
return not _ensure_poly(other).__ge__(self)
def __lt__(self, other: DimSize):
return not self.__ge__(other)
def divmod(self, divisor: DimSize) -> Tuple[DimSize, int]:
"""
Floor division with remainder (divmod) generalized to polynomials.
If the `divisor` is not a constant, the remainder must be 0.
If the `divisor` is a constant, the remainder may be non 0, for consistency
with integer divmod.
:return: Quotient resulting from polynomial division and integer remainder.
"""
divisor = _ensure_poly(divisor)
dmon, dcount = divisor.leading_term
dividend, quotient = self, 0
err_msg = f"Dimension polynomial '{self}' is not a multiple of '{divisor}'"
# invariant: self = dividend + divisor * quotient
# the leading term of dividend decreases through the loop.
while is_poly_dim(dividend) and not dividend.is_constant:
mon, count = dividend.leading_term
try:
qmon = mon.divide(dmon)
except InconclusiveDimensionOperation:
raise InconclusiveDimensionOperation(err_msg)
qcount, rcount = divmod(count, dcount)
if rcount != 0:
raise InconclusiveDimensionOperation(err_msg)
q = _DimPolynomial.from_coeffs({qmon: qcount})
quotient += q
dividend -= q * divisor # type: ignore[assignment]
dividend = int(dividend) # type: ignore[assignment]
if divisor.is_constant:
q, r = divmod(dividend, int(divisor)) # type: ignore
quotient += q
remainder = r
else:
if dividend != 0:
raise InconclusiveDimensionOperation(err_msg)
remainder = 0
if config.jax_enable_checks:
assert self == divisor * quotient + remainder
return quotient, remainder
def __floordiv__(self, divisor: DimSize) -> DimSize:
return self.divmod(divisor)[0]
def __rfloordiv__(self, other):
return _ensure_poly(other).__floordiv__(self)
def __mod__(self, divisor: DimSize) -> int:
return self.divmod(divisor)[1]
__divmod__ = divmod
def __rdivmod__(self, other: DimSize) -> Tuple[DimSize, int]:
return _ensure_poly(other).divmod(self)
def __int__(self):
if self.is_constant:
return op.index(next(iter(self.values())))
else:
raise InconclusiveDimensionOperation(f"Dimension polynomial '{self}' is not constant")
def bounds(self) -> Tuple[Optional[int], Optional[int]]:
"""Returns the lower and upper bounds, if defined."""
lb = ub = self.get(_DimMon(), 0) # The free coefficient
for mon, coeff in self.items():
if mon.degree == 0: continue
if coeff > 0:
ub = None
lb = None if lb is None else lb + coeff
else:
lb = None
ub = None if ub is None else ub + coeff
return lb, ub
def evaluate(self, env: Dict[str, Any]) -> Any:
prod = lambda xs: functools.reduce(op.mul, xs) if xs else 1
def mul(coeff, mon):
try:
coeff = int(coeff)
except:
return coeff * mon
else:
return 0 if coeff == 0 else mon if coeff == 1 else coeff * mon
terms = [mul(coeff, prod([pow(env[id], deg) for id, deg in mon.items()]))
for mon, coeff in self.items()]
return sum(terms) if len(terms) > 1 else terms[0]
@property
def is_constant(self):
return len(self) == 1 and next(iter(self)).degree == 0
@property
def leading_term(self) -> Tuple[_DimMon, int]:
"""Returns the highest degree term that comes first lexicographically."""
return max(self.items())
def _ensure_poly(p: DimSize) -> _DimPolynomial:
if isinstance(p, _DimPolynomial): return p
return _DimPolynomial({_DimMon(): p})
def is_poly_dim(p: DimSize) -> bool:
return isinstance(p, _DimPolynomial)
class DimensionHandlerPoly(core.DimensionHandler):
"""See core.DimensionHandler.
Most methods are inherited.
"""
def is_constant(self, d: DimSize) -> bool:
assert isinstance(d, _DimPolynomial)
return False
def symbolic_equal(self, d1: core.DimSize, d2: core.DimSize) -> bool:
try:
return _ensure_poly(d1) == d2
except InconclusiveDimensionOperation:
return False
def greater_equal(self, d1: DimSize, d2: DimSize):
return _ensure_poly(d1) >= d2
def divide_shape_sizes(self, s1: Shape, s2: Shape) -> DimSize:
sz1 = np.prod(s1)
sz2 = np.prod(s2)
if core.symbolic_equal_dim(sz1, sz2): # Takes care also of sz1 == sz2 == 0
return 1
err_msg = f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}"
try:
q, r = _ensure_poly(sz1).divmod(sz2)
except InconclusiveDimensionOperation:
raise InconclusiveDimensionOperation(err_msg)
if r != 0:
raise InconclusiveDimensionOperation(err_msg)
return q # type: ignore[return-value]
def stride(self, d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize:
"""Implements `(d - window_size) // window_stride + 1`"""
try:
q, r = _ensure_poly(d - window_size).divmod(window_stride)
return q + 1
except InconclusiveDimensionOperation as e:
raise InconclusiveDimensionOperation(
f"Cannot compute stride for dimension '{d}', "
f"window_size '{window_size}', stride '{window_stride}'. Reason: {e}.")
return d
def as_value(self, d: DimSize):
"""Turns a dimension size into a Jax value that we can compute with."""
return _dim_as_value(d)
core._SPECIAL_DIMENSION_HANDLERS[_DimPolynomial] = DimensionHandlerPoly()
def _einsum_contract_path(*operands, **kwargs):
"""Like opt_einsum.contract_path, with support for DimPolynomial shapes.
We use opt_einsum.contract_path to compute the schedule, using a fixed
constant for all dimension variables. This is safe because we throw an
error if there are more than 1 contractions. Essentially, we just use
opt_einsum.contract_path to parse the specification.
"""
# Replace the polymorphic shapes with some concrete shapes for calling
# into opt_einsum.contract_path, because the latter wants to compute the
# sizes of operands and intermediate results.
fake_ops = []
for operand in operands:
# We replace only array operands
if not hasattr(operand, "dtype"):
fake_ops.append(operand)
else:
shape = np.shape(operand)
def fake_dim(d):
if core.is_constant_dim(d):
return d
else:
if not isinstance(d, _DimPolynomial):
raise TypeError(f"Encountered unexpected shape dimension {d}")
# It is Ok to replace all polynomials with the same value. We may miss
# here some errors due to non-equal dimensions, but we catch them
# later.
return 8
fake_ops.append(jax.ShapeDtypeStruct(tuple(map(fake_dim, shape)),
operand.dtype))
contract_fake_ops, contractions = opt_einsum.contract_path(*fake_ops,
**kwargs)
if len(contractions) > 1:
msg = ("Shape polymorphism is not yet supported for einsum with more than "
f"one contraction {contractions}")
raise ValueError(msg)
contract_operands = []
for operand in contract_fake_ops:
idx = tuple(i for i, fake_op in enumerate(fake_ops) if operand is fake_op)
assert len(idx) == 1
contract_operands.append(operands[idx[0]])
return contract_operands, contractions
lax_numpy._polymorphic_einsum_contract_path_handlers[_DimPolynomial] = _einsum_contract_path
# A JAX primitive with no array arguments but with a dimension parameter
# that is a DimPoly. The value of the primitive is the value of the dimension.
# This primitive is used only in the context of jax2tf, so it does not need
# XLA translation rules.
dim_as_value_p = core.Primitive("dim_as_value")
def _dim_as_value_abstract(dim: DimSize) -> core.AbstractValue:
return core.ShapedArray((), np.int32)
dim_as_value_p.def_abstract_eval(_dim_as_value_abstract)
def _dim_as_value(dim: DimSize):
return dim_as_value_p.bind(dim=dim)
def _dim_as_value_jax2tf(dim: DimSize):
dim_tf, = jax2tf_internal._eval_shape((dim,))
assert dim_tf.dtype == tf.int32
return dim_tf
def _register_conversion_rules():
jax2tf_internal.tf_impl[dim_as_value_p] = _dim_as_value_jax2tf
class PolyShape(tuple):
"""Tuple of polymorphic dimension specifications.
See docstring of :func:`jax2tf.convert`.
"""
def __new__(cls, *dim_specs):
for i, ds in enumerate(dim_specs):
if not isinstance(ds, (int, str)) and ds != ...:
msg = (f"Invalid PolyShape element: {repr(ds)}; must be a string "
"representing a dimension variable, or an integer, or ...")
raise ValueError(msg)
return tuple.__new__(PolyShape, dim_specs)
def parse_spec(spec: Optional[Union[str, PolyShape]],
arg_shape: Sequence[Optional[int]]) -> Tuple[DimSize, ...]:
"""Parse the shape polymorphic specification for one array argument.
Args:
spec: a shape polymorphic specification, either a string, or a PolyShape.
arg_shape: an actual shape, possibly containing unknown dimensions (None).
The placeholders `_` in the specification are replaced with the values from
the actual shape, which must be known.
See the README.md for usage.
"""
if spec is None:
spec_tuple = (...,) # type: Tuple[Any,...]
elif isinstance(spec, PolyShape):
spec_tuple = spec
elif isinstance(spec, str):
spec_ = spec.strip()
if spec_[0] == "(":
if spec_[-1] != ")":
raise ValueError(f"PolyShape '{spec}' has invalid syntax")
spec_ = spec_[1:-1]
spec_ = spec_.strip()
spec_ = spec_.rstrip(",")
if not spec_:
spec_tuple = ()
else:
spec_tuple = spec_.split(",") # type: ignore
else:
raise ValueError(f"PolyShape {repr(spec)} must be either None, a string, or PolyShape.")
# Process ...
spec_tuple = tuple(map(lambda s: ... if isinstance(s, str) and s.strip() == "..." else s,
spec_tuple))
ds_ellipses = tuple(ds for ds in spec_tuple if ds == ...)
if ds_ellipses:
if len(ds_ellipses) > 1 or spec_tuple[-1] != ...:
raise ValueError(f"PolyShape {repr(spec)} can contain Ellipsis only at the end.")
spec_tuple = spec_tuple[0:-1]
if len(arg_shape) >= len(spec_tuple):
spec_tuple = spec_tuple + ("_",) * (len(arg_shape) - len(spec_tuple))
if len(arg_shape) != len(spec_tuple):
raise ValueError(f"PolyShape {repr(spec)} of rank {len(spec_tuple)} must match the rank {len(arg_shape)} of argument shape {arg_shape}.")
# The actual parsing.
# We actually parse not just dimension variables, but polynomials.
# This is not a supported feature of the API, but is needed when parsing the
# polymorphic_shapes of a gradient function, when the primal function has polynomial
# output shapes.
def _parse_dim(dim_spec: Union[str, int]) -> DimSize:
if isinstance(dim_spec, int):
return dim_spec #
dim_spec = dim_spec.strip()
if not dim_spec:
raise ValueError(f"PolyShape {repr(spec)} has invalid syntax (empty dimension {dim_spec}')")
# Terms are separated by "+"
terms = dim_spec.split("+")
if not terms:
raise ValueError(f"PolyShape {repr(spec)} has invalid syntax (empty dimension {dim_spec}')")
def _parse_term(term_spec: str) -> DimSize:
term_spec = term_spec.strip()
# Factors are separated by "*"
factors = term_spec.split("*")
if not factors:
raise ValueError(f"PolyShape {repr(spec)} has invalid syntax (unexpected term '{term_spec}')")
def _parse_factor(factor_spec: str) -> DimSize:
factor_spec = factor_spec.strip()
if re.match(r"^-?\d+$", factor_spec):
return int(factor_spec)
m = re.match(r"^([a-zA-Z]\w*)(\^(\d+))?$", factor_spec)
if not m:
raise ValueError(f"PolyShape {repr(spec)} has invalid syntax (unexpected term '{factor_spec}')")
var = _DimPolynomial.from_var(m.group(1))
if m.group(3) is None:
return var
return var ** int(m.group(3))
return functools.reduce(op.mul, map(_parse_factor, factors))
return functools.reduce(op.add, map(_parse_term, terms))
def _process_dim(i: int, dim_spec: Union[str, int]):
if isinstance(dim_spec, str):
dim_spec = dim_spec.strip()
dim_size = arg_shape[i]
if isinstance(dim_size, tf.compat.v1.Dimension):
dim_size = dim_size.value
if dim_size is None:
if dim_spec == "_":
msg = (f"PolyShape {repr(spec)} in axis {i} must contain a shape variable "
f"for unknown dimension in argument shape {arg_shape}")
raise ValueError(msg)
dim_poly = _parse_dim(dim_spec)
if not is_poly_dim(dim_poly):
msg = (f"PolyShape {repr(spec)} in axis {i} must contain a shape variable "
f"for unknown dimension in argument shape {arg_shape}")
raise ValueError(msg)
return dim_poly
else: # dim_size is known
dim_size = int(dim_size)
if dim_spec == "_":
return dim_size
dim_poly = _parse_dim(dim_spec)
if not is_poly_dim(dim_poly):
if dim_poly != dim_size:
msg = (f"PolyShape {repr(spec)} in axis {i} must contain a constant or '_' "
f"for known dimension in argument shape {arg_shape}")
raise ValueError(msg)
return dim_size
return dim_poly
dims = tuple([_process_dim(i, ds) for i, ds in enumerate(spec_tuple)])
return dims
def eval_shape(shape: Sequence[DimSize], shape_env: Dict[str, Any]) -> Sequence[Any]:
return tuple(d.evaluate(shape_env) if type(d) is _DimPolynomial else d # type: ignore
for d in shape)