Cleanup: remove duplicate canonicalize_axis utility

This commit is contained in:
Jake VanderPlas 2021-11-23 16:54:02 -08:00
parent 28b3c46b9b
commit f6e3f1b4ad
5 changed files with 9 additions and 24 deletions

View File

@ -2616,7 +2616,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
return c_flat
# Check that all inputs have a consistent leading dimension `num_elems`.
axis = lax._canonicalize_axis(axis, elems_flat[0].ndim)
axis = util.canonicalize_axis(axis, elems_flat[0].ndim)
num_elems = int(elems_flat[0].shape[axis])
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):
raise ValueError('Array inputs to associative_scan must have the same '

View File

@ -5847,15 +5847,3 @@ def _check_user_dtype_supported(dtype, fun_name=None):
fun_name = f"requested in {fun_name}" if fun_name else ""
truncated_dtype = dtypes.canonicalize_dtype(dtype).name
warnings.warn(msg.format(dtype, fun_name , truncated_dtype), stacklevel=2)
def _canonicalize_axis(axis, num_dims):
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
axis = operator.index(axis)
if not -num_dims <= axis < num_dims:
raise ValueError(
"axis {} is out of bounds for array of dimension {}".format(
axis, num_dims))
if axis < 0:
axis = axis + num_dims
return axis

View File

@ -28,14 +28,13 @@ from jax.config import config
from jax.core import NamedShape
from jax._src.api import jit, vmap
from jax._src.numpy.lax_numpy import (_arraylike, _check_arraylike,
_constant_like, _convert_and_clip_integer,
_canonicalize_axis)
_constant_like, _convert_and_clip_integer)
from jax._src.lib import xla_bridge
from jax.numpy.linalg import cholesky, svd, eigh
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import xla
from jax._src.util import prod
from jax._src.util import prod, canonicalize_axis
Array = Any
@ -390,7 +389,7 @@ def permutation(key: KeyArray,
"""
key, _ = _check_prng_key(key)
_check_arraylike("permutation", x)
axis = _canonicalize_axis(axis, np.ndim(x) or 1)
axis = canonicalize_axis(axis, np.ndim(x) or 1)
if not np.ndim(x):
if not np.issubdtype(lax.dtype(x), np.integer):
raise TypeError("x must be an integer or at least 1-dimensional")
@ -466,7 +465,7 @@ def choice(key: KeyArray,
a = core.concrete_or_error(int, a, "The error occurred in jax.random.choice()")
else:
a = jnp.asarray(a)
axis = _canonicalize_axis(axis, np.ndim(a) or 1)
axis = canonicalize_axis(axis, np.ndim(a) or 1)
n_inputs = int(a) if np.ndim(a) == 0 else a.shape[axis] # type: ignore[arg-type]
n_draws = prod(shape)
if n_draws == 0:

View File

@ -16,7 +16,7 @@ from functools import partial
import scipy.fftpack as osp_fft # TODO use scipy.fft once scipy>=1.4.0 is used
from jax import lax, numpy as jnp
from jax._src.lax.lax import _canonicalize_axis
from jax._src.util import canonicalize_axis
from jax._src.numpy.util import _wraps
def _W4(N, k):
@ -40,7 +40,7 @@ def dct(x, type=2, n=None, axis=-1, norm=None):
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')
axis = _canonicalize_axis(axis, x.ndim)
axis = canonicalize_axis(axis, x.ndim)
if n is not None:
x = lax.pad(x, jnp.array(0, x.dtype),
[(0, n - x.shape[axis] if a == axis else 0, 0)
@ -58,7 +58,7 @@ def dct(x, type=2, n=None, axis=-1, norm=None):
def _dct2(x, axes, norm):
axis1, axis2 = map(partial(_canonicalize_axis, num_dims=x.ndim), axes)
axis1, axis2 = map(partial(canonicalize_axis, num_dims=x.ndim), axes)
N1, N2 = x.shape[axis1], x.shape[axis2]
v = _dct_interleave(_dct_interleave(x, axis1), axis2)
V = jnp.fft.fftn(v, axes=axes)

View File

@ -286,9 +286,7 @@ def canonicalize_axis(axis, num_dims) -> int:
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
axis = operator.index(axis)
if not -num_dims <= axis < num_dims:
raise ValueError(
"axis {} is out of bounds for array of dimension {}".format(
axis, num_dims))
raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
if axis < 0:
axis = axis + num_dims
return axis