Merge pull request #20550 from Micky774:api_clip

PiperOrigin-RevId: 622045823
This commit is contained in:
jax authors 2024-04-04 19:58:06 -07:00
commit 2512843a56
13 changed files with 168 additions and 30 deletions

View File

@ -12,6 +12,9 @@ Remember to align the itemized text with the first line of an item within a list
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old * Pallas now exclusively uses XLA for compiling kernels on GPU. The old
lowering pass via Triton Python APIs has been removed and the lowering pass via Triton Python APIs has been removed and the
`JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect. `JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect.
* {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and
`a_max` are deprecated in favor of `x` (positonal only), `min`, and
`max` ({jax-issue}`20550`).
## jaxlib 0.4.27 ## jaxlib 0.4.27

View File

@ -84,12 +84,11 @@ def _itemsize(arr: ArrayLike) -> int:
def _clip(number: ArrayLike, def _clip(number: ArrayLike,
min: ArrayLike | None = None, max: ArrayLike | None = None, min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array:
out: None = None) -> Array:
"""Return an array whose values are limited to a specified range. """Return an array whose values are limited to a specified range.
Refer to :func:`jax.numpy.clip` for full documentation.""" Refer to :func:`jax.numpy.clip` for full documentation."""
return lax_numpy.clip(number, a_min=min, a_max=max, out=out) return lax_numpy.clip(number, min=min, max=max)
def _transpose(a: Array, *args: Any) -> Array: def _transpose(a: Array, *args: Any) -> Array:

View File

@ -66,7 +66,10 @@ from jax._src.numpy import reductions
from jax._src.numpy import ufuncs from jax._src.numpy import ufuncs
from jax._src.numpy import util from jax._src.numpy import util
from jax._src.numpy.vectorize import vectorize from jax._src.numpy.vectorize import vectorize
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape from jax._src.typing import (
Array, ArrayLike, DimSize, DuckTypedArray,
DType, DTypeLike, Shape, DeprecatedArg
)
from jax._src.util import (unzip2, subvals, safe_zip, from jax._src.util import (unzip2, subvals, safe_zip,
ceil_of_ratio, partition_list, ceil_of_ratio, partition_list,
canonicalize_axis as _canonicalize_axis, canonicalize_axis as _canonicalize_axis,
@ -1293,20 +1296,63 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array
axis: int = 0) -> list[Array]: axis: int = 0) -> list[Array]:
return _split("array_split", ary, indices_or_sections, axis=axis) return _split("array_split", ary, indices_or_sections, axis=axis)
@util.implements(np.clip, skip_params=['out'])
_DEPRECATED_CLIP_ARG = DeprecatedArg()
@util.implements(
np.clip,
skip_params=['a', 'a_min'],
extra_params=_dedent("""
x : array_like
Array containing elements to clip.
min : array_like, optional
Minimum value. If ``None``, clipping is not performed on the
corresponding edge. The value of ``min`` is broadcast against x.
max : array_like, optional
Maximum value. If ``None``, clipping is not performed on the
corresponding edge. The value of ``max`` is broadcast against x.
""")
)
@jit @jit
def clip(a: ArrayLike, a_min: ArrayLike | None = None, def clip(
a_max: ArrayLike | None = None, out: None = None) -> Array: x: ArrayLike | None = None, # Default to preserve backwards compatability
util.check_arraylike("clip", a) /,
if out is not None: min: ArrayLike | None = None,
raise NotImplementedError("The 'out' argument to jnp.clip is not supported.") max: ArrayLike | None = None,
if a_min is None and a_max is None: *,
raise ValueError("At most one of a_min and a_max may be None") a: ArrayLike | DeprecatedArg = _DEPRECATED_CLIP_ARG,
if a_min is not None: a_min: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG,
a = ufuncs.maximum(a_min, a) a_max: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG
if a_max is not None: ) -> Array:
a = ufuncs.minimum(a_max, a) # TODO(micky774): deprecated 2024-4-2, remove after deprecation expires.
return asarray(a) x = a if not isinstance(a, DeprecatedArg) else x
if x is None:
raise ValueError("No input was provided to the clip function.")
min = a_min if not isinstance(a_min, DeprecatedArg) else min
max = a_max if not isinstance(a_max, DeprecatedArg) else max
if any(not isinstance(t, DeprecatedArg) for t in (a, a_min, a_max)):
warnings.warn(
"Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy.clip is "
"deprecated. Please use 'x', 'min', and 'max' respectively instead.",
DeprecationWarning,
stacklevel=2,
)
util.check_arraylike("clip", x)
if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)):
# TODO(micky774): Deprecated 2024-4-2, remove after deprecation expires.
warnings.warn(
"Clip received a complex value either through the input or the min/max "
"keywords. Complex values have no ordering and cannot be clipped. "
"Attempting to clip using complex numbers is deprecated and will soon "
"raise a ValueError. Please convert to a real value or array by taking "
"the real or imaginary components via jax.numpy.real/imag respectively.",
DeprecationWarning, stacklevel=2,
)
if min is not None:
x = ufuncs.maximum(min, x)
if max is not None:
x = ufuncs.minimum(max, x)
return asarray(x)
@util.implements(np.around, skip_params=['out']) @util.implements(np.around, skip_params=['out'])
@partial(jit, static_argnames=('decimals',)) @partial(jit, static_argnames=('decimals',))

