From 0a859453b3a6ba4b71f33cc1eeb15bdd5e00d392 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 25 Feb 2022 17:39:47 -0500 Subject: [PATCH] Use canonicalize_axis --- jax/_src/image/scale.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/image/scale.py b/jax/_src/image/scale.py index 7ebbb7fa3..87823ffa3 100644 --- a/jax/_src/image/scale.py +++ b/jax/_src/image/scale.py @@ -20,6 +20,7 @@ from jax import core from jax import jit from jax import lax from jax import numpy as jnp +from jax._src.util import canonicalize_axis import numpy as np @@ -91,8 +92,7 @@ def _scale_and_translate(x, output_shape: core.Shape, out_indices = list(range(len(output_shape))) spatial_ndim = len(spatial_dims) for i, d in enumerate(spatial_dims): - if d < 0: - d = spatial_ndim + d + d = canonicalize_axis(d, spatial_ndim) m = input_shape[d] n = output_shape[d] w = compute_weight_mat(m, n, scale[i], translation[i],