Add jit to jax.image.resize (#3714)

* Add image/ directory to Bazel build.

* Use a jit on jax.image.resize to reduce compilation time.

Relax bfloat16 test tolerance.
This commit is contained in:
Peter Hawkins 2020-07-10 10:32:13 -04:00 committed by GitHub
parent b943b31b22
commit 417de0d351
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 13 deletions

View File

@ -27,6 +27,7 @@ pytype_library(
srcs = glob( srcs = glob(
[ [
"*.py", "*.py",
"image/**/*.py",
"lib/**/*.py", "lib/**/*.py",
"interpreters/**/*.py", "interpreters/**/*.py",
"lax/**/*.py", "lax/**/*.py",

View File

@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import enum import enum
import math import math
from typing import Callable, Sequence, Tuple, Union from typing import Callable, Sequence, Tuple, Union
from jax import jit
from jax import lax from jax import lax
from jax import numpy as jnp from jax import numpy as jnp
import numpy as np import numpy as np
@ -146,6 +148,21 @@ _kernels[ResizeMethod.LANCZOS5] = _lanczos_kernel(5.)
_kernels[ResizeMethod.CUBIC] = _keys_cubic_kernel() _kernels[ResizeMethod.CUBIC] = _keys_cubic_kernel()
@partial(jit, static_argnums=(1, 2, 3))
def _resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
antialias: bool):
if len(shape) != image.ndim:
msg = ('shape must have length equal to the number of dimensions of x; '
f' {shape} vs {image.shape}')
raise ValueError(msg)
kernel = _kernels[ResizeMethod.from_string(method) if isinstance(method, str)
else method]
scale = [float(o) / i for o, i in zip(shape, image.shape)]
if not jnp.issubdtype(image.dtype, jnp.inexact):
image = lax.convert_element_type(image, jnp.result_type(image, jnp.float32))
return _scale_and_translate(image, shape, scale, [0.] * image.ndim, kernel,
antialias)
def resize(image, shape: Sequence[int], method: Union[str, ResizeMethod], def resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
antialias: bool = True): antialias: bool = True):
"""Image resize. """Image resize.
@ -183,14 +200,5 @@ def resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
Returns: Returns:
The resized image. The resized image.
""" """
if len(shape) != image.ndim: return _resize(image, shape, method, antialias)
msg = ('shape must have length equal to the number of dimensions of x; '
f' {shape} vs {image.shape}')
raise ValueError(msg)
kernel = _kernels[ResizeMethod.from_string(method) if isinstance(method, str)
else method]
scale = [float(o) / i for o, i in zip(shape, image.shape)]
if not jnp.issubdtype(image.dtype, jnp.inexact):
image = lax.convert_element_type(image, jnp.result_type(image, jnp.float32))
return _scale_and_translate(image, shape, scale, [0.] * image.ndim, kernel,
antialias)

View File

@ -22,6 +22,7 @@ from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
from jax import image from jax import image
from jax import numpy as jnp
from jax import test_util as jtu from jax import test_util as jtu
from jax.config import config from jax.config import config
@ -79,8 +80,8 @@ class ImageTest(jtu.JaxTestCase):
jax_fn = partial(image.resize, shape=target_shape, method=method, jax_fn = partial(image.resize, shape=target_shape, method=method,
antialias=antialias) antialias=antialias)
self._CheckAgainstNumpy(tf_fn, jax_fn, args_maker, check_dtypes=True, self._CheckAgainstNumpy(tf_fn, jax_fn, args_maker, check_dtypes=True,
tol={np.float16: 2e-2, np.float32: 1e-4, tol={np.float16: 2e-2, jnp.bfloat16: 1e-1,
np.float64: 1e-4}) np.float32: 1e-4, np.float64: 1e-4})
@parameterized.named_parameters(jtu.cases_from_list( @parameterized.named_parameters(jtu.cases_from_list(