[jax2tf] Added support for shape-polymorphic reductionsxs

This commit is contained in:
George Necula 2021-04-09 11:10:32 +03:00
parent 45bae37278
commit 7667fc3be7
3 changed files with 23 additions and 4 deletions

View File

@ -1975,7 +1975,7 @@ def _reduction(a, name, np_fun, op, init_val, has_identity=True,
axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().")
if initial is None and not has_identity:
if not size(a):
if not _all(core.greater_equal_dim(d, 1) for d in np.shape(a)):
raise ValueError(f"zero-size array to reduction operation {name} which has no identity")
if where_ is not None:
raise ValueError(f"reduction operation {name} does not have an identity, so to use a "

View File

@ -14,6 +14,7 @@
"""Tests for the shape-polymorphic jax2tf conversion."""
from absl.testing import absltest
from absl.testing import parameterized
from typing import Dict, Optional, Sequence, Union
import collections
@ -992,8 +993,8 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(f_jax(x), f_tf(x))
def test_random_gamma(self):
if "random_gamma" in _VMAP_NOT_POLY_YET:
raise unittest.SkipTest("TODO: vmap(random_gamma) not yet supported")
assert "random_gamma" in _VMAP_NOT_POLY_YET
raise unittest.SkipTest("TODO: vmap(random_gamma) not yet supported")
def f_jax(key, a):
return jax.random.gamma(key, a)
@ -1012,6 +1013,24 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
expected_output_signature=tf.TensorSpec([None, 3]))
self.assertAllClose(f_jax(key, a), f_tf(key, a))
@parameterized.named_parameters(
jtu.cases_from_list(
dict(testcase_name=f"_{op.__name__}", op=op)
for op in [jnp.all, jnp.any, jnp.max, jnp.min, jnp.prod, jnp.sum]))
def test_reduce(self, op=jnp.max):
f_jax = lambda x: op(x, axis=-1, keepdims=True)
x = np.random.rand(7, 8)
f_tf = self.CheckShapePolymorphism(
f_jax,
input_signature=[
tf.TensorSpec([None, 8], dtype=x.dtype),
],
polymorphic_shapes=["(batch, ...)"],
expected_output_signature=tf.TensorSpec([None, 1], dtype=x.dtype))
self.assertAllClose(f_jax(x), f_tf(x))
def test_reshape(self):
self.CheckShapePolymorphism(

View File

@ -742,7 +742,7 @@ class MaskingTest(jtu.JaxTestCase):
'testcase_name': "operator={}".format(operator.__name__), 'operator': operator}
for operator in [jnp.sum, jnp.prod, jnp.max, jnp.min]]))
def test_reduce(self, operator):
self.check(operator, ['(m, n)'], '', {'m': 3, 'n': 4}, [(4, 5)], ['float_'],
self.check(operator, ['(m+1, n+1)'], '', {'m': 3, 'n': 4}, [(4, 5)], ['float_'],
jtu.rand_default(self.rng()))
def test_output_shape_error(self):