DOC: remove unimplemneted parameters from lax.numpy docstrings

This commit is contained in:
Jake VanderPlas 2021-03-25 14:47:18 -07:00
parent 4e99b0e488
commit 5c098b11c5
2 changed files with 52 additions and 44 deletions

View File

@ -1243,7 +1243,7 @@ def _gradient(a, varargs, axis):
return a_grad
@_wraps(np.gradient)
@_wraps(np.gradient, skip_params=['edge_order'])
def gradient(f, *varargs, axis: Optional[Union[int, Tuple[int, ...]]] = None,
edge_order=None):
if edge_order is not None:
@ -1785,7 +1785,7 @@ dsplit = _split_on_axis(np.dsplit, axis=2)
def array_split(ary, indices_or_sections, axis: int = 0):
return _split("array_split", ary, indices_or_sections, axis=axis)
@_wraps(np.clip)
@_wraps(np.clip, skip_params=['out'])
def clip(a, a_min=None, a_max=None, out=None):
_check_arraylike("clip", a)
if out is not None:
@ -1798,7 +1798,7 @@ def clip(a, a_min=None, a_max=None, out=None):
a = minimum(a_max, a)
return a
@_wraps(np.round, update_doc=False)
@_wraps(np.round, update_doc=False, skip_params=['out'])
def round(a, decimals=0, out=None):
_check_arraylike("round", a)
decimals = core.concrete_or_error(operator.index, decimals, "'decimals' argument of jnp.round")
@ -1832,7 +1832,7 @@ def round(a, decimals=0, out=None):
around = round
@_wraps(np.fix)
@_wraps(np.fix, skip_params=['out'])
def fix(x, out=None):
_check_arraylike("fix", x)
if out is not None:
@ -1841,7 +1841,7 @@ def fix(x, out=None):
return where(lax.ge(x, zero), floor(x), ceil(x))
@_wraps(np.modf)
@_wraps(np.modf, skip_params=['out'])
def modf(x, out=None):
_check_arraylike("modf", x)
if out is not None:
@ -1886,9 +1886,11 @@ def _isposneginf(infinity, x, out):
else:
return full_like(x, False, dtype=bool_)
isposinf = _wraps(np.isposinf)(lambda x, out=None: _isposneginf(inf, x, out))
isposinf = _wraps(np.isposinf, skip_params=['out'])(
lambda x, out=None: _isposneginf(inf, x, out))
isneginf = _wraps(np.isneginf)(lambda x, out=None: _isposneginf(-inf, x, out))
isneginf = _wraps(np.isneginf, skip_params=['out'])(
lambda x, out=None: _isposneginf(-inf, x, out))
@_wraps(np.isnan)
def isnan(x):
@ -1919,6 +1921,9 @@ def _reduction(a, name, np_fun, op, init_val, has_identity=True,
axis=None, dtype=None, out=None, keepdims=False, initial=None,
where_=None, parallel_reduce=None):
bool_op = bool_op or op
# Note: we must accept out=None as an argument, because numpy reductions delegate to
# object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method
# exists, passing along all its arguments.
if out is not None:
raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
_check_arraylike(name, a)
@ -1990,7 +1995,7 @@ def _reduction_init_val(a, init_val):
_cast_to_bool = partial(lax.convert_element_type, new_dtype=bool_)
@_wraps(np.sum)
@_wraps(np.sum, skip_params=['out'])
def sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None):
return _reduction(a, "sum", np.sum, lax.add, 0,
@ -1998,34 +2003,34 @@ def sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.psum)
@_wraps(np.prod)
@_wraps(np.prod, skip_params=['out'])
def prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None):
return _reduction(a, "prod", np.prod, lax.mul, 1,
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where)
@_wraps(np.max)
@_wraps(np.max, skip_params=['out'])
def max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None):
return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.pmax)
@_wraps(np.min)
@_wraps(np.min, skip_params=['out'])
def min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None):
return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.pmin)
@_wraps(np.all)
@_wraps(np.all, skip_params=['out'])
def all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, *, where=None):
return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool,
axis=axis, out=out, keepdims=keepdims, where_=where)
@_wraps(np.any)
@_wraps(np.any, skip_params=['out'])
def any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, *, where=None):
return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool,
@ -2046,7 +2051,7 @@ def _axis_size(a, axis):
size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))
return size
@_wraps(np.mean)
@_wraps(np.mean, skip_params=['out'])
def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=False, *, where=None):
_check_arraylike("mean", a)
@ -2123,7 +2128,7 @@ def average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,
return avg
@_wraps(np.var)
@_wraps(np.var, skip_params=['out'])
def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, ddof=0, keepdims=False, *, where=None):
_check_arraylike("var", a)
@ -2173,7 +2178,7 @@ def _var_promote_types(a_dtype, dtype):
return a_dtype, dtype
@_wraps(np.std)
@_wraps(np.std, skip_params=['out'])
def std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, ddof=0, keepdims=False, *, where=None):
_check_arraylike("std", a)
@ -2183,7 +2188,7 @@ def std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
return sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
@_wraps(np.ptp)
@_wraps(np.ptp, skip_params=['out'])
def ptp(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=False):
_check_arraylike("ptp", a)
@ -2241,31 +2246,31 @@ def _nan_reduction(a, name, jnp_reduction, init_val, nan_if_all_nan,
else:
return out
@_wraps(np.nanmin)
@_wraps(np.nanmin, skip_params=['out'])
def nanmin(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None):
return _nan_reduction(a, 'nanmin', min, inf, nan_if_all_nan=True,
axis=axis, out=out, keepdims=keepdims)
@_wraps(np.nanmax)
@_wraps(np.nanmax, skip_params=['out'])
def nanmax(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None):
return _nan_reduction(a, 'nanmax', max, -inf, nan_if_all_nan=True,
axis=axis, out=out, keepdims=keepdims)
@_wraps(np.nansum)
@_wraps(np.nansum, skip_params=['out'])
def nansum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None):
return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False,
axis=axis, dtype=dtype, out=out, keepdims=keepdims)
@_wraps(np.nanprod)
@_wraps(np.nanprod, skip_params=['out'])
def nanprod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None):
return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False,
axis=axis, dtype=dtype, out=out, keepdims=keepdims)
@_wraps(np.nanmean)
@_wraps(np.nanmean, skip_params=['out'])
def nanmean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=False):
_check_arraylike("nanmean", a)
@ -2283,7 +2288,7 @@ def nanmean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
return td
@_wraps(np.nanvar)
@_wraps(np.nanvar, skip_params=['out'])
def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, ddof=0, keepdims=False):
_check_arraylike("nanvar", a)
@ -2314,7 +2319,7 @@ def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
return lax.convert_element_type(out, dtype)
@_wraps(np.nanstd)
@_wraps(np.nanstd, skip_params=['out'])
def nanstd(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, ddof=0, keepdims=False):
_check_arraylike("nanstd", a)
@ -2349,7 +2354,7 @@ def _make_cumulative_reduction(np_reduction, reduction, fill_nan=False, fill_val
return reduction(a, axis)
@_wraps(np_reduction)
@_wraps(np_reduction, skip_params=['out'])
def cumulative_reduction(a,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None):
@ -2699,7 +2704,7 @@ def pad(array, pad_width, mode="constant", **kwargs):
return _pad(array, pad_width, mode, constant_values, stat_length, end_values, reflect_type)
@_wraps(np.stack)
@_wraps(np.stack, skip_params=['out'])
def stack(arrays, axis: int =0, out=None):
if not len(arrays):
raise ValueError("Need at least one array to stack.")
@ -2783,7 +2788,7 @@ def column_stack(tup):
return concatenate(arrays, 1)
@_wraps(np.choose)
@_wraps(np.choose, skip_params=['out'])
def choose(a, choices, out=None, mode='raise'):
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
@ -3364,7 +3369,7 @@ def triu(m, k=0):
return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m)
@_wraps(np.trace)
@_wraps(np.trace, skip_params=['out'])
def trace(a, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, out=None):
_check_arraylike("trace", a)
if out is not None:
@ -3793,7 +3798,7 @@ def tensordot(a, b, axes=2, *, precision=None):
precision=precision)
@_wraps(np.einsum, lax_description=_PRECISION_DOC)
@_wraps(np.einsum, lax_description=_PRECISION_DOC, skip_params=['out'])
def einsum(*operands, out=None, optimize='greedy', precision=None,
_use_xeinsum=False):
if out is not None:
@ -3970,7 +3975,7 @@ def inner(a, b, *, precision=None):
return tensordot(a, b, (-1, -1), precision=precision)
@_wraps(np.outer)
@_wraps(np.outer, skip_params=['out'])
def outer(a, b, out=None):
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.outer is not supported.")
@ -4049,7 +4054,7 @@ def argwhere(a):
return result.reshape(result.shape[0], ndim(a))
@_wraps(np.argmax)
@_wraps(np.argmax, skip_params=['out'])
def argmax(a, axis: Optional[int] = None, out=None):
_check_arraylike("argmax", a)
if out is not None:
@ -4061,7 +4066,7 @@ def argmax(a, axis: Optional[int] = None, out=None):
raise ValueError("attempt to get argmax of an empty sequence")
return lax.argmax(a, _canonicalize_axis(axis, a.ndim), int64)
@_wraps(np.argmin)
@_wraps(np.argmin, skip_params=['out'])
def argmin(a, axis: Optional[int] = None, out=None):
_check_arraylike("argmin", a)
if out is not None:
@ -4246,7 +4251,7 @@ def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'):
return swapaxes(unpacked, axis, -1)
@_wraps(np.take)
@_wraps(np.take, skip_params=['out'])
def take(a, indices, axis: Optional[int] = None, out=None, mode=None):
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.take is not supported.")
@ -4423,7 +4428,7 @@ def _unique1d(ar, return_index=False, return_inverse=False,
ret += (diff(idx),)
return ret
@_wraps(np.unique)
@_wraps(np.unique, skip_params=['axis'])
def unique(ar, return_index=False, return_inverse=False,
return_counts=False, axis: Optional[int] = None):
ar = core.concrete_or_error(asarray, ar, "The error arose in jnp.unique()")
@ -4914,7 +4919,7 @@ def extract(condition, arr):
return compress(ravel(condition), ravel(arr))
@_wraps(np.compress)
@_wraps(np.compress, skip_params=['out'])
def compress(condition, a, axis: Optional[int] = None, out=None):
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.compress is not supported.")
@ -5020,7 +5025,7 @@ def corrcoef(x, y=None, rowvar=True):
return c
@_wraps(getattr(np, "quantile", None))
@_wraps(getattr(np, "quantile", None), skip_params=['out', 'overwrite_input'])
def quantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
overwrite_input=False, interpolation="linear", keepdims=False):
_check_arraylike("quantile", a, q)
@ -5030,7 +5035,7 @@ def quantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
raise ValueError(msg)
return _quantile(a, q, axis, interpolation, keepdims, False)
@_wraps(getattr(np, "nanquantile", None))
@_wraps(getattr(np, "nanquantile", None), skip_params=['out', 'overwrite_input'])
def nanquantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out=None, overwrite_input=False, interpolation="linear",
keepdims=False):
@ -5158,7 +5163,7 @@ def _searchsorted(a, v, side):
return lax.fori_loop(0, n_levels, body_fun, (0, len(a)))[1]
@_wraps(np.searchsorted)
@_wraps(np.searchsorted, skip_params=['sorter'])
def searchsorted(a, v, side='left', sorter=None):
if side not in ['left', 'right']:
raise ValueError(f"{side!r} is an invalid value for keyword 'side'")
@ -5209,7 +5214,7 @@ def piecewise(x, condlist, funclist, *args, **kw):
return vectorize(lax.switch, excluded=(1,))(indices, funclist, x)
@_wraps(np.percentile)
@_wraps(np.percentile, skip_params=['out', 'overwrite_input'])
def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out=None, overwrite_input=False, interpolation="linear",
keepdims=False):
@ -5218,7 +5223,7 @@ def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, keepdims=keepdims)
@_wraps(np.nanpercentile)
@_wraps(np.nanpercentile, skip_params=['out', 'overwrite_input'])
def nanpercentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out=None, overwrite_input=False, interpolation="linear",
keepdims=False):
@ -5227,14 +5232,14 @@ def nanpercentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, keepdims=keepdims)
@_wraps(np.median)
@_wraps(np.median, skip_params=['out', 'overwrite_input'])
def median(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
overwrite_input=False, keepdims=False):
_check_arraylike("median", a)
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
keepdims=keepdims, interpolation='midpoint')
@_wraps(np.nanmedian)
@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input'])
def nanmedian(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
overwrite_input=False, keepdims=False):
_check_arraylike("nanmedian", a)

