mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Fix conversion of gradients for shape polymorphic functions.
This fixes the case when the primal shape polymorphic function has output shapes that are polynomials of the input shapes (not just dimension variables).
This commit is contained in:
parent
b85e3a0107
commit
7e335e0e2e
@ -1080,7 +1080,7 @@ class ShapedArray(UnshapedArray):
|
||||
self.weak_type, self.named_shape)
|
||||
|
||||
def join(self, other):
|
||||
if self.shape == other.shape and self.dtype == other.dtype:
|
||||
if symbolic_equal_shape(self.shape, other.shape) and self.dtype == other.dtype:
|
||||
weak_type = self.weak_type and other.weak_type
|
||||
named_shape = join_named_shapes(self.named_shape, other.named_shape)
|
||||
return self.update(weak_type=weak_type, named_shape=named_shape)
|
||||
|
@ -335,7 +335,7 @@ def convert(fun: Callable,
|
||||
in_tree.children()[0], polymorphic_shapes_flat)
|
||||
out_cts_polymorphic_shapes = tree_util.tree_unflatten(
|
||||
out_tree,
|
||||
tuple(str(out_aval.shape)
|
||||
tuple(str(out_aval.shape) # Note: may be polynomials, not just DimVar
|
||||
for out_aval in _out_cts_avals)) # type: ignore
|
||||
vjp_polymorphic_shapes = [
|
||||
args_polymorphic_shapes, out_cts_polymorphic_shapes
|
||||
@ -584,18 +584,20 @@ def _args_to_avals_and_env(
|
||||
aval_shape = shape_poly.parse_spec(polymorphic_shape, arg_shape)
|
||||
|
||||
for i, d in enumerate(aval_shape):
|
||||
if type(d) is int:
|
||||
if not shape_poly.is_poly_dim(d):
|
||||
assert isinstance(d, int)
|
||||
assert d == arg_shape[i]
|
||||
elif shape_poly.is_poly_dim(d) and d not in shapeenv:
|
||||
# Even if the shape of `arg` is known, we still use `tf.shape` for
|
||||
# safety, because the promise is that we will convert the function
|
||||
# to work for any value of the dimension.
|
||||
v = d.to_var() # type: ignore
|
||||
assert v is not None
|
||||
shapeenv[v] = tf.shape(arg)[i] # type: ignore[index]
|
||||
else:
|
||||
# TODO: add an assertion tf.shape(arg)[i] == env[d]
|
||||
pass
|
||||
d_var = d.to_var() # type: ignore
|
||||
if d_var is not None and d_var not in shapeenv:
|
||||
# Even if the shape of `arg` is known, we still use `tf.shape` for
|
||||
# safety, because the promise is that we will convert the function
|
||||
# to work for any value of the dimension.
|
||||
shapeenv[d_var] = tf.shape(arg)[i] # type: ignore[index]
|
||||
else:
|
||||
# TODO: add an assertion tf.shape(arg)[i] == env[d]
|
||||
pass
|
||||
|
||||
|
||||
return core.ShapedArray(aval_shape, arg_jax_dtype)
|
||||
|
||||
|
@ -21,6 +21,7 @@ import collections
|
||||
import itertools
|
||||
import functools
|
||||
import operator as op
|
||||
import re
|
||||
from typing import Any, Dict, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
|
||||
@ -64,7 +65,7 @@ class _DimMon(dict):
|
||||
return hash(frozenset(self.items()))
|
||||
|
||||
def __str__(self):
|
||||
return ' '.join(f'{key}^{exponent}' if exponent != 1 else str(key)
|
||||
return "*".join(f"{key}^{exponent}" if exponent != 1 else str(key)
|
||||
for key, exponent in sorted(self.items()))
|
||||
|
||||
@classmethod
|
||||
@ -184,8 +185,14 @@ class _DimPolynomial(dict):
|
||||
return hash(tuple(sorted(self.items())))
|
||||
|
||||
def __str__(self):
|
||||
return ' + '.join(f'{c} {mon}' if c != 1 or mon.degree == 0 else str(mon)
|
||||
for mon, c in sorted(self.items(), reverse=True)).strip()
|
||||
def _one_monomial(mon, c):
|
||||
if mon.degree == 0:
|
||||
return str(c)
|
||||
if c == 1:
|
||||
return str(mon)
|
||||
return f"{c}*{mon}"
|
||||
return " + ".join(_one_monomial(mon, c)
|
||||
for mon, c in sorted(self.items(), reverse=True))
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
@ -211,6 +218,12 @@ class _DimPolynomial(dict):
|
||||
coeffs[mon] = coeffs.get(mon, 0) + coeff1 * coeff2
|
||||
return _DimPolynomial.from_coeffs(coeffs)
|
||||
|
||||
def __pow__(self, power, modulo=None):
|
||||
assert modulo is None
|
||||
if not isinstance(power, int):
|
||||
raise InconclusiveDimensionOperation(f"Dimension polynomial cannot be raised to non-integer power '{self}' ^ '{power}'")
|
||||
return functools.reduce(op.mul, [self] * power)
|
||||
|
||||
def __rmul__(self, other: DimSize) -> DimSize:
|
||||
return self * other # multiplication commutes
|
||||
|
||||
@ -476,32 +489,25 @@ def parse_spec(spec: Optional[Union[str, PolyShape]],
|
||||
if spec is None:
|
||||
spec_tuple = (...,) # type: Tuple[Any,...]
|
||||
elif isinstance(spec, PolyShape):
|
||||
spec_tuple = tuple(spec)
|
||||
spec_tuple = spec
|
||||
elif isinstance(spec, str):
|
||||
spec_ = spec.replace(" ", "")
|
||||
spec_ = spec.strip()
|
||||
if spec_[0] == "(":
|
||||
if spec_[-1] != ")":
|
||||
raise ValueError(f"PolyShape '{spec}' has invalid syntax")
|
||||
spec_ = spec_[1:-1]
|
||||
spec_ = spec_.strip()
|
||||
spec_ = spec_.rstrip(",")
|
||||
if not spec_:
|
||||
spec_tuple = ()
|
||||
else:
|
||||
specs = spec_.split(',')
|
||||
def parse_dim(ds: str):
|
||||
if ds == "...":
|
||||
return ...
|
||||
elif ds.isdigit():
|
||||
return int(ds)
|
||||
elif ds == "_" or ds.isalnum():
|
||||
return ds
|
||||
else:
|
||||
raise ValueError(f"PolyShape '{spec}' has invalid syntax")
|
||||
|
||||
spec_tuple = tuple(map(parse_dim, specs))
|
||||
spec_tuple = spec_.split(",") # type: ignore
|
||||
else:
|
||||
raise ValueError(f"PolyShape '{spec}' must be either None, a string, or PolyShape.")
|
||||
|
||||
# Process ...
|
||||
spec_tuple = tuple(map(lambda s: ... if isinstance(s, str) and s.strip() == "..." else s,
|
||||
spec_tuple))
|
||||
ds_ellipses = tuple(ds for ds in spec_tuple if ds == ...)
|
||||
if ds_ellipses:
|
||||
if len(ds_ellipses) > 1 or spec_tuple[-1] != ...:
|
||||
@ -513,29 +519,73 @@ def parse_spec(spec: Optional[Union[str, PolyShape]],
|
||||
if len(arg_shape) != len(spec_tuple):
|
||||
raise ValueError(f"PolyShape '{spec}' must match the rank of arguments {arg_shape}.")
|
||||
|
||||
# The actual parsing.
|
||||
# We actually parse not just dimension variables, but polynomials.
|
||||
# This is not a supported feature of the API, but is needed when parsing the
|
||||
# polymorphic_shapes of a gradient function, when the primal function has polynomial
|
||||
# output shapes.
|
||||
def _parse_dim(dim_spec: Union[str, int]) -> DimSize:
|
||||
if isinstance(dim_spec, int):
|
||||
return dim_spec #
|
||||
dim_spec = dim_spec.strip()
|
||||
if not dim_spec:
|
||||
raise ValueError(f"PolyShape '{spec}' has invalid syntax (empty dimension {dim_spec}')")
|
||||
# Terms are separated by "+"
|
||||
terms = dim_spec.split("+")
|
||||
if not terms:
|
||||
raise ValueError(f"PolyShape '{spec}' has invalid syntax (empty dimension {dim_spec}')")
|
||||
def _parse_term(term_spec: str) -> DimSize:
|
||||
term_spec = term_spec.strip()
|
||||
# Factors are separated by "*"
|
||||
factors = term_spec.split("*")
|
||||
if not factors:
|
||||
raise ValueError(f"PolyShape '{spec}' has invalid syntax (unexpected term '{term_spec}')")
|
||||
def _parse_factor(factor_spec: str) -> DimSize:
|
||||
factor_spec = factor_spec.strip()
|
||||
if re.match(r"^-?\d+$", factor_spec):
|
||||
return int(factor_spec)
|
||||
m = re.match(r"^([a-zA-Z]\w*)(\^(\d+))?$", factor_spec)
|
||||
if not m:
|
||||
raise ValueError(f"PolyShape '{spec}' has invalid syntax (unexpected term '{factor_spec}')")
|
||||
var = _DimPolynomial.from_var(m.group(1))
|
||||
if m.group(3) is None:
|
||||
return var
|
||||
return var ** int(m.group(3))
|
||||
|
||||
return functools.reduce(op.mul, map(_parse_factor, factors))
|
||||
return functools.reduce(op.add, map(_parse_term, terms))
|
||||
|
||||
shape_var_map: Dict[str, Set[int]] = collections.defaultdict(set)
|
||||
def _process_dim(i: int, dim_spec):
|
||||
if not isinstance(dim_spec, (str, int)):
|
||||
raise ValueError(f"PolyShape '{spec}' in axis {i} must contain only integers, strings, or Ellipsis.")
|
||||
def _process_dim(i: int, dim_spec: Union[str, int]):
|
||||
if isinstance(dim_spec, str):
|
||||
dim_spec = dim_spec.strip()
|
||||
dim_size = arg_shape[i]
|
||||
if dim_size is None:
|
||||
if dim_spec == "_" or not isinstance(dim_spec, str):
|
||||
if dim_spec == "_":
|
||||
msg = (f"PolyShape '{spec}' in axis {i} must contain a shape variable "
|
||||
f"for unknown dimension in argument shape {arg_shape}")
|
||||
raise ValueError(msg)
|
||||
return _DimPolynomial.from_var(dim_spec)
|
||||
dim_poly = _parse_dim(dim_spec)
|
||||
if not isinstance(dim_poly, _DimPolynomial):
|
||||
msg = (f"PolyShape '{spec}' in axis {i} must contain a shape variable "
|
||||
f"for unknown dimension in argument shape {arg_shape}")
|
||||
raise ValueError(msg)
|
||||
return dim_poly
|
||||
else: # dim_size is known
|
||||
if dim_spec == "_":
|
||||
return dim_size
|
||||
if isinstance(dim_spec, int):
|
||||
if dim_spec != dim_size:
|
||||
dim_poly = _parse_dim(dim_spec)
|
||||
if isinstance(dim_poly, int):
|
||||
if dim_poly != dim_size:
|
||||
msg = (f"PolyShape '{spec}' in axis {i} must contain a constant or '_' "
|
||||
f"for known dimension in argument shape {arg_shape}")
|
||||
raise ValueError(msg)
|
||||
return dim_size
|
||||
# We have a dimension variable for a known dimension.
|
||||
shape_var_map[dim_spec].add(dim_size)
|
||||
return _DimPolynomial.from_var(dim_spec)
|
||||
# We have a dimension polynomial for a known dimension.
|
||||
dim_var = dim_poly.to_var()
|
||||
if dim_var is not None:
|
||||
shape_var_map[dim_spec].add(dim_size) # type: ignore
|
||||
return dim_poly
|
||||
|
||||
dims = tuple([_process_dim(i, ds) for i, ds in enumerate(spec_tuple)])
|
||||
for dim_var, dim_var_values in shape_var_map.items():
|
||||
|
@ -59,9 +59,28 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertEqual((2, 3), shape_poly.parse_spec("...", (2, 3)))
|
||||
self.assertEqual((2, 3), shape_poly.parse_spec(" ( 2 , 3 ) ", (2, 3)))
|
||||
|
||||
a, b = shape_poly.parse_spec("a, b", (2, 3))
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_dim_spec={dim_spec}",
|
||||
dim_spec=dim_spec, dim_poly=dim_poly)
|
||||
for dim_spec, dim_poly in [
|
||||
("2*a*b", 2 * a * b),
|
||||
("-2 * a^2 * b + b^2", -2 * a * a * b + b * b),
|
||||
("-2 * a^2 * b + -1 *b^2*a", -2 * a * a * b - a * b * b),
|
||||
("3 * a * b * a + -2", 3 * a * b * a - 2),
|
||||
("a + 1", a + 1),
|
||||
("a + -1", a - 1),
|
||||
])
|
||||
def test_parse_poly_spec_poly(self, dim_spec="3 * a * b * a + -2", dim_poly=3 * a * b * a - 2):
|
||||
# For internal usage only (the polymorphic_shapes of VJP) we need to
|
||||
# parse polynomials.
|
||||
self.assertEqual((dim_poly,), shape_poly.parse_spec(dim_spec, (2,)))
|
||||
self.assertEqual((dim_poly,), shape_poly.parse_spec(str(dim_poly), (2,)))
|
||||
|
||||
def test_dim_vars(self):
|
||||
a, b = shape_poly.parse_spec("a, b", (2, 3))
|
||||
a, b, a1 = shape_poly.parse_spec("a, b, a", (2, 3, 2))
|
||||
self.assertEqual(True, a == a)
|
||||
self.assertEqual(True, a == a1)
|
||||
self.assertEqual(False, a != a)
|
||||
with self.assertRaisesRegex(
|
||||
core.InconclusiveDimensionOperation,
|
||||
@ -145,7 +164,7 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertFalse((2 * a * b * a + 1).eq(a * b * a))
|
||||
self.assertFalse((3 * a * b * a - 1).eq(a * b * a))
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
re.escape("Dimension polynomial comparison '3 a^2 b + -2' == 'a^2 b' is inconclusive")):
|
||||
re.escape("Dimension polynomial comparison '3*a^2*b + -2' == 'a^2*b' is inconclusive")):
|
||||
(3 * a * b * a - 2).eq(a * b * a)
|
||||
|
||||
def test_poly_compare(self):
|
||||
@ -202,7 +221,6 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertEqual(a * 2 // a, 2)
|
||||
self.assertIsInstance(a * 2 // a, int)
|
||||
|
||||
a, b = shape_poly.parse_spec("a, b", (2, 3))
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_D={dividend}_d={divisor}_q={quotient}_r={remainder}",
|
||||
dividend=dividend, divisor=divisor, quotient=quotient,
|
||||
@ -362,10 +380,12 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
)
|
||||
|
||||
# Some errors
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
re.escape("PolyShape ')(' has invalid syntax")):
|
||||
check_avals(
|
||||
args=[const((2, 3))], polymorphic_shapes=[")("], expected_avals=None)
|
||||
for invalid_syntax in [")(", "2a", "a@", "a - 2"]:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
re.escape("PolyShape '" + invalid_syntax + "' has invalid syntax")):
|
||||
check_avals(
|
||||
args=[const((2,))], polymorphic_shapes=[invalid_syntax],
|
||||
expected_avals=None)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
@ -592,6 +612,22 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
# The shape of the gradient should match the input
|
||||
self.assertEqual((None, 3, 4), tuple(tf_grad.output_shapes[1]["grad"]))
|
||||
|
||||
def test_grad_not_var_output(self):
|
||||
# Output of the function has poly shapes, non-variable
|
||||
def f_jax(x): # :[b, 3]
|
||||
return jnp.reshape(x, (-1,)) # : [3b]
|
||||
x = np.arange(12, dtype=np.float32).reshape((4, 3))
|
||||
xv = tf.Variable(x)
|
||||
|
||||
f_tf = jax2tf.convert(f_jax, with_gradient=True,
|
||||
polymorphic_shapes=["b, ..."])
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
res_tf = f_tf(xv)
|
||||
grad_tf = tape.gradient(res_tf, xv)
|
||||
self.assertAllClose(np.ones(x.shape, dtype=np.float32), grad_tf.numpy())
|
||||
|
||||
|
||||
def test_cond(self):
|
||||
# Test the primitive under conditional
|
||||
def f(x, y):
|
||||
|
Loading…
x
Reference in New Issue
Block a user