mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add jit to jax.image.resize (#3714)
* Add image/ directory to Bazel build. * Use a jit on jax.image.resize to reduce compilation time. Relax bfloat16 test tolerance.
This commit is contained in:
parent
b943b31b22
commit
417de0d351
@ -27,6 +27,7 @@ pytype_library(
|
||||
srcs = glob(
|
||||
[
|
||||
"*.py",
|
||||
"image/**/*.py",
|
||||
"lib/**/*.py",
|
||||
"interpreters/**/*.py",
|
||||
"lax/**/*.py",
|
||||
|
@ -12,10 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import enum
|
||||
import math
|
||||
from typing import Callable, Sequence, Tuple, Union
|
||||
|
||||
from jax import jit
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
import numpy as np
|
||||
@ -146,6 +148,21 @@ _kernels[ResizeMethod.LANCZOS5] = _lanczos_kernel(5.)
|
||||
_kernels[ResizeMethod.CUBIC] = _keys_cubic_kernel()
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1, 2, 3))
|
||||
def _resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
|
||||
antialias: bool):
|
||||
if len(shape) != image.ndim:
|
||||
msg = ('shape must have length equal to the number of dimensions of x; '
|
||||
f' {shape} vs {image.shape}')
|
||||
raise ValueError(msg)
|
||||
kernel = _kernels[ResizeMethod.from_string(method) if isinstance(method, str)
|
||||
else method]
|
||||
scale = [float(o) / i for o, i in zip(shape, image.shape)]
|
||||
if not jnp.issubdtype(image.dtype, jnp.inexact):
|
||||
image = lax.convert_element_type(image, jnp.result_type(image, jnp.float32))
|
||||
return _scale_and_translate(image, shape, scale, [0.] * image.ndim, kernel,
|
||||
antialias)
|
||||
|
||||
def resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
|
||||
antialias: bool = True):
|
||||
"""Image resize.
|
||||
@ -183,14 +200,5 @@ def resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
|
||||
Returns:
|
||||
The resized image.
|
||||
"""
|
||||
if len(shape) != image.ndim:
|
||||
msg = ('shape must have length equal to the number of dimensions of x; '
|
||||
f' {shape} vs {image.shape}')
|
||||
raise ValueError(msg)
|
||||
kernel = _kernels[ResizeMethod.from_string(method) if isinstance(method, str)
|
||||
else method]
|
||||
scale = [float(o) / i for o, i in zip(shape, image.shape)]
|
||||
if not jnp.issubdtype(image.dtype, jnp.inexact):
|
||||
image = lax.convert_element_type(image, jnp.result_type(image, jnp.float32))
|
||||
return _scale_and_translate(image, shape, scale, [0.] * image.ndim, kernel,
|
||||
antialias)
|
||||
return _resize(image, shape, method, antialias)
|
||||
|
||||
|
@ -22,6 +22,7 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from jax import image
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
@ -79,8 +80,8 @@ class ImageTest(jtu.JaxTestCase):
|
||||
jax_fn = partial(image.resize, shape=target_shape, method=method,
|
||||
antialias=antialias)
|
||||
self._CheckAgainstNumpy(tf_fn, jax_fn, args_maker, check_dtypes=True,
|
||||
tol={np.float16: 2e-2, np.float32: 1e-4,
|
||||
np.float64: 1e-4})
|
||||
tol={np.float16: 2e-2, jnp.bfloat16: 1e-1,
|
||||
np.float32: 1e-4, np.float64: 1e-4})
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
|
Loading…
x
Reference in New Issue
Block a user