[JAX] Canonicalize shapes in jax.image.resize().

This makes jax.image.resize() robust to having jnp arrays passed as sizes. It also turns out some users were passing floating point values here, and this means they are correctly flagged as errors.

PiperOrigin-RevId: 399996702
This commit is contained in:
Peter Hawkins 2021-09-30 12:36:47 -07:00 committed by jax authors
parent 372839863d
commit a163b6ec5d
2 changed files with 13 additions and 1 deletions

View File

@ -16,6 +16,7 @@ from functools import partial
import enum import enum
from typing import Callable, Sequence, Union from typing import Callable, Sequence, Union
from jax import core
from jax import jit from jax import jit
from jax import lax from jax import lax
from jax import numpy as jnp from jax import numpy as jnp
@ -194,6 +195,7 @@ def scale_and_translate(image, shape: Sequence[int],
Returns: Returns:
The scale and translated image. The scale and translated image.
""" """
shape = core.canonicalize_shape(shape)
if len(shape) != image.ndim: if len(shape) != image.ndim:
msg = ('shape must have length equal to the number of dimensions of x; ' msg = ('shape must have length equal to the number of dimensions of x; '
f' {shape} vs {image.shape}') f' {shape} vs {image.shape}')
@ -303,4 +305,5 @@ def resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
Returns: Returns:
The resized image. The resized image.
""" """
return _resize(image, tuple(shape), method, antialias, precision) return _resize(image, core.canonicalize_shape(shape), method, antialias,
precision)

View File

@ -382,6 +382,15 @@ class ImageTest(jtu.JaxTestCase):
self.assertTrue(jnp.all(jnp.isfinite(translate_out))) self.assertTrue(jnp.all(jnp.isfinite(translate_out)))
def testResizeWithUnusualShapes(self):
x = jnp.ones((3, 4))
# Array shapes are accepted
self.assertEqual((10, 17),
jax.image.resize(x, jnp.array((10, 17)), "nearest").shape)
with self.assertRaises(TypeError):
# Fractional shapes are disallowed
jax.image.resize(x, [10.5, 17], "bicubic")
if __name__ == "__main__": if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader()) absltest.main(testLoader=jtu.JaxTestLoader())