mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Use canonicalize_axis
This commit is contained in:
parent
e28ec78d7a
commit
0a859453b3
@ -20,6 +20,7 @@ from jax import core
|
||||
from jax import jit
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src.util import canonicalize_axis
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -91,8 +92,7 @@ def _scale_and_translate(x, output_shape: core.Shape,
|
||||
out_indices = list(range(len(output_shape)))
|
||||
spatial_ndim = len(spatial_dims)
|
||||
for i, d in enumerate(spatial_dims):
|
||||
if d < 0:
|
||||
d = spatial_ndim + d
|
||||
d = canonicalize_axis(d, spatial_ndim)
|
||||
m = input_shape[d]
|
||||
n = output_shape[d]
|
||||
w = compute_weight_mat(m, n, scale[i], translation[i],
|
||||
|
Loading…
x
Reference in New Issue
Block a user