Use canonicalize_axis

This commit is contained in:
Du Phan 2022-02-25 17:39:47 -05:00
parent e28ec78d7a
commit 0a859453b3

View File

@ -20,6 +20,7 @@ from jax import core
from jax import jit from jax import jit
from jax import lax from jax import lax
from jax import numpy as jnp from jax import numpy as jnp
from jax._src.util import canonicalize_axis
import numpy as np import numpy as np
@ -91,8 +92,7 @@ def _scale_and_translate(x, output_shape: core.Shape,
out_indices = list(range(len(output_shape))) out_indices = list(range(len(output_shape)))
spatial_ndim = len(spatial_dims) spatial_ndim = len(spatial_dims)
for i, d in enumerate(spatial_dims): for i, d in enumerate(spatial_dims):
if d < 0: d = canonicalize_axis(d, spatial_ndim)
d = spatial_ndim + d
m = input_shape[d] m = input_shape[d]
n = output_shape[d] n = output_shape[d]
w = compute_weight_mat(m, n, scale[i], translation[i], w = compute_weight_mat(m, n, scale[i], translation[i],