mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Cleanup: remove duplicate canonicalize_axis utility
This commit is contained in:
parent
28b3c46b9b
commit
f6e3f1b4ad
@ -2616,7 +2616,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
|
||||
return c_flat
|
||||
|
||||
# Check that all inputs have a consistent leading dimension `num_elems`.
|
||||
axis = lax._canonicalize_axis(axis, elems_flat[0].ndim)
|
||||
axis = util.canonicalize_axis(axis, elems_flat[0].ndim)
|
||||
num_elems = int(elems_flat[0].shape[axis])
|
||||
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):
|
||||
raise ValueError('Array inputs to associative_scan must have the same '
|
||||
|
@ -5847,15 +5847,3 @@ def _check_user_dtype_supported(dtype, fun_name=None):
|
||||
fun_name = f"requested in {fun_name}" if fun_name else ""
|
||||
truncated_dtype = dtypes.canonicalize_dtype(dtype).name
|
||||
warnings.warn(msg.format(dtype, fun_name , truncated_dtype), stacklevel=2)
|
||||
|
||||
|
||||
def _canonicalize_axis(axis, num_dims):
|
||||
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
|
||||
axis = operator.index(axis)
|
||||
if not -num_dims <= axis < num_dims:
|
||||
raise ValueError(
|
||||
"axis {} is out of bounds for array of dimension {}".format(
|
||||
axis, num_dims))
|
||||
if axis < 0:
|
||||
axis = axis + num_dims
|
||||
return axis
|
||||
|
@ -28,14 +28,13 @@ from jax.config import config
|
||||
from jax.core import NamedShape
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.numpy.lax_numpy import (_arraylike, _check_arraylike,
|
||||
_constant_like, _convert_and_clip_integer,
|
||||
_canonicalize_axis)
|
||||
_constant_like, _convert_and_clip_integer)
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax.numpy.linalg import cholesky, svd, eigh
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import xla
|
||||
from jax._src.util import prod
|
||||
from jax._src.util import prod, canonicalize_axis
|
||||
|
||||
|
||||
Array = Any
|
||||
@ -390,7 +389,7 @@ def permutation(key: KeyArray,
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
_check_arraylike("permutation", x)
|
||||
axis = _canonicalize_axis(axis, np.ndim(x) or 1)
|
||||
axis = canonicalize_axis(axis, np.ndim(x) or 1)
|
||||
if not np.ndim(x):
|
||||
if not np.issubdtype(lax.dtype(x), np.integer):
|
||||
raise TypeError("x must be an integer or at least 1-dimensional")
|
||||
@ -466,7 +465,7 @@ def choice(key: KeyArray,
|
||||
a = core.concrete_or_error(int, a, "The error occurred in jax.random.choice()")
|
||||
else:
|
||||
a = jnp.asarray(a)
|
||||
axis = _canonicalize_axis(axis, np.ndim(a) or 1)
|
||||
axis = canonicalize_axis(axis, np.ndim(a) or 1)
|
||||
n_inputs = int(a) if np.ndim(a) == 0 else a.shape[axis] # type: ignore[arg-type]
|
||||
n_draws = prod(shape)
|
||||
if n_draws == 0:
|
||||
|
@ -16,7 +16,7 @@ from functools import partial
|
||||
|
||||
import scipy.fftpack as osp_fft # TODO use scipy.fft once scipy>=1.4.0 is used
|
||||
from jax import lax, numpy as jnp
|
||||
from jax._src.lax.lax import _canonicalize_axis
|
||||
from jax._src.util import canonicalize_axis
|
||||
from jax._src.numpy.util import _wraps
|
||||
|
||||
def _W4(N, k):
|
||||
@ -40,7 +40,7 @@ def dct(x, type=2, n=None, axis=-1, norm=None):
|
||||
if type != 2:
|
||||
raise NotImplementedError('Only DCT type 2 is implemented.')
|
||||
|
||||
axis = _canonicalize_axis(axis, x.ndim)
|
||||
axis = canonicalize_axis(axis, x.ndim)
|
||||
if n is not None:
|
||||
x = lax.pad(x, jnp.array(0, x.dtype),
|
||||
[(0, n - x.shape[axis] if a == axis else 0, 0)
|
||||
@ -58,7 +58,7 @@ def dct(x, type=2, n=None, axis=-1, norm=None):
|
||||
|
||||
|
||||
def _dct2(x, axes, norm):
|
||||
axis1, axis2 = map(partial(_canonicalize_axis, num_dims=x.ndim), axes)
|
||||
axis1, axis2 = map(partial(canonicalize_axis, num_dims=x.ndim), axes)
|
||||
N1, N2 = x.shape[axis1], x.shape[axis2]
|
||||
v = _dct_interleave(_dct_interleave(x, axis1), axis2)
|
||||
V = jnp.fft.fftn(v, axes=axes)
|
||||
|
@ -286,9 +286,7 @@ def canonicalize_axis(axis, num_dims) -> int:
|
||||
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
|
||||
axis = operator.index(axis)
|
||||
if not -num_dims <= axis < num_dims:
|
||||
raise ValueError(
|
||||
"axis {} is out of bounds for array of dimension {}".format(
|
||||
axis, num_dims))
|
||||
raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
|
||||
if axis < 0:
|
||||
axis = axis + num_dims
|
||||
return axis
|
||||
|
Loading…
x
Reference in New Issue
Block a user