View File

@ -301,7 +301,7 @@ def _zeta_series_expansion(x: ArrayLike, q: ArrayLike | None = None) -> Array:
m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim))) m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim)))
s_over_a = (s_ + m) / (a_ + N) s_over_a = (s_ + m) / (a_ + N)
T1 = jnp.cumprod(s_over_a, -1)[..., ::2] T1 = jnp.cumprod(s_over_a, -1)[..., ::2]
T1 = jnp.clip(T1, a_max=jnp.finfo(dtype).max) T1 = jnp.clip(T1, max=jnp.finfo(dtype).max)
coefs = np.expand_dims(np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype), coefs = np.expand_dims(np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype),
tuple(range(a.ndim))) tuple(range(a.ndim)))
T1 = T1 / coefs T1 = T1 / coefs

View File

@ -77,3 +77,9 @@ class DuckTypedArray(Protocol):
# JAX array (i.e. not including future non-standard array types like KeyArray and BInt). # JAX array (i.e. not including future non-standard array types like KeyArray and BInt).
# It's different than np.typing.ArrayLike in that it doesn't accept arbitrary sequences, # It's different than np.typing.ArrayLike in that it doesn't accept arbitrary sequences,
# nor does it accept string data. # nor does it accept string data.
# We use a class for deprecated args to avoid using Any/object types which can
# introduce complications and mistakes in static analysis
class DeprecatedArg:
def __repr__(self):
return "Deprecated"

View File

@ -112,6 +112,7 @@ from jax.experimental.array_api._elementwise_functions import (
bitwise_right_shift as bitwise_right_shift, bitwise_right_shift as bitwise_right_shift,
bitwise_xor as bitwise_xor, bitwise_xor as bitwise_xor,
ceil as ceil, ceil as ceil,
clip as clip,
conj as conj, conj as conj,
cos as cos, cos as cos,
cosh as cosh, cosh as cosh,

View File

@ -125,6 +125,22 @@ def ceil(x, /):
return jax.numpy.ceil(x) return jax.numpy.ceil(x)
def clip(x, /, min=None, max=None):
"""Returns the complex conjugate for each element x_i of the input array x."""
x, = _promote_dtypes("clip", x)
# TODO(micky774): Remove when jnp.clip deprecation is completed
# (began 2024-4-2) and default behavior is Array API 2023 compliant
if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)):
raise ValueError(
"Clip received a complex value either through the input or the min/max "
"keywords. Complex values have no ordering and cannot be clipped. "
"Please convert to a real value or array by taking the real or "
"imaginary components via jax.numpy.real/imag respectively."
)
return jax.numpy.clip(x, min=min, max=max)
def conj(x, /): def conj(x, /):
"""Returns the complex conjugate for each element x_i of the input array x.""" """Returns the complex conjugate for each element x_i of the input array x."""
x, = _promote_dtypes("conj", x) x, = _promote_dtypes("conj", x)

