rocm_jax/tests/image_test.py

405 lines
17 KiB
Python

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import itertools
import unittest
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import image
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax.config import config
# We use TensorFlow and PIL as reference implementations.
try:
import tensorflow as tf
except ImportError:
tf = None
try:
from PIL import Image as PIL_Image
except ImportError:
PIL_Image = None
config.parse_flags_with_absl()
float_dtypes = jtu.dtypes.all_floating
inexact_dtypes = jtu.dtypes.inexact
class ImageTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_target={}_method={}_antialias={}".format(
jtu.format_shape_dtype_string(image_shape, dtype),
jtu.format_shape_dtype_string(target_shape, dtype), method,
antialias),
"dtype": dtype, "image_shape": image_shape,
"target_shape": target_shape,
"method": method, "antialias": antialias}
for dtype in float_dtypes
for target_shape, image_shape in itertools.combinations_with_replacement(
[[2, 3, 2, 4], [2, 6, 4, 4], [2, 33, 17, 4], [2, 50, 38, 4]], 2)
for method in ["nearest", "bilinear", "lanczos3", "lanczos5", "bicubic"]
for antialias in [False, True]))
@unittest.skipIf(not tf, "Test requires TensorFlow")
def testResizeAgainstTensorFlow(self, dtype, image_shape, target_shape, method,
antialias):
# TODO(phawkins): debug this. There is a small mismatch between TF and JAX
# for some cases of non-antialiased bicubic downscaling; we would expect
# exact equality.
if method == "bicubic" and any(x < y for x, y in
zip(target_shape, image_shape)):
raise unittest.SkipTest("non-antialiased bicubic downscaling mismatch")
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(image_shape, dtype),)
def tf_fn(x):
out = tf.image.resize(
x.astype(np.float64), tf.constant(target_shape[1:-1]),
method=method, antialias=antialias).numpy().astype(dtype)
return out
jax_fn = partial(image.resize, shape=target_shape, method=method,
antialias=antialias)
self._CheckAgainstNumpy(tf_fn, jax_fn, args_maker, check_dtypes=True,
tol={np.float16: 2e-2, jnp.bfloat16: 1e-1,
np.float32: 1e-4, np.float64: 1e-4})
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_target={}_method={}".format(
jtu.format_shape_dtype_string(image_shape, dtype),
jtu.format_shape_dtype_string(target_shape, dtype), method),
"dtype": dtype, "image_shape": image_shape,
"target_shape": target_shape,
"method": method}
for dtype in [np.float32]
for target_shape, image_shape in itertools.combinations_with_replacement(
[[3, 2], [6, 4], [33, 17], [50, 39]], 2)
for method in ["nearest", "bilinear", "lanczos3", "bicubic"]))
@unittest.skipIf(not PIL_Image, "Test requires PIL")
def testResizeAgainstPIL(self, dtype, image_shape, target_shape, method):
rng = jtu.rand_uniform(self.rng())
args_maker = lambda: (rng(image_shape, dtype),)
def pil_fn(x):
pil_methods = {
"nearest": PIL_Image.NEAREST,
"bilinear": PIL_Image.BILINEAR,
"bicubic": PIL_Image.BICUBIC,
"lanczos3": PIL_Image.LANCZOS,
}
img = PIL_Image.fromarray(x.astype(np.float32))
out = np.asarray(img.resize(target_shape[::-1], pil_methods[method]),
dtype=dtype)
return out
jax_fn = partial(image.resize, shape=target_shape, method=method,
antialias=True)
self._CheckAgainstNumpy(pil_fn, jax_fn, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_target={}_method={}".format(
jtu.format_shape_dtype_string(image_shape, dtype),
jtu.format_shape_dtype_string(target_shape, dtype), method),
"dtype": dtype, "image_shape": image_shape, "target_shape": target_shape,
"method": method}
for dtype in inexact_dtypes
for image_shape, target_shape in [
([3, 1, 2], [6, 1, 4]),
([1, 3, 2, 1], [1, 6, 4, 1]),
]
for method in ["nearest", "linear", "lanczos3", "lanczos5", "cubic"]))
def testResizeUp(self, dtype, image_shape, target_shape, method):
data = [64, 32, 32, 64, 50, 100]
expected_data = {}
expected_data["nearest"] = [
64.0, 64.0, 32.0, 32.0, 64.0, 64.0, 32.0, 32.0, 32.0, 32.0, 64.0, 64.0,
32.0, 32.0, 64.0, 64.0, 50.0, 50.0, 100.0, 100.0, 50.0, 50.0, 100.0,
100.0
]
expected_data["linear"] = [
64.0, 56.0, 40.0, 32.0, 56.0, 52.0, 44.0, 40.0, 40.0, 44.0, 52.0, 56.0,
36.5, 45.625, 63.875, 73.0, 45.5, 56.875, 79.625, 91.0, 50.0, 62.5,
87.5, 100.0
]
expected_data["lanczos3"] = [
75.8294, 59.6281, 38.4313, 22.23, 60.6851, 52.0037, 40.6454, 31.964,
35.8344, 41.0779, 47.9383, 53.1818, 24.6968, 43.0769, 67.1244, 85.5045,
35.7939, 56.4713, 83.5243, 104.2017, 44.8138, 65.1949, 91.8603, 112.2413
]
expected_data["lanczos5"] = [
77.5699, 60.0223, 40.6694, 23.1219, 61.8253, 51.2369, 39.5593, 28.9709,
35.7438, 40.8875, 46.5604, 51.7041, 21.5942, 43.5299, 67.7223, 89.658,
32.1213, 56.784, 83.984, 108.6467, 44.5802, 66.183, 90.0082, 111.6109
]
expected_data["cubic"] = [
70.1453, 59.0252, 36.9748, 25.8547, 59.3195, 53.3386, 41.4789, 35.4981,
36.383, 41.285, 51.0051, 55.9071, 30.2232, 42.151, 65.8032, 77.731,
41.6492, 55.823, 83.9288, 98.1026, 47.0363, 62.2744, 92.4903, 107.7284
]
x = np.array(data, dtype=dtype).reshape(image_shape)
output = image.resize(x, target_shape, method)
expected = np.array(expected_data[method], dtype=dtype).reshape(target_shape)
self.assertAllClose(output, expected, atol=1e-04)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_target={}_method={}_antialias={}".format(
jtu.format_shape_dtype_string(image_shape, dtype),
jtu.format_shape_dtype_string(target_shape, dtype), method,
antialias),
"dtype": dtype, "image_shape": image_shape,
"target_shape": target_shape,
"method": method, "antialias": antialias}
for dtype in [np.float32]
for target_shape, image_shape in itertools.combinations_with_replacement(
[[2, 3, 2, 4], [2, 6, 4, 4], [2, 33, 17, 4], [2, 50, 38, 4]], 2)
for method in ["bilinear", "lanczos3", "lanczos5", "bicubic"]
for antialias in [False, True]))
def testResizeGradients(self, dtype, image_shape, target_shape, method,
antialias):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(image_shape, dtype),)
jax_fn = partial(image.resize, shape=target_shape, method=method,
antialias=antialias)
jtu.check_grads(jax_fn, args_maker(), order=2, rtol=1e-2, eps=1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_target={}_method={}_antialias={}".format(
jtu.format_shape_dtype_string(image_shape, dtype),
jtu.format_shape_dtype_string(target_shape, dtype), method,
antialias),
"dtype": dtype, "image_shape": image_shape,
"target_shape": target_shape,
"method": method, "antialias": antialias}
for dtype in [np.float32]
for image_shape, target_shape in [
([1], [0]),
([5, 5], [5, 0]),
([5, 5], [0, 1]),
([5, 5], [0, 0])
]
for method in ["nearest", "linear", "lanczos3", "lanczos5", "cubic"]
for antialias in [False, True]))
def testResizeEmpty(self, dtype, image_shape, target_shape, method, antialias):
# Regression test for https://github.com/google/jax/issues/7586
image = np.ones(image_shape, dtype)
out = jax.image.resize(image, shape=target_shape, method=method, antialias=antialias)
self.assertArraysEqual(out, jnp.zeros(target_shape, dtype))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_target={}_method={}".format(
jtu.format_shape_dtype_string(image_shape, dtype),
jtu.format_shape_dtype_string(target_shape, dtype), method),
"dtype": dtype, "image_shape": image_shape,
"target_shape": target_shape,
"scale": scale, "translation": translation, "method": method}
for dtype in inexact_dtypes
for image_shape, target_shape, scale, translation in [
([3, 1, 2], [6, 1, 4], [2.0, 1.0, 2.0], [1.0, 0.0, -1.0]),
([1, 3, 2, 1], [1, 6, 4, 1], [1.0, 2.0, 2.0, 1.0], [0.0, 1.0, -1.0, 0.0])]
for method in ["linear", "lanczos3", "lanczos5", "cubic"]))
def testScaleAndTranslateUp(self, dtype, image_shape, target_shape, scale,
translation, method):
data = [64, 32, 32, 64, 50, 100]
# Note zeros occur in the output because the sampling location is outside
# the boundaries of the input image.
expected_data = {}
expected_data["linear"] = [
0.0, 0.0, 0.0, 0.0, 56.0, 40.0, 32.0, 0.0, 52.0, 44.0, 40.0, 0.0, 44.0,
52.0, 56.0, 0.0, 45.625, 63.875, 73.0, 0.0, 56.875, 79.625, 91.0, 0.0
]
expected_data["lanczos3"] = [
0.0, 0.0, 0.0, 0.0, 59.6281, 38.4313, 22.23, 0.0, 52.0037, 40.6454,
31.964, 0.0, 41.0779, 47.9383, 53.1818, 0.0, 43.0769, 67.1244, 85.5045,
0.0, 56.4713, 83.5243, 104.2017, 0.0
]
expected_data["lanczos5"] = [
0.0, 0.0, 0.0, 0.0, 60.0223, 40.6694, 23.1219, 0.0, 51.2369, 39.5593,
28.9709, 0.0, 40.8875, 46.5604, 51.7041, 0.0, 43.5299, 67.7223, 89.658,
0.0, 56.784, 83.984, 108.6467, 0.0
]
expected_data["cubic"] = [
0.0, 0.0, 0.0, 0.0, 59.0252, 36.9748, 25.8547, 0.0, 53.3386, 41.4789,
35.4981, 0.0, 41.285, 51.0051, 55.9071, 0.0, 42.151, 65.8032, 77.731,
0.0, 55.823, 83.9288, 98.1026, 0.0
]
x = np.array(data, dtype=dtype).reshape(image_shape)
# Should we test different float types here?
scale_a = jnp.array(scale, dtype=jnp.float32)
translation_a = jnp.array(translation, dtype=jnp.float32)
output = image.scale_and_translate(x, target_shape, range(len(image_shape)),
scale_a, translation_a,
method)
expected = np.array(
expected_data[method], dtype=dtype).reshape(target_shape)
self.assertAllClose(output, expected, atol=2e-03)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}_method={}_antialias={}".format(
jtu.dtype_str(dtype), method, antialias),
"dtype": dtype, "method": method, "antialias": antialias}
for dtype in inexact_dtypes
for method in ["linear", "lanczos3", "lanczos5", "cubic"]
for antialias in [True, False]))
def testScaleAndTranslateDown(self, dtype, method, antialias):
image_shape = [1, 6, 7, 1]
target_shape = [1, 3, 3, 1]
data = [
51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92,
41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89,
71, 32, 23, 23, 35, 93
]
if antialias:
expected_data = {}
expected_data["linear"] = [
43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0
]
expected_data["lanczos3"] = [
43.2884, 57.9091, 54.6439, 48.5856, 58.2427, 53.7551, 0, 0, 0
]
expected_data["lanczos5"] = [
43.9209, 57.6360, 54.9575, 48.9272, 58.1865, 53.1948, 0, 0, 0
]
expected_data["cubic"] = [
42.9935, 59.1687, 54.2138, 48.2640, 58.2678, 54.4088, 0, 0, 0
]
else:
expected_data = {}
expected_data["linear"] = [
43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0
]
expected_data["lanczos3"] = [
44.1390, 87.8786, 63.3111, 25.1161, 20.8795, 53.6165, 0, 0, 0
]
expected_data["lanczos5"] = [
44.8835, 85.5896, 66.7231, 16.9983, 19.8891, 47.1446, 0, 0, 0
]
expected_data["cubic"] = [
43.6426, 88.8854, 60.6638, 31.4685, 22.1204, 58.3457, 0, 0, 0
]
x = np.array(data, dtype=dtype).reshape(image_shape)
expected = np.array(
expected_data[method], dtype=dtype).reshape(target_shape)
scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)
output = image.scale_and_translate(
x, target_shape, (0,1,2,3),
scale_a, translation_a, method, antialias=antialias)
self.assertAllClose(output, expected, atol=2e-03)
# Tests that running with just a subset of dimensions that have non-trivial
# scale and translation.
output = image.scale_and_translate(
x, target_shape, (1,2),
scale_a[1:3], translation_a[1:3], method, antialias=antialias)
self.assertAllClose(output, expected, atol=2e-03)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "antialias={}".format(antialias),
"antialias": antialias}
for antialias in [True, False]))
def testScaleAndTranslateJITs(self, antialias):
image_shape = [1, 6, 7, 1]
target_shape = [1, 3, 3, 1]
data = [
51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92,
41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89,
71, 32, 23, 23, 35, 93
]
if antialias:
expected_data = [
43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0
]
else:
expected_data = [43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0]
x = jnp.array(data, dtype=jnp.float32).reshape(image_shape)
expected = jnp.array(expected_data, dtype=jnp.float32).reshape(target_shape)
scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)
def jit_fn(in_array, s, t):
return jax.image.scale_and_translate(
in_array, target_shape, (0, 1, 2, 3), s, t,
"linear", antialias, precision=jax.lax.Precision.HIGHEST)
output = jax.jit(jit_fn)(x, scale_a, translation_a)
self.assertAllClose(output, expected, atol=2e-03)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "antialias={}".format(antialias),
"antialias": antialias}
for antialias in [True, False]))
def testScaleAndTranslateGradFinite(self, antialias):
image_shape = [1, 6, 7, 1]
target_shape = [1, 3, 3, 1]
data = [
51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92,
41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89,
71, 32, 23, 23, 35, 93
]
x = jnp.array(data, dtype=jnp.float32).reshape(image_shape)
scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)
def scale_fn(s):
return jnp.sum(jax.image.scale_and_translate(
x, target_shape, (0, 1, 2, 3), s, translation_a, "linear", antialias,
precision=jax.lax.Precision.HIGHEST))
scale_out = jax.grad(scale_fn)(scale_a)
self.assertTrue(jnp.all(jnp.isfinite(scale_out)))
def translate_fn(t):
return jnp.sum(jax.image.scale_and_translate(
x, target_shape, (0, 1, 2, 3), scale_a, t, "linear", antialias,
precision=jax.lax.Precision.HIGHEST))
translate_out = jax.grad(translate_fn)(translation_a)
self.assertTrue(jnp.all(jnp.isfinite(translate_out)))
def testScaleAndTranslateNegativeDims(self):
data = jnp.full((3, 3), 0.5)
actual = jax.image.scale_and_translate(
data, (5, 5), (-2, -1), jnp.ones(2), jnp.zeros(2), "linear")
expected = jax.image.scale_and_translate(
data, (5, 5), (0, 1), jnp.ones(2), jnp.zeros(2), "linear")
self.assertAllClose(actual, expected)
def testResizeWithUnusualShapes(self):
x = jnp.ones((3, 4))
# Array shapes are accepted
self.assertEqual((10, 17),
jax.image.resize(x, jnp.array((10, 17)), "nearest").shape)
with self.assertRaises(TypeError):
# Fractional shapes are disallowed
jax.image.resize(x, [10.5, 17], "bicubic")
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())