mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
DOC: remove unimplemneted parameters from lax.numpy docstrings
This commit is contained in:
parent
4e99b0e488
commit
5c098b11c5
@ -1243,7 +1243,7 @@ def _gradient(a, varargs, axis):
|
||||
return a_grad
|
||||
|
||||
|
||||
@_wraps(np.gradient)
|
||||
@_wraps(np.gradient, skip_params=['edge_order'])
|
||||
def gradient(f, *varargs, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
edge_order=None):
|
||||
if edge_order is not None:
|
||||
@ -1785,7 +1785,7 @@ dsplit = _split_on_axis(np.dsplit, axis=2)
|
||||
def array_split(ary, indices_or_sections, axis: int = 0):
|
||||
return _split("array_split", ary, indices_or_sections, axis=axis)
|
||||
|
||||
@_wraps(np.clip)
|
||||
@_wraps(np.clip, skip_params=['out'])
|
||||
def clip(a, a_min=None, a_max=None, out=None):
|
||||
_check_arraylike("clip", a)
|
||||
if out is not None:
|
||||
@ -1798,7 +1798,7 @@ def clip(a, a_min=None, a_max=None, out=None):
|
||||
a = minimum(a_max, a)
|
||||
return a
|
||||
|
||||
@_wraps(np.round, update_doc=False)
|
||||
@_wraps(np.round, update_doc=False, skip_params=['out'])
|
||||
def round(a, decimals=0, out=None):
|
||||
_check_arraylike("round", a)
|
||||
decimals = core.concrete_or_error(operator.index, decimals, "'decimals' argument of jnp.round")
|
||||
@ -1832,7 +1832,7 @@ def round(a, decimals=0, out=None):
|
||||
around = round
|
||||
|
||||
|
||||
@_wraps(np.fix)
|
||||
@_wraps(np.fix, skip_params=['out'])
|
||||
def fix(x, out=None):
|
||||
_check_arraylike("fix", x)
|
||||
if out is not None:
|
||||
@ -1841,7 +1841,7 @@ def fix(x, out=None):
|
||||
return where(lax.ge(x, zero), floor(x), ceil(x))
|
||||
|
||||
|
||||
@_wraps(np.modf)
|
||||
@_wraps(np.modf, skip_params=['out'])
|
||||
def modf(x, out=None):
|
||||
_check_arraylike("modf", x)
|
||||
if out is not None:
|
||||
@ -1886,9 +1886,11 @@ def _isposneginf(infinity, x, out):
|
||||
else:
|
||||
return full_like(x, False, dtype=bool_)
|
||||
|
||||
isposinf = _wraps(np.isposinf)(lambda x, out=None: _isposneginf(inf, x, out))
|
||||
isposinf = _wraps(np.isposinf, skip_params=['out'])(
|
||||
lambda x, out=None: _isposneginf(inf, x, out))
|
||||
|
||||
isneginf = _wraps(np.isneginf)(lambda x, out=None: _isposneginf(-inf, x, out))
|
||||
isneginf = _wraps(np.isneginf, skip_params=['out'])(
|
||||
lambda x, out=None: _isposneginf(-inf, x, out))
|
||||
|
||||
@_wraps(np.isnan)
|
||||
def isnan(x):
|
||||
@ -1919,6 +1921,9 @@ def _reduction(a, name, np_fun, op, init_val, has_identity=True,
|
||||
axis=None, dtype=None, out=None, keepdims=False, initial=None,
|
||||
where_=None, parallel_reduce=None):
|
||||
bool_op = bool_op or op
|
||||
# Note: we must accept out=None as an argument, because numpy reductions delegate to
|
||||
# object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method
|
||||
# exists, passing along all its arguments.
|
||||
if out is not None:
|
||||
raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
|
||||
_check_arraylike(name, a)
|
||||
@ -1990,7 +1995,7 @@ def _reduction_init_val(a, init_val):
|
||||
|
||||
_cast_to_bool = partial(lax.convert_element_type, new_dtype=bool_)
|
||||
|
||||
@_wraps(np.sum)
|
||||
@_wraps(np.sum, skip_params=['out'])
|
||||
def sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None):
|
||||
return _reduction(a, "sum", np.sum, lax.add, 0,
|
||||
@ -1998,34 +2003,34 @@ def sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.psum)
|
||||
|
||||
@_wraps(np.prod)
|
||||
@_wraps(np.prod, skip_params=['out'])
|
||||
def prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None):
|
||||
return _reduction(a, "prod", np.prod, lax.mul, 1,
|
||||
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where)
|
||||
|
||||
@_wraps(np.max)
|
||||
@_wraps(np.max, skip_params=['out'])
|
||||
def max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None):
|
||||
return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False,
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.pmax)
|
||||
|
||||
@_wraps(np.min)
|
||||
@_wraps(np.min, skip_params=['out'])
|
||||
def min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None):
|
||||
return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False,
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.pmin)
|
||||
|
||||
@_wraps(np.all)
|
||||
@_wraps(np.all, skip_params=['out'])
|
||||
def all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, *, where=None):
|
||||
return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool,
|
||||
axis=axis, out=out, keepdims=keepdims, where_=where)
|
||||
|
||||
@_wraps(np.any)
|
||||
@_wraps(np.any, skip_params=['out'])
|
||||
def any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, *, where=None):
|
||||
return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool,
|
||||
@ -2046,7 +2051,7 @@ def _axis_size(a, axis):
|
||||
size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))
|
||||
return size
|
||||
|
||||
@_wraps(np.mean)
|
||||
@_wraps(np.mean, skip_params=['out'])
|
||||
def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=False, *, where=None):
|
||||
_check_arraylike("mean", a)
|
||||
@ -2123,7 +2128,7 @@ def average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,
|
||||
return avg
|
||||
|
||||
|
||||
@_wraps(np.var)
|
||||
@_wraps(np.var, skip_params=['out'])
|
||||
def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, ddof=0, keepdims=False, *, where=None):
|
||||
_check_arraylike("var", a)
|
||||
@ -2173,7 +2178,7 @@ def _var_promote_types(a_dtype, dtype):
|
||||
return a_dtype, dtype
|
||||
|
||||
|
||||
@_wraps(np.std)
|
||||
@_wraps(np.std, skip_params=['out'])
|
||||
def std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, ddof=0, keepdims=False, *, where=None):
|
||||
_check_arraylike("std", a)
|
||||
@ -2183,7 +2188,7 @@ def std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
return sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
|
||||
|
||||
|
||||
@_wraps(np.ptp)
|
||||
@_wraps(np.ptp, skip_params=['out'])
|
||||
def ptp(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=False):
|
||||
_check_arraylike("ptp", a)
|
||||
@ -2241,31 +2246,31 @@ def _nan_reduction(a, name, jnp_reduction, init_val, nan_if_all_nan,
|
||||
else:
|
||||
return out
|
||||
|
||||
@_wraps(np.nanmin)
|
||||
@_wraps(np.nanmin, skip_params=['out'])
|
||||
def nanmin(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None):
|
||||
return _nan_reduction(a, 'nanmin', min, inf, nan_if_all_nan=True,
|
||||
axis=axis, out=out, keepdims=keepdims)
|
||||
|
||||
@_wraps(np.nanmax)
|
||||
@_wraps(np.nanmax, skip_params=['out'])
|
||||
def nanmax(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None):
|
||||
return _nan_reduction(a, 'nanmax', max, -inf, nan_if_all_nan=True,
|
||||
axis=axis, out=out, keepdims=keepdims)
|
||||
|
||||
@_wraps(np.nansum)
|
||||
@_wraps(np.nansum, skip_params=['out'])
|
||||
def nansum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None):
|
||||
return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False,
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims)
|
||||
|
||||
@_wraps(np.nanprod)
|
||||
@_wraps(np.nanprod, skip_params=['out'])
|
||||
def nanprod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None):
|
||||
return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False,
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims)
|
||||
|
||||
@_wraps(np.nanmean)
|
||||
@_wraps(np.nanmean, skip_params=['out'])
|
||||
def nanmean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=False):
|
||||
_check_arraylike("nanmean", a)
|
||||
@ -2283,7 +2288,7 @@ def nanmean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
return td
|
||||
|
||||
|
||||
@_wraps(np.nanvar)
|
||||
@_wraps(np.nanvar, skip_params=['out'])
|
||||
def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, ddof=0, keepdims=False):
|
||||
_check_arraylike("nanvar", a)
|
||||
@ -2314,7 +2319,7 @@ def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
return lax.convert_element_type(out, dtype)
|
||||
|
||||
|
||||
@_wraps(np.nanstd)
|
||||
@_wraps(np.nanstd, skip_params=['out'])
|
||||
def nanstd(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, ddof=0, keepdims=False):
|
||||
_check_arraylike("nanstd", a)
|
||||
@ -2349,7 +2354,7 @@ def _make_cumulative_reduction(np_reduction, reduction, fill_nan=False, fill_val
|
||||
|
||||
return reduction(a, axis)
|
||||
|
||||
@_wraps(np_reduction)
|
||||
@_wraps(np_reduction, skip_params=['out'])
|
||||
def cumulative_reduction(a,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None):
|
||||
@ -2699,7 +2704,7 @@ def pad(array, pad_width, mode="constant", **kwargs):
|
||||
return _pad(array, pad_width, mode, constant_values, stat_length, end_values, reflect_type)
|
||||
|
||||
|
||||
@_wraps(np.stack)
|
||||
@_wraps(np.stack, skip_params=['out'])
|
||||
def stack(arrays, axis: int =0, out=None):
|
||||
if not len(arrays):
|
||||
raise ValueError("Need at least one array to stack.")
|
||||
@ -2783,7 +2788,7 @@ def column_stack(tup):
|
||||
return concatenate(arrays, 1)
|
||||
|
||||
|
||||
@_wraps(np.choose)
|
||||
@_wraps(np.choose, skip_params=['out'])
|
||||
def choose(a, choices, out=None, mode='raise'):
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
|
||||
@ -3364,7 +3369,7 @@ def triu(m, k=0):
|
||||
return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m)
|
||||
|
||||
|
||||
@_wraps(np.trace)
|
||||
@_wraps(np.trace, skip_params=['out'])
|
||||
def trace(a, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, out=None):
|
||||
_check_arraylike("trace", a)
|
||||
if out is not None:
|
||||
@ -3793,7 +3798,7 @@ def tensordot(a, b, axes=2, *, precision=None):
|
||||
precision=precision)
|
||||
|
||||
|
||||
@_wraps(np.einsum, lax_description=_PRECISION_DOC)
|
||||
@_wraps(np.einsum, lax_description=_PRECISION_DOC, skip_params=['out'])
|
||||
def einsum(*operands, out=None, optimize='greedy', precision=None,
|
||||
_use_xeinsum=False):
|
||||
if out is not None:
|
||||
@ -3970,7 +3975,7 @@ def inner(a, b, *, precision=None):
|
||||
return tensordot(a, b, (-1, -1), precision=precision)
|
||||
|
||||
|
||||
@_wraps(np.outer)
|
||||
@_wraps(np.outer, skip_params=['out'])
|
||||
def outer(a, b, out=None):
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.outer is not supported.")
|
||||
@ -4049,7 +4054,7 @@ def argwhere(a):
|
||||
return result.reshape(result.shape[0], ndim(a))
|
||||
|
||||
|
||||
@_wraps(np.argmax)
|
||||
@_wraps(np.argmax, skip_params=['out'])
|
||||
def argmax(a, axis: Optional[int] = None, out=None):
|
||||
_check_arraylike("argmax", a)
|
||||
if out is not None:
|
||||
@ -4061,7 +4066,7 @@ def argmax(a, axis: Optional[int] = None, out=None):
|
||||
raise ValueError("attempt to get argmax of an empty sequence")
|
||||
return lax.argmax(a, _canonicalize_axis(axis, a.ndim), int64)
|
||||
|
||||
@_wraps(np.argmin)
|
||||
@_wraps(np.argmin, skip_params=['out'])
|
||||
def argmin(a, axis: Optional[int] = None, out=None):
|
||||
_check_arraylike("argmin", a)
|
||||
if out is not None:
|
||||
@ -4246,7 +4251,7 @@ def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'):
|
||||
return swapaxes(unpacked, axis, -1)
|
||||
|
||||
|
||||
@_wraps(np.take)
|
||||
@_wraps(np.take, skip_params=['out'])
|
||||
def take(a, indices, axis: Optional[int] = None, out=None, mode=None):
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.take is not supported.")
|
||||
@ -4423,7 +4428,7 @@ def _unique1d(ar, return_index=False, return_inverse=False,
|
||||
ret += (diff(idx),)
|
||||
return ret
|
||||
|
||||
@_wraps(np.unique)
|
||||
@_wraps(np.unique, skip_params=['axis'])
|
||||
def unique(ar, return_index=False, return_inverse=False,
|
||||
return_counts=False, axis: Optional[int] = None):
|
||||
ar = core.concrete_or_error(asarray, ar, "The error arose in jnp.unique()")
|
||||
@ -4914,7 +4919,7 @@ def extract(condition, arr):
|
||||
return compress(ravel(condition), ravel(arr))
|
||||
|
||||
|
||||
@_wraps(np.compress)
|
||||
@_wraps(np.compress, skip_params=['out'])
|
||||
def compress(condition, a, axis: Optional[int] = None, out=None):
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.compress is not supported.")
|
||||
@ -5020,7 +5025,7 @@ def corrcoef(x, y=None, rowvar=True):
|
||||
return c
|
||||
|
||||
|
||||
@_wraps(getattr(np, "quantile", None))
|
||||
@_wraps(getattr(np, "quantile", None), skip_params=['out', 'overwrite_input'])
|
||||
def quantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
overwrite_input=False, interpolation="linear", keepdims=False):
|
||||
_check_arraylike("quantile", a, q)
|
||||
@ -5030,7 +5035,7 @@ def quantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
raise ValueError(msg)
|
||||
return _quantile(a, q, axis, interpolation, keepdims, False)
|
||||
|
||||
@_wraps(getattr(np, "nanquantile", None))
|
||||
@_wraps(getattr(np, "nanquantile", None), skip_params=['out', 'overwrite_input'])
|
||||
def nanquantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out=None, overwrite_input=False, interpolation="linear",
|
||||
keepdims=False):
|
||||
@ -5158,7 +5163,7 @@ def _searchsorted(a, v, side):
|
||||
return lax.fori_loop(0, n_levels, body_fun, (0, len(a)))[1]
|
||||
|
||||
|
||||
@_wraps(np.searchsorted)
|
||||
@_wraps(np.searchsorted, skip_params=['sorter'])
|
||||
def searchsorted(a, v, side='left', sorter=None):
|
||||
if side not in ['left', 'right']:
|
||||
raise ValueError(f"{side!r} is an invalid value for keyword 'side'")
|
||||
@ -5209,7 +5214,7 @@ def piecewise(x, condlist, funclist, *args, **kw):
|
||||
return vectorize(lax.switch, excluded=(1,))(indices, funclist, x)
|
||||
|
||||
|
||||
@_wraps(np.percentile)
|
||||
@_wraps(np.percentile, skip_params=['out', 'overwrite_input'])
|
||||
def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out=None, overwrite_input=False, interpolation="linear",
|
||||
keepdims=False):
|
||||
@ -5218,7 +5223,7 @@ def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
interpolation=interpolation, keepdims=keepdims)
|
||||
|
||||
@_wraps(np.nanpercentile)
|
||||
@_wraps(np.nanpercentile, skip_params=['out', 'overwrite_input'])
|
||||
def nanpercentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
out=None, overwrite_input=False, interpolation="linear",
|
||||
keepdims=False):
|
||||
@ -5227,14 +5232,14 @@ def nanpercentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
interpolation=interpolation, keepdims=keepdims)
|
||||
|
||||
@_wraps(np.median)
|
||||
@_wraps(np.median, skip_params=['out', 'overwrite_input'])
|
||||
def median(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
overwrite_input=False, keepdims=False):
|
||||
_check_arraylike("median", a)
|
||||
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
keepdims=keepdims, interpolation='midpoint')
|
||||
|
||||
@_wraps(np.nanmedian)
|
||||
@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input'])
|
||||
def nanmedian(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
overwrite_input=False, keepdims=False):
|
||||
_check_arraylike("nanmedian", a)
|
||||
|
@ -86,7 +86,8 @@ def _parse_parameters(body: str) -> Dict[str, str]:
|
||||
|
||||
|
||||
def _wraps(fun: Callable, update_doc: bool = True, lax_description: str = "",
|
||||
sections: Sequence[str] = ('Parameters', 'Returns', 'References')):
|
||||
sections: Sequence[str] = ('Parameters', 'Returns', 'References'),
|
||||
skip_params: Sequence[str] = ()):
|
||||
"""Specialized version of functools.wraps for wrapping numpy functions.
|
||||
|
||||
This produces a wrapped function with a modified docstring. In particular, if
|
||||
@ -104,6 +105,8 @@ def _wraps(fun: Callable, update_doc: bool = True, lax_description: str = "",
|
||||
the docstring.
|
||||
sections: a list of sections to include in the docstring. The default is
|
||||
["Parameters", "returns", "References"]
|
||||
skip_params: a list of strings containing names of parameters accepted by the
|
||||
function that should be skipped in the parameter list.
|
||||
"""
|
||||
def wrap(op):
|
||||
docstr = getattr(fun, "__doc__", None)
|
||||
@ -118,7 +121,7 @@ def _wraps(fun: Callable, update_doc: bool = True, lax_description: str = "",
|
||||
"Parameters\n"
|
||||
"----------\n" +
|
||||
"\n".join(_versionadded.split(desc)[0].rstrip() for p, desc in parameters.items()
|
||||
if p in op.__code__.co_varnames)
|
||||
if p in op.__code__.co_varnames and p not in skip_params)
|
||||
)
|
||||
|
||||
docstr = parsed.summary.strip() + "\n" if parsed.summary else ""
|
||||
|
Loading…
x
Reference in New Issue
Block a user