mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Update jnp.clip
to Array API 2023 standard
This commit is contained in:
parent
033992867f
commit
8b7aae586b
@ -12,6 +12,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
|
||||
lowering pass via Triton Python APIs has been removed and the
|
||||
`JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect.
|
||||
* {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and
|
||||
`a_max` are deprecated in favor of `x` (positonal only), `min`, and
|
||||
`max` ({jax-issue}`20550`).
|
||||
|
||||
|
||||
## jaxlib 0.4.27
|
||||
|
@ -84,12 +84,11 @@ def _itemsize(arr: ArrayLike) -> int:
|
||||
|
||||
|
||||
def _clip(number: ArrayLike,
|
||||
min: ArrayLike | None = None, max: ArrayLike | None = None,
|
||||
out: None = None) -> Array:
|
||||
min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array:
|
||||
"""Return an array whose values are limited to a specified range.
|
||||
|
||||
Refer to :func:`jax.numpy.clip` for full documentation."""
|
||||
return lax_numpy.clip(number, a_min=min, a_max=max, out=out)
|
||||
return lax_numpy.clip(number, min=min, max=max)
|
||||
|
||||
|
||||
def _transpose(a: Array, *args: Any) -> Array:
|
||||
|
@ -66,7 +66,10 @@ from jax._src.numpy import reductions
|
||||
from jax._src.numpy import ufuncs
|
||||
from jax._src.numpy import util
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape
|
||||
from jax._src.typing import (
|
||||
Array, ArrayLike, DimSize, DuckTypedArray,
|
||||
DType, DTypeLike, Shape, DeprecatedArg
|
||||
)
|
||||
from jax._src.util import (unzip2, subvals, safe_zip,
|
||||
ceil_of_ratio, partition_list,
|
||||
canonicalize_axis as _canonicalize_axis,
|
||||
@ -1293,20 +1296,63 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array
|
||||
axis: int = 0) -> list[Array]:
|
||||
return _split("array_split", ary, indices_or_sections, axis=axis)
|
||||
|
||||
@util.implements(np.clip, skip_params=['out'])
|
||||
|
||||
_DEPRECATED_CLIP_ARG = DeprecatedArg()
|
||||
@util.implements(
|
||||
np.clip,
|
||||
skip_params=['a', 'a_min'],
|
||||
extra_params=_dedent("""
|
||||
x : array_like
|
||||
Array containing elements to clip.
|
||||
min : array_like, optional
|
||||
Minimum value. If ``None``, clipping is not performed on the
|
||||
corresponding edge. The value of ``min`` is broadcast against x.
|
||||
max : array_like, optional
|
||||
Maximum value. If ``None``, clipping is not performed on the
|
||||
corresponding edge. The value of ``max`` is broadcast against x.
|
||||
""")
|
||||
)
|
||||
@jit
|
||||
def clip(a: ArrayLike, a_min: ArrayLike | None = None,
|
||||
a_max: ArrayLike | None = None, out: None = None) -> Array:
|
||||
util.check_arraylike("clip", a)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.clip is not supported.")
|
||||
if a_min is None and a_max is None:
|
||||
raise ValueError("At most one of a_min and a_max may be None")
|
||||
if a_min is not None:
|
||||
a = ufuncs.maximum(a_min, a)
|
||||
if a_max is not None:
|
||||
a = ufuncs.minimum(a_max, a)
|
||||
return asarray(a)
|
||||
def clip(
|
||||
x: ArrayLike | None = None, # Default to preserve backwards compatability
|
||||
/,
|
||||
min: ArrayLike | None = None,
|
||||
max: ArrayLike | None = None,
|
||||
*,
|
||||
a: ArrayLike | DeprecatedArg = _DEPRECATED_CLIP_ARG,
|
||||
a_min: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG,
|
||||
a_max: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG
|
||||
) -> Array:
|
||||
# TODO(micky774): deprecated 2024-4-2, remove after deprecation expires.
|
||||
x = a if not isinstance(a, DeprecatedArg) else x
|
||||
if x is None:
|
||||
raise ValueError("No input was provided to the clip function.")
|
||||
min = a_min if not isinstance(a_min, DeprecatedArg) else min
|
||||
max = a_max if not isinstance(a_max, DeprecatedArg) else max
|
||||
if any(not isinstance(t, DeprecatedArg) for t in (a, a_min, a_max)):
|
||||
warnings.warn(
|
||||
"Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy.clip is "
|
||||
"deprecated. Please use 'x', 'min', and 'max' respectively instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
util.check_arraylike("clip", x)
|
||||
if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)):
|
||||
# TODO(micky774): Deprecated 2024-4-2, remove after deprecation expires.
|
||||
warnings.warn(
|
||||
"Clip received a complex value either through the input or the min/max "
|
||||
"keywords. Complex values have no ordering and cannot be clipped. "
|
||||
"Attempting to clip using complex numbers is deprecated and will soon "
|
||||
"raise a ValueError. Please convert to a real value or array by taking "
|
||||
"the real or imaginary components via jax.numpy.real/imag respectively.",
|
||||
DeprecationWarning, stacklevel=2,
|
||||
)
|
||||
if min is not None:
|
||||
x = ufuncs.maximum(min, x)
|
||||
if max is not None:
|
||||
x = ufuncs.minimum(max, x)
|
||||
return asarray(x)
|
||||
|
||||
@util.implements(np.around, skip_params=['out'])
|
||||
@partial(jit, static_argnames=('decimals',))
|
||||
|
@ -301,7 +301,7 @@ def _zeta_series_expansion(x: ArrayLike, q: ArrayLike | None = None) -> Array:
|
||||
m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim)))
|
||||
s_over_a = (s_ + m) / (a_ + N)
|
||||
T1 = jnp.cumprod(s_over_a, -1)[..., ::2]
|
||||
T1 = jnp.clip(T1, a_max=jnp.finfo(dtype).max)
|
||||
T1 = jnp.clip(T1, max=jnp.finfo(dtype).max)
|
||||
coefs = np.expand_dims(np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype),
|
||||
tuple(range(a.ndim)))
|
||||
T1 = T1 / coefs
|
||||
|
@ -77,3 +77,9 @@ class DuckTypedArray(Protocol):
|
||||
# JAX array (i.e. not including future non-standard array types like KeyArray and BInt).
|
||||
# It's different than np.typing.ArrayLike in that it doesn't accept arbitrary sequences,
|
||||
# nor does it accept string data.
|
||||
|
||||
# We use a class for deprecated args to avoid using Any/object types which can
|
||||
# introduce complications and mistakes in static analysis
|
||||
class DeprecatedArg:
|
||||
def __repr__(self):
|
||||
return "Deprecated"
|
||||
|
@ -112,6 +112,7 @@ from jax.experimental.array_api._elementwise_functions import (
|
||||
bitwise_right_shift as bitwise_right_shift,
|
||||
bitwise_xor as bitwise_xor,
|
||||
ceil as ceil,
|
||||
clip as clip,
|
||||
conj as conj,
|
||||
cos as cos,
|
||||
cosh as cosh,
|
||||
|
@ -125,6 +125,22 @@ def ceil(x, /):
|
||||
return jax.numpy.ceil(x)
|
||||
|
||||
|
||||
def clip(x, /, min=None, max=None):
|
||||
"""Returns the complex conjugate for each element x_i of the input array x."""
|
||||
x, = _promote_dtypes("clip", x)
|
||||
|
||||
# TODO(micky774): Remove when jnp.clip deprecation is completed
|
||||
# (began 2024-4-2) and default behavior is Array API 2023 compliant
|
||||
if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)):
|
||||
raise ValueError(
|
||||
"Clip received a complex value either through the input or the min/max "
|
||||
"keywords. Complex values have no ordering and cannot be clipped. "
|
||||
"Please convert to a real value or array by taking the real or "
|
||||
"imaginary components via jax.numpy.real/imag respectively."
|
||||
)
|
||||
return jax.numpy.clip(x, min=min, max=max)
|
||||
|
||||
|
||||
def conj(x, /):
|
||||
"""Returns the complex conjugate for each element x_i of the input array x."""
|
||||
x, = _promote_dtypes("conj", x)
|
||||
|
@ -1283,7 +1283,7 @@ class Jax2TfLimitation(test_harnesses.Limitation):
|
||||
# values like 1.0000001 on float32, which are clipped to 1.0. It is
|
||||
# possible that anything other than `cos_angular_diff` can be outside
|
||||
# the interval [0, 1] due to roundoff.
|
||||
cos_angular_diff = jnp.clip(cos_angular_diff, a_min=0.0, a_max=1.0)
|
||||
cos_angular_diff = jnp.clip(cos_angular_diff, min=0.0, max=1.0)
|
||||
|
||||
angular_diff = jnp.arccos(cos_angular_diff)
|
||||
|
||||
|
@ -201,7 +201,7 @@ def _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args):
|
||||
next_t = t + dt
|
||||
error_ratio = mean_error_ratio(next_y_error, rtol, atol, y, next_y)
|
||||
new_interp_coeff = interp_fit_dopri(y, next_y, k, dt)
|
||||
dt = jnp.clip(optimal_step_size(dt, error_ratio), a_min=0., a_max=hmax)
|
||||
dt = jnp.clip(optimal_step_size(dt, error_ratio), min=0., max=hmax)
|
||||
|
||||
new = [i + 1, next_y, next_f, next_t, dt, t, new_interp_coeff]
|
||||
old = [i + 1, y, f, t, dt, last_t, interp_coeff]
|
||||
@ -214,7 +214,7 @@ def _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args):
|
||||
return carry, y_target
|
||||
|
||||
f0 = func_(y0, ts[0])
|
||||
dt = jnp.clip(initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0), a_min=0., a_max=hmax)
|
||||
dt = jnp.clip(initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0), min=0., max=hmax)
|
||||
interp_coeff = jnp.array([y0] * 5)
|
||||
init_carry = [y0, f0, ts[0], dt, ts[0], interp_coeff]
|
||||
_, ys = lax.scan(scan_fun, init_carry, ts[1:])
|
||||
|
@ -9,7 +9,10 @@ from jax._src import dtypes as _dtypes
|
||||
from jax._src.lax.lax import PrecisionLike
|
||||
from jax._src.lax.slicing import GatherScatterMode
|
||||
from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass
|
||||
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DimSize, DuckTypedArray, Shape
|
||||
from jax._src.typing import (
|
||||
Array, ArrayLike, DType, DTypeLike,
|
||||
DimSize, DuckTypedArray, Shape, DeprecatedArg
|
||||
)
|
||||
from jax.numpy import fft as fft, linalg as linalg
|
||||
from jax.sharding import Sharding as _Sharding
|
||||
import numpy as _np
|
||||
@ -181,8 +184,15 @@ def ceil(x: ArrayLike, /) -> Array: ...
|
||||
character = _np.character
|
||||
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
|
||||
out: None = ..., mode: str = ...) -> Array: ...
|
||||
def clip(a: ArrayLike, a_min: Optional[ArrayLike] = ...,
|
||||
a_max: Optional[ArrayLike] = ..., out: None = ...) -> Array: ...
|
||||
def clip(
|
||||
x: ArrayLike | None = ...,
|
||||
/,
|
||||
min: Optional[ArrayLike] = ...,
|
||||
max: Optional[ArrayLike] = ...,
|
||||
a: ArrayLike | DeprecatedArg | None = ...,
|
||||
a_min: ArrayLike | DeprecatedArg | None = ...,
|
||||
a_max: ArrayLike | DeprecatedArg | None = ...
|
||||
) -> Array: ...
|
||||
def column_stack(
|
||||
tup: Union[_np.ndarray, Array, Sequence[ArrayLike]]
|
||||
) -> Array: ...
|
||||
|
@ -58,6 +58,7 @@ MAIN_NAMESPACE = {
|
||||
'broadcast_to',
|
||||
'can_cast',
|
||||
'ceil',
|
||||
'clip',
|
||||
'complex128',
|
||||
'complex64',
|
||||
'concat',
|
||||
@ -233,5 +234,27 @@ class ArrayAPISmokeTest(absltest.TestCase):
|
||||
self.assertIs(x.__array_namespace__(), array_api)
|
||||
|
||||
|
||||
class ArrayAPIErrors(absltest.TestCase):
|
||||
"""Test that our array API implementations raise errors where required"""
|
||||
|
||||
# TODO(micky774): Remove when jnp.clip deprecation is completed
|
||||
# (began 2024-4-2) and default behavior is Array API 2023 compliant
|
||||
def test_clip_complex(self):
|
||||
x = array_api.arange(5, dtype=array_api.complex64)
|
||||
complex_msg = "Complex values have no ordering and cannot be clipped"
|
||||
with self.assertRaisesRegex(ValueError, complex_msg):
|
||||
array_api.clip(x)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, complex_msg):
|
||||
array_api.clip(x, max=x)
|
||||
|
||||
x = array_api.arange(5, dtype=array_api.int32)
|
||||
with self.assertRaisesRegex(ValueError, complex_msg):
|
||||
array_api.clip(x, min=-1+5j)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, complex_msg):
|
||||
array_api.clip(x, max=-1+5j)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
@ -877,7 +877,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
a_max = None if a_max is None else abs(a_max)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max)
|
||||
jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max)
|
||||
jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
@ -872,14 +872,45 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
a_max = None if a_max is None else abs(a_max)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max)
|
||||
jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max)
|
||||
jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def testClipError(self):
|
||||
with self.assertRaisesRegex(ValueError, "At most one of a_min and a_max.*"):
|
||||
jnp.clip(jnp.zeros((3,)))
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=default_dtypes + unsigned_dtypes,
|
||||
)
|
||||
def testClipNone(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng(shape, dtype)
|
||||
self.assertArraysEqual(jnp.clip(x), x)
|
||||
|
||||
|
||||
# TODO(micky774): Check for ValueError instead of DeprecationWarning when
|
||||
# jnp.clip deprecation is completed (began 2024-4-2) and default behavior is
|
||||
# Array API 2023 compliant
|
||||
@jtu.sample_product(shape=all_shapes)
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
|
||||
def testClipComplexInputDeprecation(self, shape):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng(shape, dtype=jnp.complex64)
|
||||
msg = "Complex values have no ordering and cannot be clipped"
|
||||
with self.assertWarns(DeprecationWarning, msg=msg):
|
||||
jnp.clip(x)
|
||||
|
||||
with self.assertWarns(DeprecationWarning, msg=msg):
|
||||
jnp.clip(x, max=x)
|
||||
|
||||
x = rng(shape, dtype=jnp.int32)
|
||||
with self.assertWarns(DeprecationWarning, msg=msg):
|
||||
jnp.clip(x, min=-1+5j)
|
||||
|
||||
with self.assertWarns(DeprecationWarning, msg=msg):
|
||||
jnp.clip(x, max=jnp.array([-1+5j]))
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, dtype=dtype)
|
||||
@ -5772,7 +5803,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
'argpartition': ['kind', 'order'],
|
||||
'asarray': ['like'],
|
||||
'broadcast_to': ['subok'],
|
||||
'clip': ['kwargs'],
|
||||
'clip': ['kwargs', 'out'],
|
||||
'copy': ['subok'],
|
||||
'corrcoef': ['ddof', 'bias', 'dtype'],
|
||||
'cov': ['dtype'],
|
||||
@ -5809,6 +5840,9 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
}
|
||||
|
||||
extra_params = {
|
||||
# TODO(micky774): Remove when np.clip has adopted the Array API 2023
|
||||
# standard
|
||||
'clip': ['x', 'max', 'min'],
|
||||
'einsum': ['subscripts', 'precision'],
|
||||
'einsum_path': ['subscripts'],
|
||||
'take_along_axis': ['mode'],
|
||||
|
Loading…
x
Reference in New Issue
Block a user