[jax2tf] Fix conversion of gradients for shape polymorphic functions.

This fixes the case when the primal shape polymorphic function has
output shapes that are polynomials of the input shapes (not just
dimension variables).
This commit is contained in:
George Necula 2021-06-23 10:52:03 +02:00
parent b85e3a0107
commit 7e335e0e2e
4 changed files with 134 additions and 46 deletions

View File

@ -1080,7 +1080,7 @@ class ShapedArray(UnshapedArray):
self.weak_type, self.named_shape)
def join(self, other):
if self.shape == other.shape and self.dtype == other.dtype:
if symbolic_equal_shape(self.shape, other.shape) and self.dtype == other.dtype:
weak_type = self.weak_type and other.weak_type
named_shape = join_named_shapes(self.named_shape, other.named_shape)
return self.update(weak_type=weak_type, named_shape=named_shape)

View File

@ -335,7 +335,7 @@ def convert(fun: Callable,
in_tree.children()[0], polymorphic_shapes_flat)
out_cts_polymorphic_shapes = tree_util.tree_unflatten(
out_tree,
tuple(str(out_aval.shape)
tuple(str(out_aval.shape) # Note: may be polynomials, not just DimVar
for out_aval in _out_cts_avals)) # type: ignore
vjp_polymorphic_shapes = [
args_polymorphic_shapes, out_cts_polymorphic_shapes
@ -584,18 +584,20 @@ def _args_to_avals_and_env(
aval_shape = shape_poly.parse_spec(polymorphic_shape, arg_shape)
for i, d in enumerate(aval_shape):
if type(d) is int:
if not shape_poly.is_poly_dim(d):
assert isinstance(d, int)
assert d == arg_shape[i]
elif shape_poly.is_poly_dim(d) and d not in shapeenv:
# Even if the shape of `arg` is known, we still use `tf.shape` for
# safety, because the promise is that we will convert the function
# to work for any value of the dimension.
v = d.to_var() # type: ignore
assert v is not None
shapeenv[v] = tf.shape(arg)[i] # type: ignore[index]
else:
# TODO: add an assertion tf.shape(arg)[i] == env[d]
pass
d_var = d.to_var() # type: ignore
if d_var is not None and d_var not in shapeenv:
# Even if the shape of `arg` is known, we still use `tf.shape` for
# safety, because the promise is that we will convert the function
# to work for any value of the dimension.
shapeenv[d_var] = tf.shape(arg)[i] # type: ignore[index]
else:
# TODO: add an assertion tf.shape(arg)[i] == env[d]
pass
return core.ShapedArray(aval_shape, arg_jax_dtype)

View File

@ -21,6 +21,7 @@ import collections
import itertools
import functools
import operator as op
import re
from typing import Any, Dict, Optional, Sequence, Set, Tuple, Union
@ -64,7 +65,7 @@ class _DimMon(dict):
return hash(frozenset(self.items()))
def __str__(self):
return ' '.join(f'{key}^{exponent}' if exponent != 1 else str(key)
return "*".join(f"{key}^{exponent}" if exponent != 1 else str(key)
for key, exponent in sorted(self.items()))
@classmethod
@ -184,8 +185,14 @@ class _DimPolynomial(dict):
return hash(tuple(sorted(self.items())))
def __str__(self):
return ' + '.join(f'{c} {mon}' if c != 1 or mon.degree == 0 else str(mon)
for mon, c in sorted(self.items(), reverse=True)).strip()
def _one_monomial(mon, c):
if mon.degree == 0:
return str(c)
if c == 1:
return str(mon)
return f"{c}*{mon}"
return " + ".join(_one_monomial(mon, c)
for mon, c in sorted(self.items(), reverse=True))
def __repr__(self):
return str(self)
@ -211,6 +218,12 @@ class _DimPolynomial(dict):
coeffs[mon] = coeffs.get(mon, 0) + coeff1 * coeff2
return _DimPolynomial.from_coeffs(coeffs)
def __pow__(self, power, modulo=None):
assert modulo is None
if not isinstance(power, int):
raise InconclusiveDimensionOperation(f"Dimension polynomial cannot be raised to non-integer power '{self}' ^ '{power}'")
return functools.reduce(op.mul, [self] * power)
def __rmul__(self, other: DimSize) -> DimSize:
return self * other # multiplication commutes
@ -476,32 +489,25 @@ def parse_spec(spec: Optional[Union[str, PolyShape]],
if spec is None:
spec_tuple = (...,) # type: Tuple[Any,...]
elif isinstance(spec, PolyShape):
spec_tuple = tuple(spec)
spec_tuple = spec
elif isinstance(spec, str):
spec_ = spec.replace(" ", "")
spec_ = spec.strip()
if spec_[0] == "(":
if spec_[-1] != ")":
raise ValueError(f"PolyShape '{spec}' has invalid syntax")
spec_ = spec_[1:-1]
spec_ = spec_.strip()
spec_ = spec_.rstrip(",")
if not spec_:
spec_tuple = ()
else:
specs = spec_.split(',')
def parse_dim(ds: str):
if ds == "...":
return ...
elif ds.isdigit():
return int(ds)
elif ds == "_" or ds.isalnum():
return ds
else:
raise ValueError(f"PolyShape '{spec}' has invalid syntax")
spec_tuple = tuple(map(parse_dim, specs))
spec_tuple = spec_.split(",") # type: ignore
else:
raise ValueError(f"PolyShape '{spec}' must be either None, a string, or PolyShape.")
# Process ...
spec_tuple = tuple(map(lambda s: ... if isinstance(s, str) and s.strip() == "..." else s,
spec_tuple))
ds_ellipses = tuple(ds for ds in spec_tuple if ds == ...)
if ds_ellipses:
if len(ds_ellipses) > 1 or spec_tuple[-1] != ...:
@ -513,29 +519,73 @@ def parse_spec(spec: Optional[Union[str, PolyShape]],
if len(arg_shape) != len(spec_tuple):
raise ValueError(f"PolyShape '{spec}' must match the rank of arguments {arg_shape}.")
# The actual parsing.
# We actually parse not just dimension variables, but polynomials.
# This is not a supported feature of the API, but is needed when parsing the
# polymorphic_shapes of a gradient function, when the primal function has polynomial
# output shapes.
def _parse_dim(dim_spec: Union[str, int]) -> DimSize:
if isinstance(dim_spec, int):
return dim_spec #
dim_spec = dim_spec.strip()
if not dim_spec:
raise ValueError(f"PolyShape '{spec}' has invalid syntax (empty dimension {dim_spec}')")
# Terms are separated by "+"
terms = dim_spec.split("+")
if not terms:
raise ValueError(f"PolyShape '{spec}' has invalid syntax (empty dimension {dim_spec}')")
def _parse_term(term_spec: str) -> DimSize:
term_spec = term_spec.strip()
# Factors are separated by "*"
factors = term_spec.split("*")
if not factors:
raise ValueError(f"PolyShape '{spec}' has invalid syntax (unexpected term '{term_spec}')")
def _parse_factor(factor_spec: str) -> DimSize:
factor_spec = factor_spec.strip()
if re.match(r"^-?\d+$", factor_spec):
return int(factor_spec)
m = re.match(r"^([a-zA-Z]\w*)(\^(\d+))?$", factor_spec)
if not m:
raise ValueError(f"PolyShape '{spec}' has invalid syntax (unexpected term '{factor_spec}')")
var = _DimPolynomial.from_var(m.group(1))
if m.group(3) is None:
return var
return var ** int(m.group(3))
return functools.reduce(op.mul, map(_parse_factor, factors))
return functools.reduce(op.add, map(_parse_term, terms))
shape_var_map: Dict[str, Set[int]] = collections.defaultdict(set)
def _process_dim(i: int, dim_spec):
if not isinstance(dim_spec, (str, int)):
raise ValueError(f"PolyShape '{spec}' in axis {i} must contain only integers, strings, or Ellipsis.")
def _process_dim(i: int, dim_spec: Union[str, int]):
if isinstance(dim_spec, str):
dim_spec = dim_spec.strip()
dim_size = arg_shape[i]
if dim_size is None:
if dim_spec == "_" or not isinstance(dim_spec, str):
if dim_spec == "_":
msg = (f"PolyShape '{spec}' in axis {i} must contain a shape variable "
f"for unknown dimension in argument shape {arg_shape}")
raise ValueError(msg)
return _DimPolynomial.from_var(dim_spec)
dim_poly = _parse_dim(dim_spec)
if not isinstance(dim_poly, _DimPolynomial):
msg = (f"PolyShape '{spec}' in axis {i} must contain a shape variable "
f"for unknown dimension in argument shape {arg_shape}")
raise ValueError(msg)
return dim_poly
else: # dim_size is known
if dim_spec == "_":
return dim_size
if isinstance(dim_spec, int):
if dim_spec != dim_size:
dim_poly = _parse_dim(dim_spec)
if isinstance(dim_poly, int):
if dim_poly != dim_size:
msg = (f"PolyShape '{spec}' in axis {i} must contain a constant or '_' "
f"for known dimension in argument shape {arg_shape}")
raise ValueError(msg)
return dim_size
# We have a dimension variable for a known dimension.
shape_var_map[dim_spec].add(dim_size)
return _DimPolynomial.from_var(dim_spec)
# We have a dimension polynomial for a known dimension.
dim_var = dim_poly.to_var()
if dim_var is not None:
shape_var_map[dim_spec].add(dim_size) # type: ignore
return dim_poly
dims = tuple([_process_dim(i, ds) for i, ds in enumerate(spec_tuple)])
for dim_var, dim_var_values in shape_var_map.items():

View File

@ -59,9 +59,28 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
self.assertEqual((2, 3), shape_poly.parse_spec("...", (2, 3)))
self.assertEqual((2, 3), shape_poly.parse_spec(" ( 2 , 3 ) ", (2, 3)))
a, b = shape_poly.parse_spec("a, b", (2, 3))
@parameterized.named_parameters(
dict(testcase_name=f"_dim_spec={dim_spec}",
dim_spec=dim_spec, dim_poly=dim_poly)
for dim_spec, dim_poly in [
("2*a*b", 2 * a * b),
("-2 * a^2 * b + b^2", -2 * a * a * b + b * b),
("-2 * a^2 * b + -1 *b^2*a", -2 * a * a * b - a * b * b),
("3 * a * b * a + -2", 3 * a * b * a - 2),
("a + 1", a + 1),
("a + -1", a - 1),
])
def test_parse_poly_spec_poly(self, dim_spec="3 * a * b * a + -2", dim_poly=3 * a * b * a - 2):
# For internal usage only (the polymorphic_shapes of VJP) we need to
# parse polynomials.
self.assertEqual((dim_poly,), shape_poly.parse_spec(dim_spec, (2,)))
self.assertEqual((dim_poly,), shape_poly.parse_spec(str(dim_poly), (2,)))
def test_dim_vars(self):
a, b = shape_poly.parse_spec("a, b", (2, 3))
a, b, a1 = shape_poly.parse_spec("a, b, a", (2, 3, 2))
self.assertEqual(True, a == a)
self.assertEqual(True, a == a1)
self.assertEqual(False, a != a)
with self.assertRaisesRegex(
core.InconclusiveDimensionOperation,
@ -145,7 +164,7 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
self.assertFalse((2 * a * b * a + 1).eq(a * b * a))
self.assertFalse((3 * a * b * a - 1).eq(a * b * a))
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
re.escape("Dimension polynomial comparison '3 a^2 b + -2' == 'a^2 b' is inconclusive")):
re.escape("Dimension polynomial comparison '3*a^2*b + -2' == 'a^2*b' is inconclusive")):
(3 * a * b * a - 2).eq(a * b * a)
def test_poly_compare(self):
@ -202,7 +221,6 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
self.assertEqual(a * 2 // a, 2)
self.assertIsInstance(a * 2 // a, int)
a, b = shape_poly.parse_spec("a, b", (2, 3))
@parameterized.named_parameters(
dict(testcase_name=f"_D={dividend}_d={divisor}_q={quotient}_r={remainder}",
dividend=dividend, divisor=divisor, quotient=quotient,
@ -362,10 +380,12 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
)
# Some errors
with self.assertRaisesRegex(ValueError,
re.escape("PolyShape ')(' has invalid syntax")):
check_avals(
args=[const((2, 3))], polymorphic_shapes=[")("], expected_avals=None)
for invalid_syntax in [")(", "2a", "a@", "a - 2"]:
with self.assertRaisesRegex(ValueError,
re.escape("PolyShape '" + invalid_syntax + "' has invalid syntax")):
check_avals(
args=[const((2,))], polymorphic_shapes=[invalid_syntax],
expected_avals=None)
with self.assertRaisesRegex(
ValueError,
@ -592,6 +612,22 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
# The shape of the gradient should match the input
self.assertEqual((None, 3, 4), tuple(tf_grad.output_shapes[1]["grad"]))
def test_grad_not_var_output(self):
# Output of the function has poly shapes, non-variable
def f_jax(x): # :[b, 3]
return jnp.reshape(x, (-1,)) # : [3b]
x = np.arange(12, dtype=np.float32).reshape((4, 3))
xv = tf.Variable(x)
f_tf = jax2tf.convert(f_jax, with_gradient=True,
polymorphic_shapes=["b, ..."])
with tf.GradientTape() as tape:
res_tf = f_tf(xv)
grad_tf = tape.gradient(res_tf, xv)
self.assertAllClose(np.ones(x.shape, dtype=np.float32), grad_tf.numpy())
def test_cond(self):
# Test the primitive under conditional
def f(x, y):