mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix incorrect implementation to find canonicalize_axis
This commit is contained in:
parent
0a859453b3
commit
0b022b71a5
@ -90,9 +90,8 @@ def _scale_and_translate(x, output_shape: core.Shape,
|
||||
contractions = []
|
||||
in_indices = list(range(len(output_shape)))
|
||||
out_indices = list(range(len(output_shape)))
|
||||
spatial_ndim = len(spatial_dims)
|
||||
for i, d in enumerate(spatial_dims):
|
||||
d = canonicalize_axis(d, spatial_ndim)
|
||||
d = canonicalize_axis(d, x.ndim)
|
||||
m = input_shape[d]
|
||||
n = output_shape[d]
|
||||
w = compute_weight_mat(m, n, scale[i], translation[i],
|
||||
|
Loading…
x
Reference in New Issue
Block a user