jnp.repeat: add copy argument for Array API

This commit is contained in:
Jake VanderPlas 2024-07-30 14:07:08 -07:00
parent 8bcd288621
commit 5198db9fdb
3 changed files with 7 additions and 30 deletions

View File

@ -1289,7 +1289,8 @@ def isrealobj(x: Any) -> bool:
def reshape(
a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *,
newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg()) -> Array:
newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg(),
copy: bool | None = None) -> Array:
"""Return a reshaped copy of an array.
JAX implementation of :func:`numpy.reshape`, implemented in terms of
@ -1303,6 +1304,8 @@ def reshape(
order: ``'F'`` or ``'C'``, specifies whether the reshape should apply column-major
(fortran-style, ``"F"``) or row-major (C-style, ``"C"``) order; default is ``"C"``.
JAX does not support ``order="A"``.
copy: unused by JAX; JAX always returns a copy, though under JIT the compiler
may optimize such copies away.
Returns:
reshaped copy of input array with the specified shape.
@ -1355,6 +1358,8 @@ def reshape(
[3, 4],
[5, 6]], dtype=int32)
"""
del copy # unused
__tracebackhide__ = True
util.check_arraylike("reshape", a)

View File

@ -148,6 +148,7 @@ from jax.numpy import (
real as real,
remainder as remainder,
repeat as repeat,
reshape as reshape,
result_type as result_type,
roll as roll,
round as round,
@ -188,10 +189,6 @@ from jax.numpy import (
zeros_like as zeros_like,
)
from jax.experimental.array_api._manipulation_functions import (
reshape as reshape,
)
from jax.experimental.array_api._data_type_functions import (
astype as astype,
)

View File

@ -1,25 +0,0 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import jax
from jax import Array
# TODO(micky774): Implement copy
def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array:
"""Reshapes an array without changing its data."""
del copy # unused
return jax.numpy.reshape(x, shape)