diff --git a/tests/image_test.py b/tests/image_test.py index a3a5ff3bd..5508cce67 100644 --- a/tests/image_test.py +++ b/tests/image_test.py @@ -21,6 +21,7 @@ import numpy as np from absl.testing import absltest from absl.testing import parameterized +import jax from jax import image from jax import numpy as jnp from jax import test_util as jtu @@ -181,7 +182,7 @@ class ImageTest(jtu.JaxTestCase): jtu.check_grads(jax_fn, args_maker(), order=2, rtol=1e-2, eps=1.) @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name":"_shape={}_target={}_method={}".format( + {"testcase_name": "_shape={}_target={}_method={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method), "dtype": dtype, "image_shape": image_shape, @@ -221,7 +222,8 @@ class ImageTest(jtu.JaxTestCase): # Should we test different float types here? scale_a = jnp.array(scale, dtype=jnp.float32) translation_a = jnp.array(translation, dtype=jnp.float32) - output = image.scale_and_translate(x, target_shape, scale_a, translation_a, + output = image.scale_and_translate(x, target_shape, range(len(image_shape)), + scale_a, translation_a, method) expected = np.array( @@ -231,7 +233,7 @@ class ImageTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}_method={}_antialias={}".format( jtu.dtype_str(dtype), method, antialias), - "dtype":dtype, "method":method, "antialias": antialias} + "dtype": dtype, "method": method, "antialias": antialias} for dtype in inexact_dtypes for method in ["linear", "lanczos3", "lanczos5", "cubic"] for antialias in [True, False])) @@ -274,12 +276,55 @@ class ImageTest(jtu.JaxTestCase): ] x = np.array(data, dtype=dtype).reshape(image_shape) - scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32) - translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32) - output = image.scale_and_translate( - x, target_shape, scale_a, translation_a, method, antialias=antialias) expected = np.array( expected_data[method], dtype=dtype).reshape(target_shape) + scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32) + translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32) + + output = image.scale_and_translate( + x, target_shape, (0,1,2,3), + scale_a, translation_a, method, antialias=antialias) + self.assertAllClose(output, expected, atol=2e-03) + + # Tests that running with just a subset of dimensions that have non-trivial + # scale and translation. + output = image.scale_and_translate( + x, target_shape, (1,2), + scale_a[1:3], translation_a[1:3], method, antialias=antialias) + self.assertAllClose(output, expected, atol=2e-03) + + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "antialias={}".format(antialias), + "antialias": antialias} + for antialias in [True, False])) + def testScaleAndTranslateJITs(self, antialias): + image_shape = [1, 6, 7, 1] + target_shape = [1, 3, 3, 1] + + data = [ + 51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92, + 41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89, + 71, 32, 23, 23, 35, 93 + ] + if antialias: + expected_data = [ + 43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0 + ] + else: + expected_data = [43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0] + x = jnp.array(data, dtype=jnp.float32).reshape(image_shape) + + expected = jnp.array(expected_data, dtype=jnp.float32).reshape(target_shape) + scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32) + translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32) + + def jit_fn(in_array, s, t): + return jax.image.scale_and_translate( + in_array, target_shape, (0, 1, 2, 3), s, t, + "linear", antialias, precision=jax.lax.Precision.HIGHEST) + + output = jax.jit(jit_fn)(x, scale_a, translation_a) self.assertAllClose(output, expected, atol=2e-03)