mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Use canonicalize_axis
This commit is contained in:
parent
e28ec78d7a
commit
0a859453b3
@ -20,6 +20,7 @@ from jax import core
|
|||||||
from jax import jit
|
from jax import jit
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
|
from jax._src.util import canonicalize_axis
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@ -91,8 +92,7 @@ def _scale_and_translate(x, output_shape: core.Shape,
|
|||||||
out_indices = list(range(len(output_shape)))
|
out_indices = list(range(len(output_shape)))
|
||||||
spatial_ndim = len(spatial_dims)
|
spatial_ndim = len(spatial_dims)
|
||||||
for i, d in enumerate(spatial_dims):
|
for i, d in enumerate(spatial_dims):
|
||||||
if d < 0:
|
d = canonicalize_axis(d, spatial_ndim)
|
||||||
d = spatial_ndim + d
|
|
||||||
m = input_shape[d]
|
m = input_shape[d]
|
||||||
n = output_shape[d]
|
n = output_shape[d]
|
||||||
w = compute_weight_mat(m, n, scale[i], translation[i],
|
w = compute_weight_mat(m, n, scale[i], translation[i],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user