View File

@ -86,7 +86,8 @@ def _parse_parameters(body: str) -> Dict[str, str]:
def _wraps(fun: Callable, update_doc: bool = True, lax_description: str = "",
sections: Sequence[str] = ('Parameters', 'Returns', 'References')):
sections: Sequence[str] = ('Parameters', 'Returns', 'References'),
skip_params: Sequence[str] = ()):
"""Specialized version of functools.wraps for wrapping numpy functions.
This produces a wrapped function with a modified docstring. In particular, if
@ -104,6 +105,8 @@ def _wraps(fun: Callable, update_doc: bool = True, lax_description: str = "",
the docstring.
sections: a list of sections to include in the docstring. The default is
["Parameters", "returns", "References"]
skip_params: a list of strings containing names of parameters accepted by the
function that should be skipped in the parameter list.
"""
def wrap(op):
docstr = getattr(fun, "__doc__", None)
@ -118,7 +121,7 @@ def _wraps(fun: Callable, update_doc: bool = True, lax_description: str = "",
"Parameters\n"
"----------\n" +
"\n".join(_versionadded.split(desc)[0].rstrip() for p, desc in parameters.items()
if p in op.__code__.co_varnames)
if p in op.__code__.co_varnames and p not in skip_params)
)
docstr = parsed.summary.strip() + "\n" if parsed.summary else ""