Merge pull request #19244 from jakevdp:permute-dims

PiperOrigin-RevId: 596741266
This commit is contained in:
jax authors 2024-01-08 17:02:46 -08:00
commit 856915f3c4
6 changed files with 23 additions and 1 deletions

View File

@ -310,6 +310,7 @@ namespace; they are listed below.
pad
partition
percentile
permute_dims
piecewise
place
poly

View File

@ -542,6 +542,12 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array:
return lax.transpose(a, axes_)
@util._wraps(getattr(np, "permute_dims", None))
def permute_dims(x: ArrayLike, /, axes: tuple[int, ...]) -> Array:
util.check_arraylike("permute_dims", x)
return lax.transpose(x, axes)
@util._wraps(getattr(np, 'matrix_transpose', None))
def matrix_transpose(x: ArrayLike, /) -> Array:
"""Transposes the last two dimensions of x.

View File

@ -55,7 +55,7 @@ def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None) -> Array:
def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array:
"""Permutes the axes (dimensions) of an array x."""
return jax.lax.transpose(x, axes)
return jax.numpy.permute_dims(x, axes=axes)
def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array:

View File

@ -189,6 +189,7 @@ from jax._src.numpy.lax_numpy import (
packbits as packbits,
pad as pad,
partition as partition,
permute_dims as permute_dims,
pi as pi,
piecewise as piecewise,
place as place,

View File

@ -611,6 +611,7 @@ def percentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, tuple[int, ...]]] = ...,
out: None = ..., overwrite_input: bool = ..., method: str = ...,
keepdims: bool = ..., interpolation: None = ...) -> Array: ...
def permute_dims(x: ArrayLike, /, axes: tuple[int, ...]) -> Array: ...
pi: float
def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]],
funclist: Sequence[Union[ArrayLike, Callable[..., Array]]],

View File

@ -1240,6 +1240,19 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@jtu.sample_product(
shape=array_shapes,
dtype=default_dtypes,
)
def testPermuteDims(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
axes = self.rng().permutation(len(shape))
np_fun = partial(getattr(np, "permute_dims", np.transpose), axes=axes)
jnp_fun = partial(jnp.permute_dims, axes=axes)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@jtu.sample_product(
shape=[s for s in array_shapes if len(s) >= 2],
dtype=default_dtypes,