mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #19244 from jakevdp:permute-dims
PiperOrigin-RevId: 596741266
This commit is contained in:
commit
856915f3c4
@ -310,6 +310,7 @@ namespace; they are listed below.
|
||||
pad
|
||||
partition
|
||||
percentile
|
||||
permute_dims
|
||||
piecewise
|
||||
place
|
||||
poly
|
||||
|
@ -542,6 +542,12 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array:
|
||||
return lax.transpose(a, axes_)
|
||||
|
||||
|
||||
@util._wraps(getattr(np, "permute_dims", None))
|
||||
def permute_dims(x: ArrayLike, /, axes: tuple[int, ...]) -> Array:
|
||||
util.check_arraylike("permute_dims", x)
|
||||
return lax.transpose(x, axes)
|
||||
|
||||
|
||||
@util._wraps(getattr(np, 'matrix_transpose', None))
|
||||
def matrix_transpose(x: ArrayLike, /) -> Array:
|
||||
"""Transposes the last two dimensions of x.
|
||||
|
@ -55,7 +55,7 @@ def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None) -> Array:
|
||||
|
||||
def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array:
|
||||
"""Permutes the axes (dimensions) of an array x."""
|
||||
return jax.lax.transpose(x, axes)
|
||||
return jax.numpy.permute_dims(x, axes=axes)
|
||||
|
||||
|
||||
def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array:
|
||||
|
@ -189,6 +189,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
packbits as packbits,
|
||||
pad as pad,
|
||||
partition as partition,
|
||||
permute_dims as permute_dims,
|
||||
pi as pi,
|
||||
piecewise as piecewise,
|
||||
place as place,
|
||||
|
@ -611,6 +611,7 @@ def percentile(a: ArrayLike, q: ArrayLike,
|
||||
axis: Optional[Union[int, tuple[int, ...]]] = ...,
|
||||
out: None = ..., overwrite_input: bool = ..., method: str = ...,
|
||||
keepdims: bool = ..., interpolation: None = ...) -> Array: ...
|
||||
def permute_dims(x: ArrayLike, /, axes: tuple[int, ...]) -> Array: ...
|
||||
pi: float
|
||||
def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]],
|
||||
funclist: Sequence[Union[ArrayLike, Callable[..., Array]]],
|
||||
|
@ -1240,6 +1240,19 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=array_shapes,
|
||||
dtype=default_dtypes,
|
||||
)
|
||||
def testPermuteDims(self, shape, dtype):
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
axes = self.rng().permutation(len(shape))
|
||||
np_fun = partial(getattr(np, "permute_dims", np.transpose), axes=axes)
|
||||
jnp_fun = partial(jnp.permute_dims, axes=axes)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[s for s in array_shapes if len(s) >= 2],
|
||||
dtype=default_dtypes,
|
||||
|
Loading…
x
Reference in New Issue
Block a user