mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
jnp.repeat: add copy argument for Array API
This commit is contained in:
parent
8bcd288621
commit
5198db9fdb
@ -1289,7 +1289,8 @@ def isrealobj(x: Any) -> bool:
|
||||
|
||||
def reshape(
|
||||
a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *,
|
||||
newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg()) -> Array:
|
||||
newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg(),
|
||||
copy: bool | None = None) -> Array:
|
||||
"""Return a reshaped copy of an array.
|
||||
|
||||
JAX implementation of :func:`numpy.reshape`, implemented in terms of
|
||||
@ -1303,6 +1304,8 @@ def reshape(
|
||||
order: ``'F'`` or ``'C'``, specifies whether the reshape should apply column-major
|
||||
(fortran-style, ``"F"``) or row-major (C-style, ``"C"``) order; default is ``"C"``.
|
||||
JAX does not support ``order="A"``.
|
||||
copy: unused by JAX; JAX always returns a copy, though under JIT the compiler
|
||||
may optimize such copies away.
|
||||
|
||||
Returns:
|
||||
reshaped copy of input array with the specified shape.
|
||||
@ -1355,6 +1358,8 @@ def reshape(
|
||||
[3, 4],
|
||||
[5, 6]], dtype=int32)
|
||||
"""
|
||||
del copy # unused
|
||||
|
||||
__tracebackhide__ = True
|
||||
util.check_arraylike("reshape", a)
|
||||
|
||||
|
@ -148,6 +148,7 @@ from jax.numpy import (
|
||||
real as real,
|
||||
remainder as remainder,
|
||||
repeat as repeat,
|
||||
reshape as reshape,
|
||||
result_type as result_type,
|
||||
roll as roll,
|
||||
round as round,
|
||||
@ -188,10 +189,6 @@ from jax.numpy import (
|
||||
zeros_like as zeros_like,
|
||||
)
|
||||
|
||||
from jax.experimental.array_api._manipulation_functions import (
|
||||
reshape as reshape,
|
||||
)
|
||||
|
||||
from jax.experimental.array_api._data_type_functions import (
|
||||
astype as astype,
|
||||
)
|
||||
|
@ -1,25 +0,0 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import jax
|
||||
from jax import Array
|
||||
|
||||
|
||||
# TODO(micky774): Implement copy
|
||||
def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array:
|
||||
"""Reshapes an array without changing its data."""
|
||||
del copy # unused
|
||||
return jax.numpy.reshape(x, shape)
|
Loading…
x
Reference in New Issue
Block a user