mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[JAX] Canonicalize shapes in jax.image.resize().
This makes jax.image.resize() robust to having jnp arrays passed as sizes. It also turns out some users were passing floating point values here, and this means they are correctly flagged as errors. PiperOrigin-RevId: 399996702
This commit is contained in:
parent
372839863d
commit
a163b6ec5d
@ -16,6 +16,7 @@ from functools import partial
|
|||||||
import enum
|
import enum
|
||||||
from typing import Callable, Sequence, Union
|
from typing import Callable, Sequence, Union
|
||||||
|
|
||||||
|
from jax import core
|
||||||
from jax import jit
|
from jax import jit
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
@ -194,6 +195,7 @@ def scale_and_translate(image, shape: Sequence[int],
|
|||||||
Returns:
|
Returns:
|
||||||
The scale and translated image.
|
The scale and translated image.
|
||||||
"""
|
"""
|
||||||
|
shape = core.canonicalize_shape(shape)
|
||||||
if len(shape) != image.ndim:
|
if len(shape) != image.ndim:
|
||||||
msg = ('shape must have length equal to the number of dimensions of x; '
|
msg = ('shape must have length equal to the number of dimensions of x; '
|
||||||
f' {shape} vs {image.shape}')
|
f' {shape} vs {image.shape}')
|
||||||
@ -303,4 +305,5 @@ def resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
|
|||||||
Returns:
|
Returns:
|
||||||
The resized image.
|
The resized image.
|
||||||
"""
|
"""
|
||||||
return _resize(image, tuple(shape), method, antialias, precision)
|
return _resize(image, core.canonicalize_shape(shape), method, antialias,
|
||||||
|
precision)
|
||||||
|
@ -382,6 +382,15 @@ class ImageTest(jtu.JaxTestCase):
|
|||||||
self.assertTrue(jnp.all(jnp.isfinite(translate_out)))
|
self.assertTrue(jnp.all(jnp.isfinite(translate_out)))
|
||||||
|
|
||||||
|
|
||||||
|
def testResizeWithUnusualShapes(self):
|
||||||
|
x = jnp.ones((3, 4))
|
||||||
|
# Array shapes are accepted
|
||||||
|
self.assertEqual((10, 17),
|
||||||
|
jax.image.resize(x, jnp.array((10, 17)), "nearest").shape)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
# Fractional shapes are disallowed
|
||||||
|
jax.image.resize(x, [10.5, 17], "bicubic")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user