mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Added support for shape-polymorphic reductionsxs
This commit is contained in:
parent
45bae37278
commit
7667fc3be7
@ -1975,7 +1975,7 @@ def _reduction(a, name, np_fun, op, init_val, has_identity=True,
|
||||
axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().")
|
||||
|
||||
if initial is None and not has_identity:
|
||||
if not size(a):
|
||||
if not _all(core.greater_equal_dim(d, 1) for d in np.shape(a)):
|
||||
raise ValueError(f"zero-size array to reduction operation {name} which has no identity")
|
||||
if where_ is not None:
|
||||
raise ValueError(f"reduction operation {name} does not have an identity, so to use a "
|
||||
|
@ -14,6 +14,7 @@
|
||||
"""Tests for the shape-polymorphic jax2tf conversion."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
import collections
|
||||
@ -992,8 +993,8 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(f_jax(x), f_tf(x))
|
||||
|
||||
def test_random_gamma(self):
|
||||
if "random_gamma" in _VMAP_NOT_POLY_YET:
|
||||
raise unittest.SkipTest("TODO: vmap(random_gamma) not yet supported")
|
||||
assert "random_gamma" in _VMAP_NOT_POLY_YET
|
||||
raise unittest.SkipTest("TODO: vmap(random_gamma) not yet supported")
|
||||
|
||||
def f_jax(key, a):
|
||||
return jax.random.gamma(key, a)
|
||||
@ -1012,6 +1013,24 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
expected_output_signature=tf.TensorSpec([None, 3]))
|
||||
self.assertAllClose(f_jax(key, a), f_tf(key, a))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{op.__name__}", op=op)
|
||||
for op in [jnp.all, jnp.any, jnp.max, jnp.min, jnp.prod, jnp.sum]))
|
||||
def test_reduce(self, op=jnp.max):
|
||||
f_jax = lambda x: op(x, axis=-1, keepdims=True)
|
||||
|
||||
x = np.random.rand(7, 8)
|
||||
f_tf = self.CheckShapePolymorphism(
|
||||
f_jax,
|
||||
input_signature=[
|
||||
tf.TensorSpec([None, 8], dtype=x.dtype),
|
||||
],
|
||||
polymorphic_shapes=["(batch, ...)"],
|
||||
expected_output_signature=tf.TensorSpec([None, 1], dtype=x.dtype))
|
||||
|
||||
self.assertAllClose(f_jax(x), f_tf(x))
|
||||
|
||||
def test_reshape(self):
|
||||
|
||||
self.CheckShapePolymorphism(
|
||||
|
@ -742,7 +742,7 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
'testcase_name': "operator={}".format(operator.__name__), 'operator': operator}
|
||||
for operator in [jnp.sum, jnp.prod, jnp.max, jnp.min]]))
|
||||
def test_reduce(self, operator):
|
||||
self.check(operator, ['(m, n)'], '', {'m': 3, 'n': 4}, [(4, 5)], ['float_'],
|
||||
self.check(operator, ['(m+1, n+1)'], '', {'m': 3, 'n': 4}, [(4, 5)], ['float_'],
|
||||
jtu.rand_default(self.rng()))
|
||||
|
||||
def test_output_shape_error(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user