From a163b6ec5d26bb488e243fe578c76a0ab71de755 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 30 Sep 2021 12:36:47 -0700 Subject: [PATCH] [JAX] Canonicalize shapes in jax.image.resize(). This makes jax.image.resize() robust to having jnp arrays passed as sizes. It also turns out some users were passing floating point values here, and this means they are correctly flagged as errors. PiperOrigin-RevId: 399996702 --- jax/_src/image/scale.py | 5 ++++- tests/image_test.py | 9 +++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/jax/_src/image/scale.py b/jax/_src/image/scale.py index 9a6fd4f89..fa86e4b1c 100644 --- a/jax/_src/image/scale.py +++ b/jax/_src/image/scale.py @@ -16,6 +16,7 @@ from functools import partial import enum from typing import Callable, Sequence, Union +from jax import core from jax import jit from jax import lax from jax import numpy as jnp @@ -194,6 +195,7 @@ def scale_and_translate(image, shape: Sequence[int], Returns: The scale and translated image. """ + shape = core.canonicalize_shape(shape) if len(shape) != image.ndim: msg = ('shape must have length equal to the number of dimensions of x; ' f' {shape} vs {image.shape}') @@ -303,4 +305,5 @@ def resize(image, shape: Sequence[int], method: Union[str, ResizeMethod], Returns: The resized image. """ - return _resize(image, tuple(shape), method, antialias, precision) + return _resize(image, core.canonicalize_shape(shape), method, antialias, + precision) diff --git a/tests/image_test.py b/tests/image_test.py index 4da90a04e..cafd7944a 100644 --- a/tests/image_test.py +++ b/tests/image_test.py @@ -382,6 +382,15 @@ class ImageTest(jtu.JaxTestCase): self.assertTrue(jnp.all(jnp.isfinite(translate_out))) + def testResizeWithUnusualShapes(self): + x = jnp.ones((3, 4)) + # Array shapes are accepted + self.assertEqual((10, 17), + jax.image.resize(x, jnp.array((10, 17)), "nearest").shape) + with self.assertRaises(TypeError): + # Fractional shapes are disallowed + jax.image.resize(x, [10.5, 17], "bicubic") + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())