Merge pull request #20550 from Micky774:api_clip

PiperOrigin-RevId: 622045823
This commit is contained in:
jax authors 2024-04-04 19:58:06 -07:00
commit 2512843a56
13 changed files with 168 additions and 30 deletions

View File

@ -12,6 +12,9 @@ Remember to align the itemized text with the first line of an item within a list
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
lowering pass via Triton Python APIs has been removed and the
`JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect.
* {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and
`a_max` are deprecated in favor of `x` (positonal only), `min`, and
`max` ({jax-issue}`20550`).
## jaxlib 0.4.27

View File

@ -84,12 +84,11 @@ def _itemsize(arr: ArrayLike) -> int:
def _clip(number: ArrayLike,
min: ArrayLike | None = None, max: ArrayLike | None = None,
out: None = None) -> Array:
min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array:
"""Return an array whose values are limited to a specified range.
Refer to :func:`jax.numpy.clip` for full documentation."""
return lax_numpy.clip(number, a_min=min, a_max=max, out=out)
return lax_numpy.clip(number, min=min, max=max)
def _transpose(a: Array, *args: Any) -> Array:

View File

@ -66,7 +66,10 @@ from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax._src.numpy import util
from jax._src.numpy.vectorize import vectorize
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape
from jax._src.typing import (
Array, ArrayLike, DimSize, DuckTypedArray,
DType, DTypeLike, Shape, DeprecatedArg
)
from jax._src.util import (unzip2, subvals, safe_zip,
ceil_of_ratio, partition_list,
canonicalize_axis as _canonicalize_axis,
@ -1293,20 +1296,63 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array
axis: int = 0) -> list[Array]:
return _split("array_split", ary, indices_or_sections, axis=axis)
@util.implements(np.clip, skip_params=['out'])
_DEPRECATED_CLIP_ARG = DeprecatedArg()
@util.implements(
np.clip,
skip_params=['a', 'a_min'],
extra_params=_dedent("""
x : array_like
Array containing elements to clip.
min : array_like, optional
Minimum value. If ``None``, clipping is not performed on the
corresponding edge. The value of ``min`` is broadcast against x.
max : array_like, optional
Maximum value. If ``None``, clipping is not performed on the
corresponding edge. The value of ``max`` is broadcast against x.
""")
)
@jit
def clip(a: ArrayLike, a_min: ArrayLike | None = None,
a_max: ArrayLike | None = None, out: None = None) -> Array:
util.check_arraylike("clip", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.clip is not supported.")
if a_min is None and a_max is None:
raise ValueError("At most one of a_min and a_max may be None")
if a_min is not None:
a = ufuncs.maximum(a_min, a)
if a_max is not None:
a = ufuncs.minimum(a_max, a)
return asarray(a)
def clip(
x: ArrayLike | None = None, # Default to preserve backwards compatability
/,
min: ArrayLike | None = None,
max: ArrayLike | None = None,
*,
a: ArrayLike | DeprecatedArg = _DEPRECATED_CLIP_ARG,
a_min: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG,
a_max: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG
) -> Array:
# TODO(micky774): deprecated 2024-4-2, remove after deprecation expires.
x = a if not isinstance(a, DeprecatedArg) else x
if x is None:
raise ValueError("No input was provided to the clip function.")
min = a_min if not isinstance(a_min, DeprecatedArg) else min
max = a_max if not isinstance(a_max, DeprecatedArg) else max
if any(not isinstance(t, DeprecatedArg) for t in (a, a_min, a_max)):
warnings.warn(
"Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy.clip is "
"deprecated. Please use 'x', 'min', and 'max' respectively instead.",
DeprecationWarning,
stacklevel=2,
)
util.check_arraylike("clip", x)
if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)):
# TODO(micky774): Deprecated 2024-4-2, remove after deprecation expires.
warnings.warn(
"Clip received a complex value either through the input or the min/max "
"keywords. Complex values have no ordering and cannot be clipped. "
"Attempting to clip using complex numbers is deprecated and will soon "
"raise a ValueError. Please convert to a real value or array by taking "
"the real or imaginary components via jax.numpy.real/imag respectively.",
DeprecationWarning, stacklevel=2,
)
if min is not None:
x = ufuncs.maximum(min, x)
if max is not None:
x = ufuncs.minimum(max, x)
return asarray(x)
@util.implements(np.around, skip_params=['out'])
@partial(jit, static_argnames=('decimals',))

View File

@ -301,7 +301,7 @@ def _zeta_series_expansion(x: ArrayLike, q: ArrayLike | None = None) -> Array:
m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim)))
s_over_a = (s_ + m) / (a_ + N)
T1 = jnp.cumprod(s_over_a, -1)[..., ::2]
T1 = jnp.clip(T1, a_max=jnp.finfo(dtype).max)
T1 = jnp.clip(T1, max=jnp.finfo(dtype).max)
coefs = np.expand_dims(np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype),
tuple(range(a.ndim)))
T1 = T1 / coefs

View File

@ -77,3 +77,9 @@ class DuckTypedArray(Protocol):
# JAX array (i.e. not including future non-standard array types like KeyArray and BInt).
# It's different than np.typing.ArrayLike in that it doesn't accept arbitrary sequences,
# nor does it accept string data.
# We use a class for deprecated args to avoid using Any/object types which can
# introduce complications and mistakes in static analysis
class DeprecatedArg:
def __repr__(self):
return "Deprecated"

View File

@ -112,6 +112,7 @@ from jax.experimental.array_api._elementwise_functions import (
bitwise_right_shift as bitwise_right_shift,
bitwise_xor as bitwise_xor,
ceil as ceil,
clip as clip,
conj as conj,
cos as cos,
cosh as cosh,

View File

@ -125,6 +125,22 @@ def ceil(x, /):
return jax.numpy.ceil(x)
def clip(x, /, min=None, max=None):
"""Returns the complex conjugate for each element x_i of the input array x."""
x, = _promote_dtypes("clip", x)
# TODO(micky774): Remove when jnp.clip deprecation is completed
# (began 2024-4-2) and default behavior is Array API 2023 compliant
if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)):
raise ValueError(
"Clip received a complex value either through the input or the min/max "
"keywords. Complex values have no ordering and cannot be clipped. "
"Please convert to a real value or array by taking the real or "
"imaginary components via jax.numpy.real/imag respectively."
)
return jax.numpy.clip(x, min=min, max=max)
def conj(x, /):
"""Returns the complex conjugate for each element x_i of the input array x."""
x, = _promote_dtypes("conj", x)

View File

@ -1283,7 +1283,7 @@ class Jax2TfLimitation(test_harnesses.Limitation):
# values like 1.0000001 on float32, which are clipped to 1.0. It is
# possible that anything other than `cos_angular_diff` can be outside
# the interval [0, 1] due to roundoff.
cos_angular_diff = jnp.clip(cos_angular_diff, a_min=0.0, a_max=1.0)
cos_angular_diff = jnp.clip(cos_angular_diff, min=0.0, max=1.0)
angular_diff = jnp.arccos(cos_angular_diff)

View File

@ -201,7 +201,7 @@ def _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args):
next_t = t + dt
error_ratio = mean_error_ratio(next_y_error, rtol, atol, y, next_y)
new_interp_coeff = interp_fit_dopri(y, next_y, k, dt)
dt = jnp.clip(optimal_step_size(dt, error_ratio), a_min=0., a_max=hmax)
dt = jnp.clip(optimal_step_size(dt, error_ratio), min=0., max=hmax)
new = [i + 1, next_y, next_f, next_t, dt, t, new_interp_coeff]
old = [i + 1, y, f, t, dt, last_t, interp_coeff]
@ -214,7 +214,7 @@ def _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args):
return carry, y_target
f0 = func_(y0, ts[0])
dt = jnp.clip(initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0), a_min=0., a_max=hmax)
dt = jnp.clip(initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0), min=0., max=hmax)
interp_coeff = jnp.array([y0] * 5)
init_carry = [y0, f0, ts[0], dt, ts[0], interp_coeff]
_, ys = lax.scan(scan_fun, init_carry, ts[1:])

View File

@ -9,7 +9,10 @@ from jax._src import dtypes as _dtypes
from jax._src.lax.lax import PrecisionLike
from jax._src.lax.slicing import GatherScatterMode
from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DimSize, DuckTypedArray, Shape
from jax._src.typing import (
Array, ArrayLike, DType, DTypeLike,
DimSize, DuckTypedArray, Shape, DeprecatedArg
)
from jax.numpy import fft as fft, linalg as linalg
from jax.sharding import Sharding as _Sharding
import numpy as _np
@ -181,8 +184,15 @@ def ceil(x: ArrayLike, /) -> Array: ...
character = _np.character
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = ..., mode: str = ...) -> Array: ...
def clip(a: ArrayLike, a_min: Optional[ArrayLike] = ...,
a_max: Optional[ArrayLike] = ..., out: None = ...) -> Array: ...
def clip(
x: ArrayLike | None = ...,
/,
min: Optional[ArrayLike] = ...,
max: Optional[ArrayLike] = ...,
a: ArrayLike | DeprecatedArg | None = ...,
a_min: ArrayLike | DeprecatedArg | None = ...,
a_max: ArrayLike | DeprecatedArg | None = ...
) -> Array: ...
def column_stack(
tup: Union[_np.ndarray, Array, Sequence[ArrayLike]]
) -> Array: ...

View File

@ -58,6 +58,7 @@ MAIN_NAMESPACE = {
'broadcast_to',
'can_cast',
'ceil',
'clip',
'complex128',
'complex64',
'concat',
@ -233,5 +234,27 @@ class ArrayAPISmokeTest(absltest.TestCase):
self.assertIs(x.__array_namespace__(), array_api)
class ArrayAPIErrors(absltest.TestCase):
"""Test that our array API implementations raise errors where required"""
# TODO(micky774): Remove when jnp.clip deprecation is completed
# (began 2024-4-2) and default behavior is Array API 2023 compliant
def test_clip_complex(self):
x = array_api.arange(5, dtype=array_api.complex64)
complex_msg = "Complex values have no ordering and cannot be clipped"
with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x)
with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x, max=x)
x = array_api.arange(5, dtype=array_api.int32)
with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x, min=-1+5j)
with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x, max=-1+5j)
if __name__ == '__main__':
absltest.main()

View File

@ -877,7 +877,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
a_max = None if a_max is None else abs(a_max)
rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max)
jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max)
jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)

View File

@ -872,14 +872,45 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
a_max = None if a_max is None else abs(a_max)
rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max)
jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max)
jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
def testClipError(self):
with self.assertRaisesRegex(ValueError, "At most one of a_min and a_max.*"):
jnp.clip(jnp.zeros((3,)))
@jtu.sample_product(
shape=all_shapes,
dtype=default_dtypes + unsigned_dtypes,
)
def testClipNone(self, shape, dtype):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
self.assertArraysEqual(jnp.clip(x), x)
# TODO(micky774): Check for ValueError instead of DeprecationWarning when
# jnp.clip deprecation is completed (began 2024-4-2) and default behavior is
# Array API 2023 compliant
@jtu.sample_product(shape=all_shapes)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testClipComplexInputDeprecation(self, shape):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype=jnp.complex64)
msg = "Complex values have no ordering and cannot be clipped"
with self.assertWarns(DeprecationWarning, msg=msg):
jnp.clip(x)
with self.assertWarns(DeprecationWarning, msg=msg):
jnp.clip(x, max=x)
x = rng(shape, dtype=jnp.int32)
with self.assertWarns(DeprecationWarning, msg=msg):
jnp.clip(x, min=-1+5j)
with self.assertWarns(DeprecationWarning, msg=msg):
jnp.clip(x, max=jnp.array([-1+5j]))
@jtu.sample_product(
[dict(shape=shape, dtype=dtype)
@ -5772,7 +5803,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
'argpartition': ['kind', 'order'],
'asarray': ['like'],
'broadcast_to': ['subok'],
'clip': ['kwargs'],
'clip': ['kwargs', 'out'],
'copy': ['subok'],
'corrcoef': ['ddof', 'bias', 'dtype'],
'cov': ['dtype'],
@ -5809,6 +5840,9 @@ class NumpySignaturesTest(jtu.JaxTestCase):
}
extra_params = {
# TODO(micky774): Remove when np.clip has adopted the Array API 2023
# standard
'clip': ['x', 'max', 'min'],
'einsum': ['subscripts', 'precision'],
'einsum_path': ['subscripts'],
'take_along_axis': ['mode'],