Remove obsolete implements() decorator & fix tests

This commit is contained in:
Jake VanderPlas 2024-10-28 15:22:09 -07:00
parent e82d5a973b
commit 14030801a5
2 changed files with 122 additions and 274 deletions

View File

@ -13,11 +13,9 @@
# limitations under the License.
from __future__ import annotations
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from functools import partial
import re
import textwrap
from typing import Any, NamedTuple, TypeVar
from typing import Any
import warnings
@ -34,173 +32,6 @@ import numpy as np
zip, unsafe_zip = safe_zip, zip
map, unsafe_map = safe_map, map
_T = TypeVar("_T")
_parameter_break = re.compile("\n(?=[A-Za-z_])")
_section_break = re.compile(r"\n(?=[^\n]{3,15}\n-{3,15})", re.MULTILINE)
_numpy_signature_re = re.compile(r'^([\w., ]+=)?\s*[\w\.]+\([\w\W]*?\)$', re.MULTILINE)
_versionadded = re.compile(r'^\s+\.\.\s+versionadded::', re.MULTILINE)
_docreference = re.compile(r':doc:`(.*?)\s*<.*?>`')
class ParsedDoc(NamedTuple):
"""
docstr: full docstring
signature: signature from docstring.
summary: summary from docstring.
front_matter: front matter before sections.
sections: dictionary of section titles to section content.
"""
docstr: str | None
signature: str = ""
summary: str = ""
front_matter: str = ""
sections: dict[str, str] = {}
def _parse_numpydoc(docstr: str | None) -> ParsedDoc:
"""Parse a standard numpy-style docstring.
Args:
docstr: the raw docstring from a function
Returns:
ParsedDoc: parsed version of the docstring
"""
if docstr is None or not docstr.strip():
return ParsedDoc(docstr)
# Remove any :doc: directives in the docstring to avoid sphinx errors
docstr = _docreference.sub(
lambda match: f"{match.groups()[0]}", docstr)
signature, body = "", docstr
match = _numpy_signature_re.match(body)
if match:
signature = match.group()
body = docstr[match.end():]
firstline, _, body = body.partition('\n')
body = textwrap.dedent(body.lstrip('\n'))
match = _numpy_signature_re.match(body)
if match:
signature = match.group()
body = body[match.end():]
summary = firstline
if not summary:
summary, _, body = body.lstrip('\n').partition('\n')
body = textwrap.dedent(body.lstrip('\n'))
front_matter = ""
body = "\n" + body
section_list = _section_break.split(body)
if not _section_break.match(section_list[0]):
front_matter, *section_list = section_list
sections = {section.split('\n', 1)[0]: section for section in section_list}
return ParsedDoc(docstr=docstr, signature=signature, summary=summary,
front_matter=front_matter, sections=sections)
def _parse_parameters(body: str) -> dict[str, str]:
"""Parse the Parameters section of a docstring."""
title, underline, content = body.split('\n', 2)
assert title == 'Parameters'
assert underline and not underline.strip('-')
parameters = _parameter_break.split(content)
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters}
def implements(
original_fun: Callable[..., Any] | None,
update_doc: bool = True,
sections: Sequence[str] = ('Parameters', 'Returns', 'References'),
module: str | None = None,
) -> Callable[[_T], _T]:
"""Decorator for JAX functions which implement a specified NumPy function.
This mainly contains logic to copy and modify the docstring of the original
function. In particular, if `update_doc` is True, parameters listed in the
original function that are not supported by the decorated function will
be removed from the docstring. For this reason, it is important that parameter
names match those in the original numpy function.
Args:
original_fun: The original function being implemented
update_doc: whether to transform the numpy docstring to remove references of
parameters that are supported by the numpy version but not the JAX version.
If False, include the numpy docstring verbatim.
sections: a list of sections to include in the docstring. The default is
["Parameters", "Returns", "References"]
module: an optional string specifying the module from which the original function
is imported. This is useful for objects such as ufuncs, where the module cannot
be determined from the original function itself.
"""
def decorator(wrapped_fun):
wrapped_fun.__np_wrapped__ = original_fun
# Allows this pattern: @implements(getattr(np, 'new_function', None))
if original_fun is None:
return wrapped_fun
docstr = getattr(original_fun, "__doc__", None)
name = getattr(original_fun, "__name__", getattr(wrapped_fun, "__name__", str(wrapped_fun)))
try:
mod = module or original_fun.__module__
except AttributeError:
if config.enable_checks.value:
raise ValueError(f"function {original_fun} defines no __module__; pass module keyword to implements().")
else:
name = f"{mod}.{name}"
if docstr:
try:
parsed = _parse_numpydoc(docstr)
if update_doc and 'Parameters' in parsed.sections:
code = getattr(getattr(wrapped_fun, "__wrapped__", wrapped_fun), "__code__", None)
# Remove unrecognized parameter descriptions.
parameters = _parse_parameters(parsed.sections['Parameters'])
parameters = {p: desc for p, desc in parameters.items()
if (code is None or p in code.co_varnames)}
if parameters:
parsed.sections['Parameters'] = (
"Parameters\n"
"----------\n" +
"\n".join(_versionadded.split(desc)[0].rstrip()
for p, desc in parameters.items())
)
else:
del parsed.sections['Parameters']
docstr = parsed.summary.strip() + "\n" if parsed.summary else ""
docstr += f"\nLAX-backend implementation of :func:`{name}`.\n"
docstr += "\n*Original docstring below.*\n"
# We remove signatures from the docstrings, because they redundant at best and
# misleading at worst: e.g. JAX wrappers don't implement all ufunc keyword arguments.
# if parsed.signature:
# docstr += "\n" + parsed.signature.strip() + "\n"
if parsed.front_matter:
docstr += "\n" + parsed.front_matter.strip() + "\n"
kept_sections = (content.strip() for section, content in parsed.sections.items()
if section in sections)
if kept_sections:
docstr += "\n" + "\n\n".join(kept_sections) + "\n"
except:
if config.enable_checks.value:
raise
docstr = original_fun.__doc__
wrapped_fun.__doc__ = docstr
for attr in ['__name__', '__qualname__']:
try:
value = getattr(original_fun, attr)
except AttributeError:
pass
else:
setattr(wrapped_fun, attr, value)
return wrapped_fun
return decorator
_dtype = partial(dtypes.dtype, canonicalize=True)
def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:

