Use canonicalize_axis

This commit is contained in:
Du Phan 2022-02-25 17:39:47 -05:00
parent e28ec78d7a
commit 0a859453b3

View File

@ -20,6 +20,7 @@ from jax import core
from jax import jit
from jax import lax
from jax import numpy as jnp
from jax._src.util import canonicalize_axis
import numpy as np
@ -91,8 +92,7 @@ def _scale_and_translate(x, output_shape: core.Shape,
out_indices = list(range(len(output_shape)))
spatial_ndim = len(spatial_dims)
for i, d in enumerate(spatial_dims):
if d < 0:
d = spatial_ndim + d
d = canonicalize_axis(d, spatial_ndim)
m = input_shape[d]
n = output_shape[d]
w = compute_weight_mat(m, n, scale[i], translation[i],