Updating image_test

This commit is contained in:
johnpjf 2020-09-21 16:21:48 -07:00 committed by GitHub
parent be50847cee
commit ae910cdd31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -21,6 +21,7 @@ import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import image
from jax import numpy as jnp
from jax import test_util as jtu
@ -181,7 +182,7 @@ class ImageTest(jtu.JaxTestCase):
jtu.check_grads(jax_fn, args_maker(), order=2, rtol=1e-2, eps=1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":"_shape={}_target={}_method={}".format(
{"testcase_name": "_shape={}_target={}_method={}".format(
jtu.format_shape_dtype_string(image_shape, dtype),
jtu.format_shape_dtype_string(target_shape, dtype), method),
"dtype": dtype, "image_shape": image_shape,
@ -221,7 +222,8 @@ class ImageTest(jtu.JaxTestCase):
# Should we test different float types here?
scale_a = jnp.array(scale, dtype=jnp.float32)
translation_a = jnp.array(translation, dtype=jnp.float32)
output = image.scale_and_translate(x, target_shape, scale_a, translation_a,
output = image.scale_and_translate(x, target_shape, range(len(image_shape)),
scale_a, translation_a,
method)
expected = np.array(
@ -231,7 +233,7 @@ class ImageTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}_method={}_antialias={}".format(
jtu.dtype_str(dtype), method, antialias),
"dtype":dtype, "method":method, "antialias": antialias}
"dtype": dtype, "method": method, "antialias": antialias}
for dtype in inexact_dtypes
for method in ["linear", "lanczos3", "lanczos5", "cubic"]
for antialias in [True, False]))
@ -274,12 +276,55 @@ class ImageTest(jtu.JaxTestCase):
]
x = np.array(data, dtype=dtype).reshape(image_shape)
scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)
output = image.scale_and_translate(
x, target_shape, scale_a, translation_a, method, antialias=antialias)
expected = np.array(
expected_data[method], dtype=dtype).reshape(target_shape)
scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)
output = image.scale_and_translate(
x, target_shape, (0,1,2,3),
scale_a, translation_a, method, antialias=antialias)
self.assertAllClose(output, expected, atol=2e-03)
# Tests that running with just a subset of dimensions that have non-trivial
# scale and translation.
output = image.scale_and_translate(
x, target_shape, (1,2),
scale_a[1:3], translation_a[1:3], method, antialias=antialias)
self.assertAllClose(output, expected, atol=2e-03)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "antialias={}".format(antialias),
"antialias": antialias}
for antialias in [True, False]))
def testScaleAndTranslateJITs(self, antialias):
image_shape = [1, 6, 7, 1]
target_shape = [1, 3, 3, 1]
data = [
51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92,
41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89,
71, 32, 23, 23, 35, 93
]
if antialias:
expected_data = [
43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0
]
else:
expected_data = [43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0]
x = jnp.array(data, dtype=jnp.float32).reshape(image_shape)
expected = jnp.array(expected_data, dtype=jnp.float32).reshape(target_shape)
scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)
def jit_fn(in_array, s, t):
return jax.image.scale_and_translate(
in_array, target_shape, (0, 1, 2, 3), s, t,
"linear", antialias, precision=jax.lax.Precision.HIGHEST)
output = jax.jit(jit_fn)(x, scale_a, translation_a)
self.assertAllClose(output, expected, atol=2e-03)