View File

@ -1283,7 +1283,7 @@ class Jax2TfLimitation(test_harnesses.Limitation):
# values like 1.0000001 on float32, which are clipped to 1.0. It is # values like 1.0000001 on float32, which are clipped to 1.0. It is
# possible that anything other than `cos_angular_diff` can be outside # possible that anything other than `cos_angular_diff` can be outside
# the interval [0, 1] due to roundoff. # the interval [0, 1] due to roundoff.
cos_angular_diff = jnp.clip(cos_angular_diff, a_min=0.0, a_max=1.0) cos_angular_diff = jnp.clip(cos_angular_diff, min=0.0, max=1.0)
angular_diff = jnp.arccos(cos_angular_diff) angular_diff = jnp.arccos(cos_angular_diff)

View File

@ -201,7 +201,7 @@ def _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args):
next_t = t + dt next_t = t + dt
error_ratio = mean_error_ratio(next_y_error, rtol, atol, y, next_y) error_ratio = mean_error_ratio(next_y_error, rtol, atol, y, next_y)
new_interp_coeff = interp_fit_dopri(y, next_y, k, dt) new_interp_coeff = interp_fit_dopri(y, next_y, k, dt)
dt = jnp.clip(optimal_step_size(dt, error_ratio), a_min=0., a_max=hmax) dt = jnp.clip(optimal_step_size(dt, error_ratio), min=0., max=hmax)
new = [i + 1, next_y, next_f, next_t, dt, t, new_interp_coeff] new = [i + 1, next_y, next_f, next_t, dt, t, new_interp_coeff]
old = [i + 1, y, f, t, dt, last_t, interp_coeff] old = [i + 1, y, f, t, dt, last_t, interp_coeff]
@ -214,7 +214,7 @@ def _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args):
return carry, y_target return carry, y_target
f0 = func_(y0, ts[0]) f0 = func_(y0, ts[0])
dt = jnp.clip(initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0), a_min=0., a_max=hmax) dt = jnp.clip(initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0), min=0., max=hmax)
interp_coeff = jnp.array([y0] * 5) interp_coeff = jnp.array([y0] * 5)
init_carry = [y0, f0, ts[0], dt, ts[0], interp_coeff] init_carry = [y0, f0, ts[0], dt, ts[0], interp_coeff]
_, ys = lax.scan(scan_fun, init_carry, ts[1:]) _, ys = lax.scan(scan_fun, init_carry, ts[1:])

View File

@ -9,7 +9,10 @@ from jax._src import dtypes as _dtypes
from jax._src.lax.lax import PrecisionLike from jax._src.lax.lax import PrecisionLike
from jax._src.lax.slicing import GatherScatterMode from jax._src.lax.slicing import GatherScatterMode
from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DimSize, DuckTypedArray, Shape from jax._src.typing import (
Array, ArrayLike, DType, DTypeLike,
DimSize, DuckTypedArray, Shape, DeprecatedArg
)
from jax.numpy import fft as fft, linalg as linalg from jax.numpy import fft as fft, linalg as linalg
from jax.sharding import Sharding as _Sharding from jax.sharding import Sharding as _Sharding
import numpy as _np import numpy as _np
@ -181,8 +184,15 @@ def ceil(x: ArrayLike, /) -> Array: ...
character = _np.character character = _np.character
def choose(a: ArrayLike, choices: Sequence[ArrayLike], def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = ..., mode: str = ...) -> Array: ... out: None = ..., mode: str = ...) -> Array: ...
def clip(a: ArrayLike, a_min: Optional[ArrayLike] = ..., def clip(
a_max: Optional[ArrayLike] = ..., out: None = ...) -> Array: ... x: ArrayLike | None = ...,
/,
min: Optional[ArrayLike] = ...,
max: Optional[ArrayLike] = ...,
a: ArrayLike | DeprecatedArg | None = ...,
a_min: ArrayLike | DeprecatedArg | None = ...,
a_max: ArrayLike | DeprecatedArg | None = ...
) -> Array: ...
def column_stack( def column_stack(
tup: Union[_np.ndarray, Array, Sequence[ArrayLike]] tup: Union[_np.ndarray, Array, Sequence[ArrayLike]]
) -> Array: ... ) -> Array: ...

