mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00

This makes jax.image.resize() robust to having jnp arrays passed as sizes. It also turns out some users were passing floating point values here, and this means they are correctly flagged as errors. PiperOrigin-RevId: 399996702
397 lines
17 KiB
Python
397 lines
17 KiB
Python
# Copyright 2020 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from functools import partial
|
|
import itertools
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
|
|
import jax
|
|
from jax import image
|
|
from jax import numpy as jnp
|
|
from jax._src import test_util as jtu
|
|
|
|
from jax.config import config
|
|
|
|
# We use TensorFlow and PIL as reference implementations.
|
|
try:
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
tf = None
|
|
|
|
try:
|
|
from PIL import Image as PIL_Image
|
|
except ImportError:
|
|
PIL_Image = None
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
float_dtypes = jtu.dtypes.all_floating
|
|
inexact_dtypes = jtu.dtypes.inexact
|
|
|
|
class ImageTest(jtu.JaxTestCase):
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_shape={}_target={}_method={}_antialias={}".format(
|
|
jtu.format_shape_dtype_string(image_shape, dtype),
|
|
jtu.format_shape_dtype_string(target_shape, dtype), method,
|
|
antialias),
|
|
"dtype": dtype, "image_shape": image_shape,
|
|
"target_shape": target_shape,
|
|
"method": method, "antialias": antialias}
|
|
for dtype in float_dtypes
|
|
for target_shape, image_shape in itertools.combinations_with_replacement(
|
|
[[2, 3, 2, 4], [2, 6, 4, 4], [2, 33, 17, 4], [2, 50, 38, 4]], 2)
|
|
for method in ["nearest", "bilinear", "lanczos3", "lanczos5", "bicubic"]
|
|
for antialias in [False, True]))
|
|
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
|
def testResizeAgainstTensorFlow(self, dtype, image_shape, target_shape, method,
|
|
antialias):
|
|
# TODO(phawkins): debug this. There is a small mismatch between TF and JAX
|
|
# for some cases of non-antialiased bicubic downscaling; we would expect
|
|
# exact equality.
|
|
if method == "bicubic" and any(x < y for x, y in
|
|
zip(target_shape, image_shape)):
|
|
raise unittest.SkipTest("non-antialiased bicubic downscaling mismatch")
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: (rng(image_shape, dtype),)
|
|
def tf_fn(x):
|
|
out = tf.image.resize(
|
|
x.astype(np.float64), tf.constant(target_shape[1:-1]),
|
|
method=method, antialias=antialias).numpy().astype(dtype)
|
|
return out
|
|
jax_fn = partial(image.resize, shape=target_shape, method=method,
|
|
antialias=antialias)
|
|
self._CheckAgainstNumpy(tf_fn, jax_fn, args_maker, check_dtypes=True,
|
|
tol={np.float16: 2e-2, jnp.bfloat16: 1e-1,
|
|
np.float32: 1e-4, np.float64: 1e-4})
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_shape={}_target={}_method={}".format(
|
|
jtu.format_shape_dtype_string(image_shape, dtype),
|
|
jtu.format_shape_dtype_string(target_shape, dtype), method),
|
|
"dtype": dtype, "image_shape": image_shape,
|
|
"target_shape": target_shape,
|
|
"method": method}
|
|
for dtype in [np.float32]
|
|
for target_shape, image_shape in itertools.combinations_with_replacement(
|
|
[[3, 2], [6, 4], [33, 17], [50, 39]], 2)
|
|
for method in ["nearest", "bilinear", "lanczos3", "bicubic"]))
|
|
@unittest.skipIf(not PIL_Image, "Test requires PIL")
|
|
def testResizeAgainstPIL(self, dtype, image_shape, target_shape, method):
|
|
rng = jtu.rand_uniform(self.rng())
|
|
args_maker = lambda: (rng(image_shape, dtype),)
|
|
def pil_fn(x):
|
|
pil_methods = {
|
|
"nearest": PIL_Image.NEAREST,
|
|
"bilinear": PIL_Image.BILINEAR,
|
|
"bicubic": PIL_Image.BICUBIC,
|
|
"lanczos3": PIL_Image.LANCZOS,
|
|
}
|
|
img = PIL_Image.fromarray(x.astype(np.float32))
|
|
out = np.asarray(img.resize(target_shape[::-1], pil_methods[method]),
|
|
dtype=dtype)
|
|
return out
|
|
jax_fn = partial(image.resize, shape=target_shape, method=method,
|
|
antialias=True)
|
|
self._CheckAgainstNumpy(pil_fn, jax_fn, args_maker, check_dtypes=True)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_shape={}_target={}_method={}".format(
|
|
jtu.format_shape_dtype_string(image_shape, dtype),
|
|
jtu.format_shape_dtype_string(target_shape, dtype), method),
|
|
"dtype": dtype, "image_shape": image_shape, "target_shape": target_shape,
|
|
"method": method}
|
|
for dtype in inexact_dtypes
|
|
for image_shape, target_shape in [
|
|
([3, 1, 2], [6, 1, 4]),
|
|
([1, 3, 2, 1], [1, 6, 4, 1]),
|
|
]
|
|
for method in ["nearest", "linear", "lanczos3", "lanczos5", "cubic"]))
|
|
def testResizeUp(self, dtype, image_shape, target_shape, method):
|
|
data = [64, 32, 32, 64, 50, 100]
|
|
expected_data = {}
|
|
expected_data["nearest"] = [
|
|
64.0, 64.0, 32.0, 32.0, 64.0, 64.0, 32.0, 32.0, 32.0, 32.0, 64.0, 64.0,
|
|
32.0, 32.0, 64.0, 64.0, 50.0, 50.0, 100.0, 100.0, 50.0, 50.0, 100.0,
|
|
100.0
|
|
]
|
|
expected_data["linear"] = [
|
|
64.0, 56.0, 40.0, 32.0, 56.0, 52.0, 44.0, 40.0, 40.0, 44.0, 52.0, 56.0,
|
|
36.5, 45.625, 63.875, 73.0, 45.5, 56.875, 79.625, 91.0, 50.0, 62.5,
|
|
87.5, 100.0
|
|
]
|
|
expected_data["lanczos3"] = [
|
|
75.8294, 59.6281, 38.4313, 22.23, 60.6851, 52.0037, 40.6454, 31.964,
|
|
35.8344, 41.0779, 47.9383, 53.1818, 24.6968, 43.0769, 67.1244, 85.5045,
|
|
35.7939, 56.4713, 83.5243, 104.2017, 44.8138, 65.1949, 91.8603, 112.2413
|
|
]
|
|
expected_data["lanczos5"] = [
|
|
77.5699, 60.0223, 40.6694, 23.1219, 61.8253, 51.2369, 39.5593, 28.9709,
|
|
35.7438, 40.8875, 46.5604, 51.7041, 21.5942, 43.5299, 67.7223, 89.658,
|
|
32.1213, 56.784, 83.984, 108.6467, 44.5802, 66.183, 90.0082, 111.6109
|
|
]
|
|
expected_data["cubic"] = [
|
|
70.1453, 59.0252, 36.9748, 25.8547, 59.3195, 53.3386, 41.4789, 35.4981,
|
|
36.383, 41.285, 51.0051, 55.9071, 30.2232, 42.151, 65.8032, 77.731,
|
|
41.6492, 55.823, 83.9288, 98.1026, 47.0363, 62.2744, 92.4903, 107.7284
|
|
]
|
|
x = np.array(data, dtype=dtype).reshape(image_shape)
|
|
output = image.resize(x, target_shape, method)
|
|
expected = np.array(expected_data[method], dtype=dtype).reshape(target_shape)
|
|
self.assertAllClose(output, expected, atol=1e-04)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_shape={}_target={}_method={}_antialias={}".format(
|
|
jtu.format_shape_dtype_string(image_shape, dtype),
|
|
jtu.format_shape_dtype_string(target_shape, dtype), method,
|
|
antialias),
|
|
"dtype": dtype, "image_shape": image_shape,
|
|
"target_shape": target_shape,
|
|
"method": method, "antialias": antialias}
|
|
for dtype in [np.float32]
|
|
for target_shape, image_shape in itertools.combinations_with_replacement(
|
|
[[2, 3, 2, 4], [2, 6, 4, 4], [2, 33, 17, 4], [2, 50, 38, 4]], 2)
|
|
for method in ["bilinear", "lanczos3", "lanczos5", "bicubic"]
|
|
for antialias in [False, True]))
|
|
def testResizeGradients(self, dtype, image_shape, target_shape, method,
|
|
antialias):
|
|
rng = jtu.rand_default(self.rng())
|
|
args_maker = lambda: (rng(image_shape, dtype),)
|
|
jax_fn = partial(image.resize, shape=target_shape, method=method,
|
|
antialias=antialias)
|
|
jtu.check_grads(jax_fn, args_maker(), order=2, rtol=1e-2, eps=1.)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_shape={}_target={}_method={}_antialias={}".format(
|
|
jtu.format_shape_dtype_string(image_shape, dtype),
|
|
jtu.format_shape_dtype_string(target_shape, dtype), method,
|
|
antialias),
|
|
"dtype": dtype, "image_shape": image_shape,
|
|
"target_shape": target_shape,
|
|
"method": method, "antialias": antialias}
|
|
for dtype in [np.float32]
|
|
for image_shape, target_shape in [
|
|
([1], [0]),
|
|
([5, 5], [5, 0]),
|
|
([5, 5], [0, 1]),
|
|
([5, 5], [0, 0])
|
|
]
|
|
for method in ["nearest", "linear", "lanczos3", "lanczos5", "cubic"]
|
|
for antialias in [False, True]))
|
|
def testResizeEmpty(self, dtype, image_shape, target_shape, method, antialias):
|
|
# Regression test for https://github.com/google/jax/issues/7586
|
|
image = np.ones(image_shape, dtype)
|
|
out = jax.image.resize(image, shape=target_shape, method=method, antialias=antialias)
|
|
self.assertArraysEqual(out, jnp.zeros(target_shape, dtype))
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_shape={}_target={}_method={}".format(
|
|
jtu.format_shape_dtype_string(image_shape, dtype),
|
|
jtu.format_shape_dtype_string(target_shape, dtype), method),
|
|
"dtype": dtype, "image_shape": image_shape,
|
|
"target_shape": target_shape,
|
|
"scale": scale, "translation": translation, "method": method}
|
|
for dtype in inexact_dtypes
|
|
for image_shape, target_shape, scale, translation in [
|
|
([3, 1, 2], [6, 1, 4], [2.0, 1.0, 2.0], [1.0, 0.0, -1.0]),
|
|
([1, 3, 2, 1], [1, 6, 4, 1], [1.0, 2.0, 2.0, 1.0], [0.0, 1.0, -1.0, 0.0])]
|
|
for method in ["linear", "lanczos3", "lanczos5", "cubic"]))
|
|
def testScaleAndTranslateUp(self, dtype, image_shape, target_shape, scale,
|
|
translation, method):
|
|
data = [64, 32, 32, 64, 50, 100]
|
|
# Note zeros occur in the output because the sampling location is outside
|
|
# the boundaries of the input image.
|
|
expected_data = {}
|
|
expected_data["linear"] = [
|
|
0.0, 0.0, 0.0, 0.0, 56.0, 40.0, 32.0, 0.0, 52.0, 44.0, 40.0, 0.0, 44.0,
|
|
52.0, 56.0, 0.0, 45.625, 63.875, 73.0, 0.0, 56.875, 79.625, 91.0, 0.0
|
|
]
|
|
expected_data["lanczos3"] = [
|
|
0.0, 0.0, 0.0, 0.0, 59.6281, 38.4313, 22.23, 0.0, 52.0037, 40.6454,
|
|
31.964, 0.0, 41.0779, 47.9383, 53.1818, 0.0, 43.0769, 67.1244, 85.5045,
|
|
0.0, 56.4713, 83.5243, 104.2017, 0.0
|
|
]
|
|
expected_data["lanczos5"] = [
|
|
0.0, 0.0, 0.0, 0.0, 60.0223, 40.6694, 23.1219, 0.0, 51.2369, 39.5593,
|
|
28.9709, 0.0, 40.8875, 46.5604, 51.7041, 0.0, 43.5299, 67.7223, 89.658,
|
|
0.0, 56.784, 83.984, 108.6467, 0.0
|
|
]
|
|
expected_data["cubic"] = [
|
|
0.0, 0.0, 0.0, 0.0, 59.0252, 36.9748, 25.8547, 0.0, 53.3386, 41.4789,
|
|
35.4981, 0.0, 41.285, 51.0051, 55.9071, 0.0, 42.151, 65.8032, 77.731,
|
|
0.0, 55.823, 83.9288, 98.1026, 0.0
|
|
]
|
|
x = np.array(data, dtype=dtype).reshape(image_shape)
|
|
# Should we test different float types here?
|
|
scale_a = jnp.array(scale, dtype=jnp.float32)
|
|
translation_a = jnp.array(translation, dtype=jnp.float32)
|
|
output = image.scale_and_translate(x, target_shape, range(len(image_shape)),
|
|
scale_a, translation_a,
|
|
method)
|
|
|
|
expected = np.array(
|
|
expected_data[method], dtype=dtype).reshape(target_shape)
|
|
self.assertAllClose(output, expected, atol=2e-03)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_dtype={}_method={}_antialias={}".format(
|
|
jtu.dtype_str(dtype), method, antialias),
|
|
"dtype": dtype, "method": method, "antialias": antialias}
|
|
for dtype in inexact_dtypes
|
|
for method in ["linear", "lanczos3", "lanczos5", "cubic"]
|
|
for antialias in [True, False]))
|
|
def testScaleAndTranslateDown(self, dtype, method, antialias):
|
|
image_shape = [1, 6, 7, 1]
|
|
target_shape = [1, 3, 3, 1]
|
|
|
|
data = [
|
|
51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92,
|
|
41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89,
|
|
71, 32, 23, 23, 35, 93
|
|
]
|
|
if antialias:
|
|
expected_data = {}
|
|
expected_data["linear"] = [
|
|
43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0
|
|
]
|
|
expected_data["lanczos3"] = [
|
|
43.2884, 57.9091, 54.6439, 48.5856, 58.2427, 53.7551, 0, 0, 0
|
|
]
|
|
expected_data["lanczos5"] = [
|
|
43.9209, 57.6360, 54.9575, 48.9272, 58.1865, 53.1948, 0, 0, 0
|
|
]
|
|
expected_data["cubic"] = [
|
|
42.9935, 59.1687, 54.2138, 48.2640, 58.2678, 54.4088, 0, 0, 0
|
|
]
|
|
else:
|
|
expected_data = {}
|
|
expected_data["linear"] = [
|
|
43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0
|
|
]
|
|
expected_data["lanczos3"] = [
|
|
44.1390, 87.8786, 63.3111, 25.1161, 20.8795, 53.6165, 0, 0, 0
|
|
]
|
|
expected_data["lanczos5"] = [
|
|
44.8835, 85.5896, 66.7231, 16.9983, 19.8891, 47.1446, 0, 0, 0
|
|
]
|
|
expected_data["cubic"] = [
|
|
43.6426, 88.8854, 60.6638, 31.4685, 22.1204, 58.3457, 0, 0, 0
|
|
]
|
|
x = np.array(data, dtype=dtype).reshape(image_shape)
|
|
|
|
expected = np.array(
|
|
expected_data[method], dtype=dtype).reshape(target_shape)
|
|
scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
|
|
translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)
|
|
|
|
output = image.scale_and_translate(
|
|
x, target_shape, (0,1,2,3),
|
|
scale_a, translation_a, method, antialias=antialias)
|
|
self.assertAllClose(output, expected, atol=2e-03)
|
|
|
|
# Tests that running with just a subset of dimensions that have non-trivial
|
|
# scale and translation.
|
|
output = image.scale_and_translate(
|
|
x, target_shape, (1,2),
|
|
scale_a[1:3], translation_a[1:3], method, antialias=antialias)
|
|
self.assertAllClose(output, expected, atol=2e-03)
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "antialias={}".format(antialias),
|
|
"antialias": antialias}
|
|
for antialias in [True, False]))
|
|
def testScaleAndTranslateJITs(self, antialias):
|
|
image_shape = [1, 6, 7, 1]
|
|
target_shape = [1, 3, 3, 1]
|
|
|
|
data = [
|
|
51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92,
|
|
41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89,
|
|
71, 32, 23, 23, 35, 93
|
|
]
|
|
if antialias:
|
|
expected_data = [
|
|
43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0
|
|
]
|
|
else:
|
|
expected_data = [43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0]
|
|
x = jnp.array(data, dtype=jnp.float32).reshape(image_shape)
|
|
|
|
expected = jnp.array(expected_data, dtype=jnp.float32).reshape(target_shape)
|
|
scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
|
|
translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)
|
|
|
|
def jit_fn(in_array, s, t):
|
|
return jax.image.scale_and_translate(
|
|
in_array, target_shape, (0, 1, 2, 3), s, t,
|
|
"linear", antialias, precision=jax.lax.Precision.HIGHEST)
|
|
|
|
output = jax.jit(jit_fn)(x, scale_a, translation_a)
|
|
self.assertAllClose(output, expected, atol=2e-03)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "antialias={}".format(antialias),
|
|
"antialias": antialias}
|
|
for antialias in [True, False]))
|
|
def testScaleAndTranslateGradFinite(self, antialias):
|
|
image_shape = [1, 6, 7, 1]
|
|
target_shape = [1, 3, 3, 1]
|
|
|
|
data = [
|
|
51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92,
|
|
41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89,
|
|
71, 32, 23, 23, 35, 93
|
|
]
|
|
|
|
x = jnp.array(data, dtype=jnp.float32).reshape(image_shape)
|
|
scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
|
|
translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)
|
|
|
|
def scale_fn(s):
|
|
return jnp.sum(jax.image.scale_and_translate(
|
|
x, target_shape, (0, 1, 2, 3), s, translation_a, "linear", antialias,
|
|
precision=jax.lax.Precision.HIGHEST))
|
|
|
|
scale_out = jax.grad(scale_fn)(scale_a)
|
|
self.assertTrue(jnp.all(jnp.isfinite(scale_out)))
|
|
|
|
def translate_fn(t):
|
|
return jnp.sum(jax.image.scale_and_translate(
|
|
x, target_shape, (0, 1, 2, 3), scale_a, t, "linear", antialias,
|
|
precision=jax.lax.Precision.HIGHEST))
|
|
|
|
translate_out = jax.grad(translate_fn)(translation_a)
|
|
self.assertTrue(jnp.all(jnp.isfinite(translate_out)))
|
|
|
|
|
|
def testResizeWithUnusualShapes(self):
|
|
x = jnp.ones((3, 4))
|
|
# Array shapes are accepted
|
|
self.assertEqual((10, 17),
|
|
jax.image.resize(x, jnp.array((10, 17)), "nearest").shape)
|
|
with self.assertRaises(TypeError):
|
|
# Fractional shapes are disallowed
|
|
jax.image.resize(x, [10.5, 17], "bicubic")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|