mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Updating image_test
This commit is contained in:
parent
be50847cee
commit
ae910cdd31
@ -21,6 +21,7 @@ import numpy as np
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import image
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
@ -181,7 +182,7 @@ class ImageTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(jax_fn, args_maker(), order=2, rtol=1e-2, eps=1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":"_shape={}_target={}_method={}".format(
|
||||
{"testcase_name": "_shape={}_target={}_method={}".format(
|
||||
jtu.format_shape_dtype_string(image_shape, dtype),
|
||||
jtu.format_shape_dtype_string(target_shape, dtype), method),
|
||||
"dtype": dtype, "image_shape": image_shape,
|
||||
@ -221,7 +222,8 @@ class ImageTest(jtu.JaxTestCase):
|
||||
# Should we test different float types here?
|
||||
scale_a = jnp.array(scale, dtype=jnp.float32)
|
||||
translation_a = jnp.array(translation, dtype=jnp.float32)
|
||||
output = image.scale_and_translate(x, target_shape, scale_a, translation_a,
|
||||
output = image.scale_and_translate(x, target_shape, range(len(image_shape)),
|
||||
scale_a, translation_a,
|
||||
method)
|
||||
|
||||
expected = np.array(
|
||||
@ -231,7 +233,7 @@ class ImageTest(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}_method={}_antialias={}".format(
|
||||
jtu.dtype_str(dtype), method, antialias),
|
||||
"dtype":dtype, "method":method, "antialias": antialias}
|
||||
"dtype": dtype, "method": method, "antialias": antialias}
|
||||
for dtype in inexact_dtypes
|
||||
for method in ["linear", "lanczos3", "lanczos5", "cubic"]
|
||||
for antialias in [True, False]))
|
||||
@ -274,12 +276,55 @@ class ImageTest(jtu.JaxTestCase):
|
||||
]
|
||||
x = np.array(data, dtype=dtype).reshape(image_shape)
|
||||
|
||||
scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
|
||||
translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)
|
||||
output = image.scale_and_translate(
|
||||
x, target_shape, scale_a, translation_a, method, antialias=antialias)
|
||||
expected = np.array(
|
||||
expected_data[method], dtype=dtype).reshape(target_shape)
|
||||
scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
|
||||
translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)
|
||||
|
||||
output = image.scale_and_translate(
|
||||
x, target_shape, (0,1,2,3),
|
||||
scale_a, translation_a, method, antialias=antialias)
|
||||
self.assertAllClose(output, expected, atol=2e-03)
|
||||
|
||||
# Tests that running with just a subset of dimensions that have non-trivial
|
||||
# scale and translation.
|
||||
output = image.scale_and_translate(
|
||||
x, target_shape, (1,2),
|
||||
scale_a[1:3], translation_a[1:3], method, antialias=antialias)
|
||||
self.assertAllClose(output, expected, atol=2e-03)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "antialias={}".format(antialias),
|
||||
"antialias": antialias}
|
||||
for antialias in [True, False]))
|
||||
def testScaleAndTranslateJITs(self, antialias):
|
||||
image_shape = [1, 6, 7, 1]
|
||||
target_shape = [1, 3, 3, 1]
|
||||
|
||||
data = [
|
||||
51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92,
|
||||
41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89,
|
||||
71, 32, 23, 23, 35, 93
|
||||
]
|
||||
if antialias:
|
||||
expected_data = [
|
||||
43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0
|
||||
]
|
||||
else:
|
||||
expected_data = [43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0]
|
||||
x = jnp.array(data, dtype=jnp.float32).reshape(image_shape)
|
||||
|
||||
expected = jnp.array(expected_data, dtype=jnp.float32).reshape(target_shape)
|
||||
scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
|
||||
translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)
|
||||
|
||||
def jit_fn(in_array, s, t):
|
||||
return jax.image.scale_and_translate(
|
||||
in_array, target_shape, (0, 1, 2, 3), s, t,
|
||||
"linear", antialias, precision=jax.lax.Precision.HIGHEST)
|
||||
|
||||
output = jax.jit(jit_fn)(x, scale_a, translation_a)
|
||||
self.assertAllClose(output, expected, atol=2e-03)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user