View File

@ -51,7 +51,6 @@ from jax._src import deprecations
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, implements
from jax._src.util import safe_zip, NumpyComplexWarning
config.parse_flags_with_absl()
@ -6186,9 +6185,114 @@ class NumpySignaturesTest(jtu.JaxTestCase):
def testWrappedSignaturesMatch(self):
"""Test that jax.numpy function signatures match numpy."""
jnp_funcs = {name: getattr(jnp, name) for name in dir(jnp)}
func_pairs = {name: (fun, fun.__np_wrapped__) for name, fun in jnp_funcs.items()
if getattr(fun, '__np_wrapped__', None) is not None}
# NumPy functions explicitly not implemented in JAX:
skip = {'array2string',
'asanyarray',
'asarray_chkfinite',
'ascontiguousarray',
'asfortranarray',
'asmatrix',
'base_repr',
'binary_repr',
'bmat',
'broadcast',
'busday_count',
'busday_offset',
'busdaycalendar',
'common_type',
'copyto',
'datetime_as_string',
'datetime_data',
'errstate',
'flatiter',
'format_float_positional',
'format_float_scientific',
'fromregex',
'genfromtxt',
'get_include',
'getbufsize',
'geterr',
'geterrcall',
'in1d',
'info',
'is_busday',
'isfortran',
'isnat',
'loadtxt',
'matrix',
'may_share_memory',
'memmap',
'min_scalar_type',
'mintypecode',
'ndenumerate',
'ndindex',
'nditer',
'nested_iters',
'poly1d',
'put_along_axis',
'putmask',
'real_if_close',
'recarray',
'record',
'require',
'row_stack',
'savetxt',
'savez_compressed',
'setbufsize',
'seterr',
'seterrcall',
'shares_memory',
'show_config',
'show_runtime',
'test',
'trapz',
'typename'}
# symbols removed in NumPy 2.0
skip |= {'add_docstring',
'add_newdoc',
'add_newdoc_ufunc',
'alltrue',
'asfarray',
'byte_bounds',
'compare_chararrays',
'cumproduct',
'deprecate',
'deprecate_with_doc',
'disp',
'fastCopyAndTranspose',
'find_common_type',
'get_array_wrap',
'geterrobj',
'issctype',
'issubclass_',
'issubsctype',
'lookfor',
'mat',
'maximum_sctype',
'msort',
'obj2sctype',
'product',
'recfromcsv',
'recfromtxt',
'round_',
'safe_eval',
'sctype2char',
'set_numeric_ops',
'set_string_function',
'seterrobj',
'sometrue',
'source',
'who'}
self.assertEmpty(skip.intersection(dir(jnp)))
names = (name for name in dir(np) if not (name.startswith('_') or name in skip))
names = (name for name in names if callable(getattr(np, name)))
names = {name for name in names if not isinstance(getattr(np, name), type)}
self.assertEmpty(names.difference(dir(jnp)))
self.assertNotEmpty(names)
# TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names.
unsupported_params = {
@ -6199,6 +6303,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
'copy': ['subok'],
'corrcoef': ['ddof', 'bias', 'dtype'],
'cov': ['dtype'],
'cumulative_prod': ['out'],
'cumulative_sum': ['out'],
'empty_like': ['subok', 'order'],
'einsum': ['kwargs'],
@ -6210,9 +6315,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
'full': ['order', 'like'],
'full_like': ['subok', 'order'],
'fromfunction': ['like'],
'histogram': ['normed'],
'histogram2d': ['normed'],
'histogramdd': ['normed'],
'load': ['mmap_mode', 'allow_pickle', 'fix_imports', 'encoding', 'max_header_size'],
'nanpercentile': ['weights'],
'nanquantile': ['weights'],
'nanstd': ['correction', 'mean'],
@ -6222,7 +6325,6 @@ class NumpySignaturesTest(jtu.JaxTestCase):
'partition': ['kind', 'order'],
'percentile': ['weights'],
'quantile': ['weights'],
'reshape': ['shape', 'copy'],
'row_stack': ['casting'],
'stack': ['casting'],
'std': ['mean'],
@ -6233,18 +6335,19 @@ class NumpySignaturesTest(jtu.JaxTestCase):
}
extra_params = {
# TODO(micky774): Remove when np.clip has adopted the Array API 2023
# standard
'clip': ['x', 'max', 'min'],
'compress': ['size', 'fill_value'],
'einsum': ['subscripts', 'precision'],
'einsum_path': ['subscripts'],
'load': ['args', 'kwargs'],
'take_along_axis': ['mode', 'fill_value'],
'fill_diagonal': ['inplace'],
}
mismatches = {}
for name, (jnp_fun, np_fun) in func_pairs.items():
for name in names:
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if name in ['histogram', 'histogram2d', 'histogramdd']:
# numpy 1.24 re-orders the density and weights arguments.
# TODO(jakevdp): migrate histogram APIs to match newer numpy versions.
@ -6258,12 +6361,15 @@ class NumpySignaturesTest(jtu.JaxTestCase):
# TODO(dfm): After our deprecation period for the clip arguments ends
# it should be possible to reintroduce the check.
continue
# Note: can't use inspect.getfullargspec due to numpy issue
if name == "reshape":
# Similar issue to clip: we'd need logic specific to the NumPy version
# because of the change in argument name from `newshape` to `shape`.
continue
# Note: can't use inspect.getfullargspec for some functions due to numpy issue
# https://github.com/numpy/numpy/issues/12225
try:
np_params = inspect.signature(np_fun).parameters
except ValueError:
# Some functions cannot be inspected
continue
jnp_params = inspect.signature(jnp_fun).parameters
extra = set(extra_params.get(name, []))
@ -6350,8 +6456,6 @@ class NumpyUfuncTests(jtu.JaxTestCase):
class NumpyDocTests(jtu.JaxTestCase):
def test_lax_numpy_docstrings(self):
# Test that docstring wrapping & transformation didn't fail.
unimplemented = ['fromfile', 'fromiter']
aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2',
'amax', 'amin', 'around', 'bitwise_invert', 'bitwise_left_shift',
@ -6371,15 +6475,6 @@ class NumpyDocTests(jtu.JaxTestCase):
elif hasattr(np, name) and obj is getattr(np, name):
# Some APIs are imported directly from NumPy; we don't check these.
pass
elif hasattr(obj, '__np_wrapped__'):
# Functions decorated with @implements(...) should have __np_wrapped__
wrapped_fun = obj.__np_wrapped__
if wrapped_fun is not None:
# If the wrapped function has a docstring, obj should too
if wrapped_fun.__doc__ and not obj.__doc__:
raise Exception(f"jnp.{name} does not contain wrapped docstring.")
if obj.__doc__ and "*Original docstring below.*" not in obj.__doc__:
raise Exception(f"jnp.{name} does not have a wrapped docstring.")
elif name in aliases:
assert "Alias of" in obj.__doc__
elif name not in skip_args_check:
@ -6391,84 +6486,6 @@ class NumpyDocTests(jtu.JaxTestCase):
if name not in ["frompyfunc", "isdtype", "promote_types"]:
self.assertIn("Examples:", doc, msg=f"'Examples:' not found in docstring of jnp.{name}")
@parameterized.named_parameters(
{"testcase_name": "_jit" if jit else "", "jit": jit} for jit in [True, False])
def test_wrapped_function_parameters(self, jit):
def orig(x):
"""Example Docstring
Parameters
----------
x : array_like
Input Data
.. versionadded:: 1.8.0
out : array_like, optional
Output to overwrite
other_arg : Any
not used
Returns
-------
x : input
"""
return x
def wrapped(x, out=None):
return x
if jit:
wrapped = jax.jit(wrapped)
wrapped = implements(orig)(wrapped)
doc = wrapped.__doc__
self.assertStartsWith(doc, "Example Docstring")
self.assertIn("Original docstring below", doc)
self.assertIn("Parameters", doc)
self.assertIn("Returns", doc)
self.assertNotIn('other_arg', doc)
self.assertNotIn('versionadded', doc)
def test_parse_numpydoc(self):
# Unit test ensuring that _parse_numpydoc correctly parses docstrings for all
# functions in NumPy's top-level namespace.
section_titles = {'Attributes', 'Examples', 'Notes',
'Parameters', 'Raises', 'References',
'Returns', 'See also', 'See Also', 'Warnings', 'Warns'}
headings = [title + '\n' + '-'*len(title) for title in section_titles]
for name in dir(np):
if name.startswith('_'):
continue
obj = getattr(np, name)
if isinstance(obj, type):
continue
if not callable(obj):
continue
if 'built-in function' in repr(obj):
continue
parsed = _parse_numpydoc(obj.__doc__)
# Check that no docstring is handled gracefully.
if not obj.__doc__:
self.assertEqual(parsed, ParsedDoc(obj.__doc__))
continue
# Check that no unexpected section names are found.
extra_keys = parsed.sections.keys() - section_titles
if extra_keys:
raise ValueError(f"Extra section headers found in np.{name}: {extra_keys}")
# Check that every docstring has a summary.
if not parsed.summary:
raise ValueError(f"No summary found for np.{name}")
# Check that no expected headings are missed.
for heading in headings:
assert heading not in parsed.front_matter
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())