View File

@ -58,6 +58,7 @@ MAIN_NAMESPACE = {
'broadcast_to', 'broadcast_to',
'can_cast', 'can_cast',
'ceil', 'ceil',
'clip',
'complex128', 'complex128',
'complex64', 'complex64',
'concat', 'concat',
@ -233,5 +234,27 @@ class ArrayAPISmokeTest(absltest.TestCase):
self.assertIs(x.__array_namespace__(), array_api) self.assertIs(x.__array_namespace__(), array_api)
class ArrayAPIErrors(absltest.TestCase):
"""Test that our array API implementations raise errors where required"""
# TODO(micky774): Remove when jnp.clip deprecation is completed
# (began 2024-4-2) and default behavior is Array API 2023 compliant
def test_clip_complex(self):
x = array_api.arange(5, dtype=array_api.complex64)
complex_msg = "Complex values have no ordering and cannot be clipped"
with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x)
with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x, max=x)
x = array_api.arange(5, dtype=array_api.int32)
with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x, min=-1+5j)
with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x, max=-1+5j)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -877,7 +877,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
a_max = None if a_max is None else abs(a_max) a_max = None if a_max is None else abs(a_max)
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max) np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max)
jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max) jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max)
args_maker = lambda: [rng(shape, dtype)] args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker)

View File

@ -872,14 +872,45 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
a_max = None if a_max is None else abs(a_max) a_max = None if a_max is None else abs(a_max)
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max) np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max)
jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max) jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max)
args_maker = lambda: [rng(shape, dtype)] args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker)
def testClipError(self):
with self.assertRaisesRegex(ValueError, "At most one of a_min and a_max.*"): @jtu.sample_product(
jnp.clip(jnp.zeros((3,))) shape=all_shapes,
dtype=default_dtypes + unsigned_dtypes,
)
def testClipNone(self, shape, dtype):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
self.assertArraysEqual(jnp.clip(x), x)
# TODO(micky774): Check for ValueError instead of DeprecationWarning when
# jnp.clip deprecation is completed (began 2024-4-2) and default behavior is
# Array API 2023 compliant
@jtu.sample_product(shape=all_shapes)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testClipComplexInputDeprecation(self, shape):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype=jnp.complex64)
msg = "Complex values have no ordering and cannot be clipped"
with self.assertWarns(DeprecationWarning, msg=msg):
jnp.clip(x)
with self.assertWarns(DeprecationWarning, msg=msg):
jnp.clip(x, max=x)
x = rng(shape, dtype=jnp.int32)
with self.assertWarns(DeprecationWarning, msg=msg):
jnp.clip(x, min=-1+5j)
with self.assertWarns(DeprecationWarning, msg=msg):
jnp.clip(x, max=jnp.array([-1+5j]))
@jtu.sample_product( @jtu.sample_product(
[dict(shape=shape, dtype=dtype) [dict(shape=shape, dtype=dtype)
@ -5772,7 +5803,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
'argpartition': ['kind', 'order'], 'argpartition': ['kind', 'order'],
'asarray': ['like'], 'asarray': ['like'],
'broadcast_to': ['subok'], 'broadcast_to': ['subok'],
'clip': ['kwargs'], 'clip': ['kwargs', 'out'],
'copy': ['subok'], 'copy': ['subok'],
'corrcoef': ['ddof', 'bias', 'dtype'], 'corrcoef': ['ddof', 'bias', 'dtype'],
'cov': ['dtype'], 'cov': ['dtype'],
@ -5809,6 +5840,9 @@ class NumpySignaturesTest(jtu.JaxTestCase):
} }
extra_params = { extra_params = {
# TODO(micky774): Remove when np.clip has adopted the Array API 2023
# standard
'clip': ['x', 'max', 'min'],
'einsum': ['subscripts', 'precision'], 'einsum': ['subscripts', 'precision'],
'einsum_path': ['subscripts'], 'einsum_path': ['subscripts'],
'take_along_axis': ['mode'], 'take_along_axis': ['mode'],