mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add jit to jax.image.resize (#3714)
* Add image/ directory to Bazel build. * Use a jit on jax.image.resize to reduce compilation time. Relax bfloat16 test tolerance.
This commit is contained in:
parent
b943b31b22
commit
417de0d351
@ -27,6 +27,7 @@ pytype_library(
|
|||||||
srcs = glob(
|
srcs = glob(
|
||||||
[
|
[
|
||||||
"*.py",
|
"*.py",
|
||||||
|
"image/**/*.py",
|
||||||
"lib/**/*.py",
|
"lib/**/*.py",
|
||||||
"interpreters/**/*.py",
|
"interpreters/**/*.py",
|
||||||
"lax/**/*.py",
|
"lax/**/*.py",
|
||||||
|
@ -12,10 +12,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
import enum
|
import enum
|
||||||
import math
|
import math
|
||||||
from typing import Callable, Sequence, Tuple, Union
|
from typing import Callable, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
from jax import jit
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -146,6 +148,21 @@ _kernels[ResizeMethod.LANCZOS5] = _lanczos_kernel(5.)
|
|||||||
_kernels[ResizeMethod.CUBIC] = _keys_cubic_kernel()
|
_kernels[ResizeMethod.CUBIC] = _keys_cubic_kernel()
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jit, static_argnums=(1, 2, 3))
|
||||||
|
def _resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
|
||||||
|
antialias: bool):
|
||||||
|
if len(shape) != image.ndim:
|
||||||
|
msg = ('shape must have length equal to the number of dimensions of x; '
|
||||||
|
f' {shape} vs {image.shape}')
|
||||||
|
raise ValueError(msg)
|
||||||
|
kernel = _kernels[ResizeMethod.from_string(method) if isinstance(method, str)
|
||||||
|
else method]
|
||||||
|
scale = [float(o) / i for o, i in zip(shape, image.shape)]
|
||||||
|
if not jnp.issubdtype(image.dtype, jnp.inexact):
|
||||||
|
image = lax.convert_element_type(image, jnp.result_type(image, jnp.float32))
|
||||||
|
return _scale_and_translate(image, shape, scale, [0.] * image.ndim, kernel,
|
||||||
|
antialias)
|
||||||
|
|
||||||
def resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
|
def resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
|
||||||
antialias: bool = True):
|
antialias: bool = True):
|
||||||
"""Image resize.
|
"""Image resize.
|
||||||
@ -183,14 +200,5 @@ def resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
|
|||||||
Returns:
|
Returns:
|
||||||
The resized image.
|
The resized image.
|
||||||
"""
|
"""
|
||||||
if len(shape) != image.ndim:
|
return _resize(image, shape, method, antialias)
|
||||||
msg = ('shape must have length equal to the number of dimensions of x; '
|
|
||||||
f' {shape} vs {image.shape}')
|
|
||||||
raise ValueError(msg)
|
|
||||||
kernel = _kernels[ResizeMethod.from_string(method) if isinstance(method, str)
|
|
||||||
else method]
|
|
||||||
scale = [float(o) / i for o, i in zip(shape, image.shape)]
|
|
||||||
if not jnp.issubdtype(image.dtype, jnp.inexact):
|
|
||||||
image = lax.convert_element_type(image, jnp.result_type(image, jnp.float32))
|
|
||||||
return _scale_and_translate(image, shape, scale, [0.] * image.ndim, kernel,
|
|
||||||
antialias)
|
|
||||||
|
@ -22,6 +22,7 @@ from absl.testing import absltest
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from jax import image
|
from jax import image
|
||||||
|
from jax import numpy as jnp
|
||||||
from jax import test_util as jtu
|
from jax import test_util as jtu
|
||||||
|
|
||||||
from jax.config import config
|
from jax.config import config
|
||||||
@ -79,8 +80,8 @@ class ImageTest(jtu.JaxTestCase):
|
|||||||
jax_fn = partial(image.resize, shape=target_shape, method=method,
|
jax_fn = partial(image.resize, shape=target_shape, method=method,
|
||||||
antialias=antialias)
|
antialias=antialias)
|
||||||
self._CheckAgainstNumpy(tf_fn, jax_fn, args_maker, check_dtypes=True,
|
self._CheckAgainstNumpy(tf_fn, jax_fn, args_maker, check_dtypes=True,
|
||||||
tol={np.float16: 2e-2, np.float32: 1e-4,
|
tol={np.float16: 2e-2, jnp.bfloat16: 1e-1,
|
||||||
np.float64: 1e-4})
|
np.float32: 1e-4, np.float64: 1e-4})
|
||||||
|
|
||||||
|
|
||||||
@parameterized.named_parameters(jtu.cases_from_list(
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user