Fix incorrect implementation to find canonicalize_axis

This commit is contained in:
Du Phan 2022-02-25 21:48:39 -05:00
parent 0a859453b3
commit 0b022b71a5

View File

@ -90,9 +90,8 @@ def _scale_and_translate(x, output_shape: core.Shape,
contractions = []
in_indices = list(range(len(output_shape)))
out_indices = list(range(len(output_shape)))
spatial_ndim = len(spatial_dims)
for i, d in enumerate(spatial_dims):
d = canonicalize_axis(d, spatial_ndim)
d = canonicalize_axis(d, x.ndim)
m = input_shape[d]
n = output_shape[d]
w = compute_weight_mat(m, n, scale[i], translation[i],