Refactor the jax_to_tf tests to separate the primitive test harness (#3376)

* Refactor the jax_to_tf tests to separate the primitive test harness from the
test.

The goal is to have a collection of test harnesses for the JAX primitives
to be able to test various implementation (JAX, NumPy, TensorFlow). For
now we use these harnesses only in the jax_to_tf tests, although we can
later use them for lax_test.

Demonstrate the use of the harness for lax.pad and lax.squeeze, both in
tf_ops_test and lax_test. The plan is to add support for more primitives
as we make progress testing jax_to_tf.

* Expanded pad harness with negative pads
This commit is contained in:
George Necula 2020-06-09 11:07:32 +03:00 committed by GitHub
parent 6aa8f2461c
commit c12541d5f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 189 additions and 13 deletions

View File

@ -0,0 +1,170 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines test inputs and invocations for JAX primitives.
Used to test various implementations of JAX primitives, e.g., against
NumPy (lax_reference) or TensorFlow.
"""
from typing import Any, Callable, Dict, Iterable, Optional, NamedTuple, Sequence, Tuple, Union
from absl import testing
from jax import test_util as jtu
from jax import dtypes
from jax import lax
import numpy as np
# TODO: these are copied from tests/lax_test.py (make this source of truth)
def supported_dtypes(dtypes):
return [t for t in dtypes if t in jtu.supported_dtypes()]
float_dtypes = supported_dtypes([dtypes.bfloat16, np.float16, np.float32,
np.float64])
complex_elem_dtypes = supported_dtypes([np.float32, np.float64])
complex_dtypes = supported_dtypes([np.complex64, np.complex128])
inexact_dtypes = float_dtypes + complex_dtypes
int_dtypes = supported_dtypes([np.int32, np.int64])
uint_dtypes = supported_dtypes([np.uint32, np.uint64])
bool_dtypes = [np.bool_]
default_dtypes = float_dtypes + int_dtypes
all_dtypes = float_dtypes + complex_dtypes + int_dtypes + bool_dtypes
Rng = Any # A random number generator
class RandArg(NamedTuple):
"""Descriptor for a randomly generated argument."""
shape: Tuple[int, ...]
dtype: np.dtype
class StaticArg(NamedTuple):
"""Descriptor for a static argument."""
value: Any
class Harness:
"""Specifies inputs and callable for a primitive.
A primitive can take dynamic and static arguments. The dynamic arguments can
be generated using a RNG, are numeric (and appropriate for JIT).
"""
# Descriptive name of the harness, used as a testcase_name. Unique in a group.
name: str
# The function taking all arguments (static and dynamic).
fun: Callable
arg_descriptors: Sequence[Union[RandArg, StaticArg, Any]]
rng_factory: Callable
params: Dict[str, Any]
def __init__(self, name, fun, arg_descriptors, *,
rng_factory=jtu.rand_default, **params):
self.name = name
self.fun = fun
self.arg_descriptors = arg_descriptors
self.rng_factory = rng_factory
self.params = params
def __str__(self):
return self.name
def _arg_maker(self, arg_descriptor, rng: Rng):
if isinstance(arg_descriptor, StaticArg):
return arg_descriptor.value
if isinstance(arg_descriptor, RandArg):
return self.rng_factory(rng)(arg_descriptor.shape, arg_descriptor.dtype)
return arg_descriptor
def args_maker(self, rng: Rng) -> Sequence:
"""All-argument maker, including the static ones."""
return [self._arg_maker(ad, rng) for ad in self.arg_descriptors]
def dyn_args_maker(self, rng: Rng) -> Sequence:
"""A dynamic-argument maker, for use with `dyn_fun`."""
return [self._arg_maker(ad, rng) for ad in self.arg_descriptors
if not isinstance(ad, StaticArg)]
def dyn_fun(self, *dyn_args):
"""Invokes `fun` given just the dynamic arguments."""
all_args = self._args_from_dynargs(dyn_args)
return self.fun(*all_args)
def _args_from_dynargs(self, dyn_args: Sequence) -> Sequence:
"""All arguments, including the static ones."""
next_dynamic_argnum = 0
all_args = []
for ad in self.arg_descriptors:
if isinstance(ad, StaticArg):
all_args.append(ad.value)
else:
all_args.append(dyn_args[next_dynamic_argnum])
next_dynamic_argnum += 1
return all_args
def parameterized(harness_group: Iterable[Harness],
one_containing : Optional[str] = None):
"""Decorator for tests.
The tests receive a `harness` argument.
The `one_containing` parameter is useful for debugging. If given, then
picks only one harness whose name contains the string. The whole set of
parameterized tests is reduced to one test, whose name is not decorated
to make it easier to pick for running.
"""
cases = tuple(
dict(testcase_name=harness.name if one_containing is None else "",
harness=harness)
for harness in harness_group
if one_containing is None or one_containing in harness.name)
if one_containing is not None:
cases = cases[0:1]
return testing.parameterized.named_parameters(*cases)
lax_pad = jtu.cases_from_list(
Harness(f"_inshape={jtu.format_shape_dtype_string(arg_shape, dtype)}_pads={pads}",
lax.pad,
[RandArg(arg_shape, dtype), np.array(0, dtype), StaticArg(pads)],
rng_factory=jtu.rand_small,
arg_shape=arg_shape, dtype=dtype, pads=pads)
for arg_shape in [(2, 3)]
for dtype in default_dtypes
for pads in [
[(0, 0, 0), (0, 0, 0)], # no padding
[(1, 1, 0), (2, 2, 0)], # edge padding
[(1, 2, 1), (0, 1, 0)], # edge padding and interior padding
[(0, 0, 0), (-1, -1, 0)], # negative padding
[(0, 0, 0), (-2, -2, 4)], # add big dilation then remove from edges
[(0, 0, 0), (-2, -3, 1)], # remove everything in one dimension
]
)
lax_squeeze = jtu.cases_from_list(
Harness(f"_inshape={jtu.format_shape_dtype_string(arg_shape, dtype)}_dimensions={dimensions}", # type: ignore
lax.squeeze,
[RandArg(arg_shape, dtype), StaticArg(dimensions)], # type: ignore[has-type]
arg_shape=arg_shape, dtype=dtype, dimensions=dimensions) # type: ignore[has-type]
for arg_shape, dimensions in [
[(1,), (0,)],
[(1,), (-1,)],
[(2, 1, 4), (1,)],
[(2, 1, 4), (-2,)],
[(2, 1, 3, 1), (1,)],
[(2, 1, 3, 1), (1, 3)],
[(2, 1, 3, 1), (3,)],
[(2, 1, 3, 1), (1, -1)],
]
for dtype in [np.float32]
)

View File

@ -13,10 +13,13 @@
# limitations under the License.
"""Tests for the jax_to_tf transformation."""
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import dtypes
import jax.lax as lax
import jax.numpy as jnp
from jax import test_util as jtu
@ -24,12 +27,13 @@ import numpy as np
import tensorflow as tf # type: ignore[import]
from jax.experimental import jax_to_tf
from jax.experimental.jax_to_tf.tests import tf_test_util
from . import tf_test_util
from . import primitive_harness
from jax.config import config
config.parse_flags_with_absl()
# TODO(tomhennigan) Increase coverage here.
LAX_ELEMENTWISE_UNARY = (
lax.abs,
@ -170,10 +174,15 @@ class TfOpsTest(tf_test_util.JaxToTfTestCase):
f_jax = jax.jit(lambda x: jnp.concatenate(x, axis=0))
self.ConvertAndCompare(f_jax, values, with_function=True)
def test_pad(self):
values = np.array([1, 2], dtype=np.float32)
f_jax = jax.jit(lambda x: jax.lax.pad(x, 0.0, [(3, 1, 2)]))
self.ConvertAndCompare(f_jax, values, with_function=True)
@primitive_harness.parameterized(primitive_harness.lax_pad)
def test_pad(self, harness: primitive_harness.Harness):
if harness.params["dtype"] is dtypes.bfloat16:
raise unittest.SkipTest("bfloat16 not implemented")
# TODO: implement (or decide not to) pads with negative edge padding
if any([lo < 0 or hi < 0 for lo, hi, mid in harness.params["pads"]]):
raise unittest.SkipTest("pad with negative pad not supported")
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
with_function=True)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{f_jax.__name__}",
@ -252,13 +261,10 @@ class TfOpsTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(r_jax[np.isfinite(r_jax)],
r_tf[np.isfinite(r_tf)], atol=1e-4)
# TODO(necula): replace these tests with LAX reference tests
def test_squeeze(self):
shape = (2, 1, 3, 1)
values = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
for squeeze_dims in ((1,), (3,), (1, 3,)):
f_jax = jax.jit(lambda v: jnp.squeeze(v, axis=squeeze_dims)) # pylint: disable=cell-var-from-loop
self.ConvertAndCompare(f_jax, values, with_function=True)
@primitive_harness.parameterized(primitive_harness.lax_squeeze)
def test_squeeze(self, harness: primitive_harness.Harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
with_function=True)
def test_gather(self):
values = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32)