diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 27496ad99..15cbc22df 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -13,11 +13,9 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Sequence from functools import partial -import re -import textwrap -from typing import Any, NamedTuple, TypeVar +from typing import Any import warnings @@ -34,173 +32,6 @@ import numpy as np zip, unsafe_zip = safe_zip, zip map, unsafe_map = safe_map, map -_T = TypeVar("_T") - -_parameter_break = re.compile("\n(?=[A-Za-z_])") -_section_break = re.compile(r"\n(?=[^\n]{3,15}\n-{3,15})", re.MULTILINE) -_numpy_signature_re = re.compile(r'^([\w., ]+=)?\s*[\w\.]+\([\w\W]*?\)$', re.MULTILINE) -_versionadded = re.compile(r'^\s+\.\.\s+versionadded::', re.MULTILINE) -_docreference = re.compile(r':doc:`(.*?)\s*<.*?>`') - -class ParsedDoc(NamedTuple): - """ - docstr: full docstring - signature: signature from docstring. - summary: summary from docstring. - front_matter: front matter before sections. - sections: dictionary of section titles to section content. - """ - docstr: str | None - signature: str = "" - summary: str = "" - front_matter: str = "" - sections: dict[str, str] = {} - - -def _parse_numpydoc(docstr: str | None) -> ParsedDoc: - """Parse a standard numpy-style docstring. - - Args: - docstr: the raw docstring from a function - Returns: - ParsedDoc: parsed version of the docstring - """ - if docstr is None or not docstr.strip(): - return ParsedDoc(docstr) - - # Remove any :doc: directives in the docstring to avoid sphinx errors - docstr = _docreference.sub( - lambda match: f"{match.groups()[0]}", docstr) - - signature, body = "", docstr - match = _numpy_signature_re.match(body) - if match: - signature = match.group() - body = docstr[match.end():] - - firstline, _, body = body.partition('\n') - body = textwrap.dedent(body.lstrip('\n')) - - match = _numpy_signature_re.match(body) - if match: - signature = match.group() - body = body[match.end():] - - summary = firstline - if not summary: - summary, _, body = body.lstrip('\n').partition('\n') - body = textwrap.dedent(body.lstrip('\n')) - - front_matter = "" - body = "\n" + body - section_list = _section_break.split(body) - if not _section_break.match(section_list[0]): - front_matter, *section_list = section_list - sections = {section.split('\n', 1)[0]: section for section in section_list} - - return ParsedDoc(docstr=docstr, signature=signature, summary=summary, - front_matter=front_matter, sections=sections) - - -def _parse_parameters(body: str) -> dict[str, str]: - """Parse the Parameters section of a docstring.""" - title, underline, content = body.split('\n', 2) - assert title == 'Parameters' - assert underline and not underline.strip('-') - parameters = _parameter_break.split(content) - return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters} - - -def implements( - original_fun: Callable[..., Any] | None, - update_doc: bool = True, - sections: Sequence[str] = ('Parameters', 'Returns', 'References'), - module: str | None = None, -) -> Callable[[_T], _T]: - """Decorator for JAX functions which implement a specified NumPy function. - - This mainly contains logic to copy and modify the docstring of the original - function. In particular, if `update_doc` is True, parameters listed in the - original function that are not supported by the decorated function will - be removed from the docstring. For this reason, it is important that parameter - names match those in the original numpy function. - - Args: - original_fun: The original function being implemented - update_doc: whether to transform the numpy docstring to remove references of - parameters that are supported by the numpy version but not the JAX version. - If False, include the numpy docstring verbatim. - sections: a list of sections to include in the docstring. The default is - ["Parameters", "Returns", "References"] - module: an optional string specifying the module from which the original function - is imported. This is useful for objects such as ufuncs, where the module cannot - be determined from the original function itself. - """ - def decorator(wrapped_fun): - wrapped_fun.__np_wrapped__ = original_fun - # Allows this pattern: @implements(getattr(np, 'new_function', None)) - if original_fun is None: - return wrapped_fun - docstr = getattr(original_fun, "__doc__", None) - name = getattr(original_fun, "__name__", getattr(wrapped_fun, "__name__", str(wrapped_fun))) - try: - mod = module or original_fun.__module__ - except AttributeError: - if config.enable_checks.value: - raise ValueError(f"function {original_fun} defines no __module__; pass module keyword to implements().") - else: - name = f"{mod}.{name}" - if docstr: - try: - parsed = _parse_numpydoc(docstr) - - if update_doc and 'Parameters' in parsed.sections: - code = getattr(getattr(wrapped_fun, "__wrapped__", wrapped_fun), "__code__", None) - # Remove unrecognized parameter descriptions. - parameters = _parse_parameters(parsed.sections['Parameters']) - parameters = {p: desc for p, desc in parameters.items() - if (code is None or p in code.co_varnames)} - if parameters: - parsed.sections['Parameters'] = ( - "Parameters\n" - "----------\n" + - "\n".join(_versionadded.split(desc)[0].rstrip() - for p, desc in parameters.items()) - ) - else: - del parsed.sections['Parameters'] - - docstr = parsed.summary.strip() + "\n" if parsed.summary else "" - docstr += f"\nLAX-backend implementation of :func:`{name}`.\n" - docstr += "\n*Original docstring below.*\n" - - # We remove signatures from the docstrings, because they redundant at best and - # misleading at worst: e.g. JAX wrappers don't implement all ufunc keyword arguments. - # if parsed.signature: - # docstr += "\n" + parsed.signature.strip() + "\n" - - if parsed.front_matter: - docstr += "\n" + parsed.front_matter.strip() + "\n" - kept_sections = (content.strip() for section, content in parsed.sections.items() - if section in sections) - if kept_sections: - docstr += "\n" + "\n\n".join(kept_sections) + "\n" - except: - if config.enable_checks.value: - raise - docstr = original_fun.__doc__ - - wrapped_fun.__doc__ = docstr - for attr in ['__name__', '__qualname__']: - try: - value = getattr(original_fun, attr) - except AttributeError: - pass - else: - setattr(wrapped_fun, attr, value) - return wrapped_fun - return decorator - _dtype = partial(dtypes.dtype, canonicalize=True) def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]: diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index f853c742c..b37237cae 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -51,7 +51,6 @@ from jax._src import deprecations from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal -from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, implements from jax._src.util import safe_zip, NumpyComplexWarning config.parse_flags_with_absl() @@ -6186,9 +6185,114 @@ class NumpySignaturesTest(jtu.JaxTestCase): def testWrappedSignaturesMatch(self): """Test that jax.numpy function signatures match numpy.""" - jnp_funcs = {name: getattr(jnp, name) for name in dir(jnp)} - func_pairs = {name: (fun, fun.__np_wrapped__) for name, fun in jnp_funcs.items() - if getattr(fun, '__np_wrapped__', None) is not None} + # NumPy functions explicitly not implemented in JAX: + skip = {'array2string', + 'asanyarray', + 'asarray_chkfinite', + 'ascontiguousarray', + 'asfortranarray', + 'asmatrix', + 'base_repr', + 'binary_repr', + 'bmat', + 'broadcast', + 'busday_count', + 'busday_offset', + 'busdaycalendar', + 'common_type', + 'copyto', + 'datetime_as_string', + 'datetime_data', + 'errstate', + 'flatiter', + 'format_float_positional', + 'format_float_scientific', + 'fromregex', + 'genfromtxt', + 'get_include', + 'getbufsize', + 'geterr', + 'geterrcall', + 'in1d', + 'info', + 'is_busday', + 'isfortran', + 'isnat', + 'loadtxt', + 'matrix', + 'may_share_memory', + 'memmap', + 'min_scalar_type', + 'mintypecode', + 'ndenumerate', + 'ndindex', + 'nditer', + 'nested_iters', + 'poly1d', + 'put_along_axis', + 'putmask', + 'real_if_close', + 'recarray', + 'record', + 'require', + 'row_stack', + 'savetxt', + 'savez_compressed', + 'setbufsize', + 'seterr', + 'seterrcall', + 'shares_memory', + 'show_config', + 'show_runtime', + 'test', + 'trapz', + 'typename'} + + # symbols removed in NumPy 2.0 + skip |= {'add_docstring', + 'add_newdoc', + 'add_newdoc_ufunc', + 'alltrue', + 'asfarray', + 'byte_bounds', + 'compare_chararrays', + 'cumproduct', + 'deprecate', + 'deprecate_with_doc', + 'disp', + 'fastCopyAndTranspose', + 'find_common_type', + 'get_array_wrap', + 'geterrobj', + 'issctype', + 'issubclass_', + 'issubsctype', + 'lookfor', + 'mat', + 'maximum_sctype', + 'msort', + 'obj2sctype', + 'product', + 'recfromcsv', + 'recfromtxt', + 'round_', + 'safe_eval', + 'sctype2char', + 'set_numeric_ops', + 'set_string_function', + 'seterrobj', + 'sometrue', + 'source', + 'who'} + + self.assertEmpty(skip.intersection(dir(jnp))) + + names = (name for name in dir(np) if not (name.startswith('_') or name in skip)) + names = (name for name in names if callable(getattr(np, name))) + names = {name for name in names if not isinstance(getattr(np, name), type)} + self.assertEmpty(names.difference(dir(jnp))) + + self.assertNotEmpty(names) # TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names. unsupported_params = { @@ -6199,6 +6303,7 @@ class NumpySignaturesTest(jtu.JaxTestCase): 'copy': ['subok'], 'corrcoef': ['ddof', 'bias', 'dtype'], 'cov': ['dtype'], + 'cumulative_prod': ['out'], 'cumulative_sum': ['out'], 'empty_like': ['subok', 'order'], 'einsum': ['kwargs'], @@ -6210,9 +6315,7 @@ class NumpySignaturesTest(jtu.JaxTestCase): 'full': ['order', 'like'], 'full_like': ['subok', 'order'], 'fromfunction': ['like'], - 'histogram': ['normed'], - 'histogram2d': ['normed'], - 'histogramdd': ['normed'], + 'load': ['mmap_mode', 'allow_pickle', 'fix_imports', 'encoding', 'max_header_size'], 'nanpercentile': ['weights'], 'nanquantile': ['weights'], 'nanstd': ['correction', 'mean'], @@ -6222,7 +6325,6 @@ class NumpySignaturesTest(jtu.JaxTestCase): 'partition': ['kind', 'order'], 'percentile': ['weights'], 'quantile': ['weights'], - 'reshape': ['shape', 'copy'], 'row_stack': ['casting'], 'stack': ['casting'], 'std': ['mean'], @@ -6233,18 +6335,19 @@ class NumpySignaturesTest(jtu.JaxTestCase): } extra_params = { - # TODO(micky774): Remove when np.clip has adopted the Array API 2023 - # standard - 'clip': ['x', 'max', 'min'], + 'compress': ['size', 'fill_value'], 'einsum': ['subscripts', 'precision'], 'einsum_path': ['subscripts'], + 'load': ['args', 'kwargs'], 'take_along_axis': ['mode', 'fill_value'], 'fill_diagonal': ['inplace'], } mismatches = {} - for name, (jnp_fun, np_fun) in func_pairs.items(): + for name in names: + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) if name in ['histogram', 'histogram2d', 'histogramdd']: # numpy 1.24 re-orders the density and weights arguments. # TODO(jakevdp): migrate histogram APIs to match newer numpy versions. @@ -6258,12 +6361,15 @@ class NumpySignaturesTest(jtu.JaxTestCase): # TODO(dfm): After our deprecation period for the clip arguments ends # it should be possible to reintroduce the check. continue - # Note: can't use inspect.getfullargspec due to numpy issue + if name == "reshape": + # Similar issue to clip: we'd need logic specific to the NumPy version + # because of the change in argument name from `newshape` to `shape`. + continue + # Note: can't use inspect.getfullargspec for some functions due to numpy issue # https://github.com/numpy/numpy/issues/12225 try: np_params = inspect.signature(np_fun).parameters except ValueError: - # Some functions cannot be inspected continue jnp_params = inspect.signature(jnp_fun).parameters extra = set(extra_params.get(name, [])) @@ -6350,8 +6456,6 @@ class NumpyUfuncTests(jtu.JaxTestCase): class NumpyDocTests(jtu.JaxTestCase): def test_lax_numpy_docstrings(self): - # Test that docstring wrapping & transformation didn't fail. - unimplemented = ['fromfile', 'fromiter'] aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', 'amax', 'amin', 'around', 'bitwise_invert', 'bitwise_left_shift', @@ -6371,15 +6475,6 @@ class NumpyDocTests(jtu.JaxTestCase): elif hasattr(np, name) and obj is getattr(np, name): # Some APIs are imported directly from NumPy; we don't check these. pass - elif hasattr(obj, '__np_wrapped__'): - # Functions decorated with @implements(...) should have __np_wrapped__ - wrapped_fun = obj.__np_wrapped__ - if wrapped_fun is not None: - # If the wrapped function has a docstring, obj should too - if wrapped_fun.__doc__ and not obj.__doc__: - raise Exception(f"jnp.{name} does not contain wrapped docstring.") - if obj.__doc__ and "*Original docstring below.*" not in obj.__doc__: - raise Exception(f"jnp.{name} does not have a wrapped docstring.") elif name in aliases: assert "Alias of" in obj.__doc__ elif name not in skip_args_check: @@ -6391,84 +6486,6 @@ class NumpyDocTests(jtu.JaxTestCase): if name not in ["frompyfunc", "isdtype", "promote_types"]: self.assertIn("Examples:", doc, msg=f"'Examples:' not found in docstring of jnp.{name}") - @parameterized.named_parameters( - {"testcase_name": "_jit" if jit else "", "jit": jit} for jit in [True, False]) - def test_wrapped_function_parameters(self, jit): - def orig(x): - """Example Docstring - - Parameters - ---------- - x : array_like - Input Data - - .. versionadded:: 1.8.0 - out : array_like, optional - Output to overwrite - other_arg : Any - not used - - Returns - ------- - x : input - """ - return x - - def wrapped(x, out=None): - return x - - if jit: - wrapped = jax.jit(wrapped) - - wrapped = implements(orig)(wrapped) - doc = wrapped.__doc__ - - self.assertStartsWith(doc, "Example Docstring") - self.assertIn("Original docstring below", doc) - self.assertIn("Parameters", doc) - self.assertIn("Returns", doc) - self.assertNotIn('other_arg', doc) - self.assertNotIn('versionadded', doc) - - - def test_parse_numpydoc(self): - # Unit test ensuring that _parse_numpydoc correctly parses docstrings for all - # functions in NumPy's top-level namespace. - section_titles = {'Attributes', 'Examples', 'Notes', - 'Parameters', 'Raises', 'References', - 'Returns', 'See also', 'See Also', 'Warnings', 'Warns'} - headings = [title + '\n' + '-'*len(title) for title in section_titles] - - for name in dir(np): - if name.startswith('_'): - continue - obj = getattr(np, name) - if isinstance(obj, type): - continue - if not callable(obj): - continue - if 'built-in function' in repr(obj): - continue - parsed = _parse_numpydoc(obj.__doc__) - - # Check that no docstring is handled gracefully. - if not obj.__doc__: - self.assertEqual(parsed, ParsedDoc(obj.__doc__)) - continue - - # Check that no unexpected section names are found. - extra_keys = parsed.sections.keys() - section_titles - if extra_keys: - raise ValueError(f"Extra section headers found in np.{name}: {extra_keys}") - - # Check that every docstring has a summary. - if not parsed.summary: - raise ValueError(f"No summary found for np.{name}") - - # Check that no expected headings are missed. - for heading in headings: - assert heading not in parsed.front_matter - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())