mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove obsolete implements() decorator & fix tests
This commit is contained in:
parent
e82d5a973b
commit
14030801a5
@ -13,11 +13,9 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
import re
|
||||
import textwrap
|
||||
from typing import Any, NamedTuple, TypeVar
|
||||
from typing import Any
|
||||
|
||||
import warnings
|
||||
|
||||
@ -34,173 +32,6 @@ import numpy as np
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
map, unsafe_map = safe_map, map
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
_parameter_break = re.compile("\n(?=[A-Za-z_])")
|
||||
_section_break = re.compile(r"\n(?=[^\n]{3,15}\n-{3,15})", re.MULTILINE)
|
||||
_numpy_signature_re = re.compile(r'^([\w., ]+=)?\s*[\w\.]+\([\w\W]*?\)$', re.MULTILINE)
|
||||
_versionadded = re.compile(r'^\s+\.\.\s+versionadded::', re.MULTILINE)
|
||||
_docreference = re.compile(r':doc:`(.*?)\s*<.*?>`')
|
||||
|
||||
class ParsedDoc(NamedTuple):
|
||||
"""
|
||||
docstr: full docstring
|
||||
signature: signature from docstring.
|
||||
summary: summary from docstring.
|
||||
front_matter: front matter before sections.
|
||||
sections: dictionary of section titles to section content.
|
||||
"""
|
||||
docstr: str | None
|
||||
signature: str = ""
|
||||
summary: str = ""
|
||||
front_matter: str = ""
|
||||
sections: dict[str, str] = {}
|
||||
|
||||
|
||||
def _parse_numpydoc(docstr: str | None) -> ParsedDoc:
|
||||
"""Parse a standard numpy-style docstring.
|
||||
|
||||
Args:
|
||||
docstr: the raw docstring from a function
|
||||
Returns:
|
||||
ParsedDoc: parsed version of the docstring
|
||||
"""
|
||||
if docstr is None or not docstr.strip():
|
||||
return ParsedDoc(docstr)
|
||||
|
||||
# Remove any :doc: directives in the docstring to avoid sphinx errors
|
||||
docstr = _docreference.sub(
|
||||
lambda match: f"{match.groups()[0]}", docstr)
|
||||
|
||||
signature, body = "", docstr
|
||||
match = _numpy_signature_re.match(body)
|
||||
if match:
|
||||
signature = match.group()
|
||||
body = docstr[match.end():]
|
||||
|
||||
firstline, _, body = body.partition('\n')
|
||||
body = textwrap.dedent(body.lstrip('\n'))
|
||||
|
||||
match = _numpy_signature_re.match(body)
|
||||
if match:
|
||||
signature = match.group()
|
||||
body = body[match.end():]
|
||||
|
||||
summary = firstline
|
||||
if not summary:
|
||||
summary, _, body = body.lstrip('\n').partition('\n')
|
||||
body = textwrap.dedent(body.lstrip('\n'))
|
||||
|
||||
front_matter = ""
|
||||
body = "\n" + body
|
||||
section_list = _section_break.split(body)
|
||||
if not _section_break.match(section_list[0]):
|
||||
front_matter, *section_list = section_list
|
||||
sections = {section.split('\n', 1)[0]: section for section in section_list}
|
||||
|
||||
return ParsedDoc(docstr=docstr, signature=signature, summary=summary,
|
||||
front_matter=front_matter, sections=sections)
|
||||
|
||||
|
||||
def _parse_parameters(body: str) -> dict[str, str]:
|
||||
"""Parse the Parameters section of a docstring."""
|
||||
title, underline, content = body.split('\n', 2)
|
||||
assert title == 'Parameters'
|
||||
assert underline and not underline.strip('-')
|
||||
parameters = _parameter_break.split(content)
|
||||
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters}
|
||||
|
||||
|
||||
def implements(
|
||||
original_fun: Callable[..., Any] | None,
|
||||
update_doc: bool = True,
|
||||
sections: Sequence[str] = ('Parameters', 'Returns', 'References'),
|
||||
module: str | None = None,
|
||||
) -> Callable[[_T], _T]:
|
||||
"""Decorator for JAX functions which implement a specified NumPy function.
|
||||
|
||||
This mainly contains logic to copy and modify the docstring of the original
|
||||
function. In particular, if `update_doc` is True, parameters listed in the
|
||||
original function that are not supported by the decorated function will
|
||||
be removed from the docstring. For this reason, it is important that parameter
|
||||
names match those in the original numpy function.
|
||||
|
||||
Args:
|
||||
original_fun: The original function being implemented
|
||||
update_doc: whether to transform the numpy docstring to remove references of
|
||||
parameters that are supported by the numpy version but not the JAX version.
|
||||
If False, include the numpy docstring verbatim.
|
||||
sections: a list of sections to include in the docstring. The default is
|
||||
["Parameters", "Returns", "References"]
|
||||
module: an optional string specifying the module from which the original function
|
||||
is imported. This is useful for objects such as ufuncs, where the module cannot
|
||||
be determined from the original function itself.
|
||||
"""
|
||||
def decorator(wrapped_fun):
|
||||
wrapped_fun.__np_wrapped__ = original_fun
|
||||
# Allows this pattern: @implements(getattr(np, 'new_function', None))
|
||||
if original_fun is None:
|
||||
return wrapped_fun
|
||||
docstr = getattr(original_fun, "__doc__", None)
|
||||
name = getattr(original_fun, "__name__", getattr(wrapped_fun, "__name__", str(wrapped_fun)))
|
||||
try:
|
||||
mod = module or original_fun.__module__
|
||||
except AttributeError:
|
||||
if config.enable_checks.value:
|
||||
raise ValueError(f"function {original_fun} defines no __module__; pass module keyword to implements().")
|
||||
else:
|
||||
name = f"{mod}.{name}"
|
||||
if docstr:
|
||||
try:
|
||||
parsed = _parse_numpydoc(docstr)
|
||||
|
||||
if update_doc and 'Parameters' in parsed.sections:
|
||||
code = getattr(getattr(wrapped_fun, "__wrapped__", wrapped_fun), "__code__", None)
|
||||
# Remove unrecognized parameter descriptions.
|
||||
parameters = _parse_parameters(parsed.sections['Parameters'])
|
||||
parameters = {p: desc for p, desc in parameters.items()
|
||||
if (code is None or p in code.co_varnames)}
|
||||
if parameters:
|
||||
parsed.sections['Parameters'] = (
|
||||
"Parameters\n"
|
||||
"----------\n" +
|
||||
"\n".join(_versionadded.split(desc)[0].rstrip()
|
||||
for p, desc in parameters.items())
|
||||
)
|
||||
else:
|
||||
del parsed.sections['Parameters']
|
||||
|
||||
docstr = parsed.summary.strip() + "\n" if parsed.summary else ""
|
||||
docstr += f"\nLAX-backend implementation of :func:`{name}`.\n"
|
||||
docstr += "\n*Original docstring below.*\n"
|
||||
|
||||
# We remove signatures from the docstrings, because they redundant at best and
|
||||
# misleading at worst: e.g. JAX wrappers don't implement all ufunc keyword arguments.
|
||||
# if parsed.signature:
|
||||
# docstr += "\n" + parsed.signature.strip() + "\n"
|
||||
|
||||
if parsed.front_matter:
|
||||
docstr += "\n" + parsed.front_matter.strip() + "\n"
|
||||
kept_sections = (content.strip() for section, content in parsed.sections.items()
|
||||
if section in sections)
|
||||
if kept_sections:
|
||||
docstr += "\n" + "\n\n".join(kept_sections) + "\n"
|
||||
except:
|
||||
if config.enable_checks.value:
|
||||
raise
|
||||
docstr = original_fun.__doc__
|
||||
|
||||
wrapped_fun.__doc__ = docstr
|
||||
for attr in ['__name__', '__qualname__']:
|
||||
try:
|
||||
value = getattr(original_fun, attr)
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
setattr(wrapped_fun, attr, value)
|
||||
return wrapped_fun
|
||||
return decorator
|
||||
|
||||
_dtype = partial(dtypes.dtype, canonicalize=True)
|
||||
|
||||
def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
|
||||
|
@ -51,7 +51,6 @@ from jax._src import deprecations
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, implements
|
||||
from jax._src.util import safe_zip, NumpyComplexWarning
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -6186,9 +6185,114 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
|
||||
def testWrappedSignaturesMatch(self):
|
||||
"""Test that jax.numpy function signatures match numpy."""
|
||||
jnp_funcs = {name: getattr(jnp, name) for name in dir(jnp)}
|
||||
func_pairs = {name: (fun, fun.__np_wrapped__) for name, fun in jnp_funcs.items()
|
||||
if getattr(fun, '__np_wrapped__', None) is not None}
|
||||
# NumPy functions explicitly not implemented in JAX:
|
||||
skip = {'array2string',
|
||||
'asanyarray',
|
||||
'asarray_chkfinite',
|
||||
'ascontiguousarray',
|
||||
'asfortranarray',
|
||||
'asmatrix',
|
||||
'base_repr',
|
||||
'binary_repr',
|
||||
'bmat',
|
||||
'broadcast',
|
||||
'busday_count',
|
||||
'busday_offset',
|
||||
'busdaycalendar',
|
||||
'common_type',
|
||||
'copyto',
|
||||
'datetime_as_string',
|
||||
'datetime_data',
|
||||
'errstate',
|
||||
'flatiter',
|
||||
'format_float_positional',
|
||||
'format_float_scientific',
|
||||
'fromregex',
|
||||
'genfromtxt',
|
||||
'get_include',
|
||||
'getbufsize',
|
||||
'geterr',
|
||||
'geterrcall',
|
||||
'in1d',
|
||||
'info',
|
||||
'is_busday',
|
||||
'isfortran',
|
||||
'isnat',
|
||||
'loadtxt',
|
||||
'matrix',
|
||||
'may_share_memory',
|
||||
'memmap',
|
||||
'min_scalar_type',
|
||||
'mintypecode',
|
||||
'ndenumerate',
|
||||
'ndindex',
|
||||
'nditer',
|
||||
'nested_iters',
|
||||
'poly1d',
|
||||
'put_along_axis',
|
||||
'putmask',
|
||||
'real_if_close',
|
||||
'recarray',
|
||||
'record',
|
||||
'require',
|
||||
'row_stack',
|
||||
'savetxt',
|
||||
'savez_compressed',
|
||||
'setbufsize',
|
||||
'seterr',
|
||||
'seterrcall',
|
||||
'shares_memory',
|
||||
'show_config',
|
||||
'show_runtime',
|
||||
'test',
|
||||
'trapz',
|
||||
'typename'}
|
||||
|
||||
# symbols removed in NumPy 2.0
|
||||
skip |= {'add_docstring',
|
||||
'add_newdoc',
|
||||
'add_newdoc_ufunc',
|
||||
'alltrue',
|
||||
'asfarray',
|
||||
'byte_bounds',
|
||||
'compare_chararrays',
|
||||
'cumproduct',
|
||||
'deprecate',
|
||||
'deprecate_with_doc',
|
||||
'disp',
|
||||
'fastCopyAndTranspose',
|
||||
'find_common_type',
|
||||
'get_array_wrap',
|
||||
'geterrobj',
|
||||
'issctype',
|
||||
'issubclass_',
|
||||
'issubsctype',
|
||||
'lookfor',
|
||||
'mat',
|
||||
'maximum_sctype',
|
||||
'msort',
|
||||
'obj2sctype',
|
||||
'product',
|
||||
'recfromcsv',
|
||||
'recfromtxt',
|
||||
'round_',
|
||||
'safe_eval',
|
||||
'sctype2char',
|
||||
'set_numeric_ops',
|
||||
'set_string_function',
|
||||
'seterrobj',
|
||||
'sometrue',
|
||||
'source',
|
||||
'who'}
|
||||
|
||||
self.assertEmpty(skip.intersection(dir(jnp)))
|
||||
|
||||
names = (name for name in dir(np) if not (name.startswith('_') or name in skip))
|
||||
names = (name for name in names if callable(getattr(np, name)))
|
||||
names = {name for name in names if not isinstance(getattr(np, name), type)}
|
||||
self.assertEmpty(names.difference(dir(jnp)))
|
||||
|
||||
self.assertNotEmpty(names)
|
||||
|
||||
# TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names.
|
||||
unsupported_params = {
|
||||
@ -6199,6 +6303,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
'copy': ['subok'],
|
||||
'corrcoef': ['ddof', 'bias', 'dtype'],
|
||||
'cov': ['dtype'],
|
||||
'cumulative_prod': ['out'],
|
||||
'cumulative_sum': ['out'],
|
||||
'empty_like': ['subok', 'order'],
|
||||
'einsum': ['kwargs'],
|
||||
@ -6210,9 +6315,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
'full': ['order', 'like'],
|
||||
'full_like': ['subok', 'order'],
|
||||
'fromfunction': ['like'],
|
||||
'histogram': ['normed'],
|
||||
'histogram2d': ['normed'],
|
||||
'histogramdd': ['normed'],
|
||||
'load': ['mmap_mode', 'allow_pickle', 'fix_imports', 'encoding', 'max_header_size'],
|
||||
'nanpercentile': ['weights'],
|
||||
'nanquantile': ['weights'],
|
||||
'nanstd': ['correction', 'mean'],
|
||||
@ -6222,7 +6325,6 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
'partition': ['kind', 'order'],
|
||||
'percentile': ['weights'],
|
||||
'quantile': ['weights'],
|
||||
'reshape': ['shape', 'copy'],
|
||||
'row_stack': ['casting'],
|
||||
'stack': ['casting'],
|
||||
'std': ['mean'],
|
||||
@ -6233,18 +6335,19 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
}
|
||||
|
||||
extra_params = {
|
||||
# TODO(micky774): Remove when np.clip has adopted the Array API 2023
|
||||
# standard
|
||||
'clip': ['x', 'max', 'min'],
|
||||
'compress': ['size', 'fill_value'],
|
||||
'einsum': ['subscripts', 'precision'],
|
||||
'einsum_path': ['subscripts'],
|
||||
'load': ['args', 'kwargs'],
|
||||
'take_along_axis': ['mode', 'fill_value'],
|
||||
'fill_diagonal': ['inplace'],
|
||||
}
|
||||
|
||||
mismatches = {}
|
||||
|
||||
for name, (jnp_fun, np_fun) in func_pairs.items():
|
||||
for name in names:
|
||||
jnp_fun = getattr(jnp, name)
|
||||
np_fun = getattr(np, name)
|
||||
if name in ['histogram', 'histogram2d', 'histogramdd']:
|
||||
# numpy 1.24 re-orders the density and weights arguments.
|
||||
# TODO(jakevdp): migrate histogram APIs to match newer numpy versions.
|
||||
@ -6258,12 +6361,15 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
# TODO(dfm): After our deprecation period for the clip arguments ends
|
||||
# it should be possible to reintroduce the check.
|
||||
continue
|
||||
# Note: can't use inspect.getfullargspec due to numpy issue
|
||||
if name == "reshape":
|
||||
# Similar issue to clip: we'd need logic specific to the NumPy version
|
||||
# because of the change in argument name from `newshape` to `shape`.
|
||||
continue
|
||||
# Note: can't use inspect.getfullargspec for some functions due to numpy issue
|
||||
# https://github.com/numpy/numpy/issues/12225
|
||||
try:
|
||||
np_params = inspect.signature(np_fun).parameters
|
||||
except ValueError:
|
||||
# Some functions cannot be inspected
|
||||
continue
|
||||
jnp_params = inspect.signature(jnp_fun).parameters
|
||||
extra = set(extra_params.get(name, []))
|
||||
@ -6350,8 +6456,6 @@ class NumpyUfuncTests(jtu.JaxTestCase):
|
||||
class NumpyDocTests(jtu.JaxTestCase):
|
||||
|
||||
def test_lax_numpy_docstrings(self):
|
||||
# Test that docstring wrapping & transformation didn't fail.
|
||||
|
||||
unimplemented = ['fromfile', 'fromiter']
|
||||
aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2',
|
||||
'amax', 'amin', 'around', 'bitwise_invert', 'bitwise_left_shift',
|
||||
@ -6371,15 +6475,6 @@ class NumpyDocTests(jtu.JaxTestCase):
|
||||
elif hasattr(np, name) and obj is getattr(np, name):
|
||||
# Some APIs are imported directly from NumPy; we don't check these.
|
||||
pass
|
||||
elif hasattr(obj, '__np_wrapped__'):
|
||||
# Functions decorated with @implements(...) should have __np_wrapped__
|
||||
wrapped_fun = obj.__np_wrapped__
|
||||
if wrapped_fun is not None:
|
||||
# If the wrapped function has a docstring, obj should too
|
||||
if wrapped_fun.__doc__ and not obj.__doc__:
|
||||
raise Exception(f"jnp.{name} does not contain wrapped docstring.")
|
||||
if obj.__doc__ and "*Original docstring below.*" not in obj.__doc__:
|
||||
raise Exception(f"jnp.{name} does not have a wrapped docstring.")
|
||||
elif name in aliases:
|
||||
assert "Alias of" in obj.__doc__
|
||||
elif name not in skip_args_check:
|
||||
@ -6391,84 +6486,6 @@ class NumpyDocTests(jtu.JaxTestCase):
|
||||
if name not in ["frompyfunc", "isdtype", "promote_types"]:
|
||||
self.assertIn("Examples:", doc, msg=f"'Examples:' not found in docstring of jnp.{name}")
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_jit" if jit else "", "jit": jit} for jit in [True, False])
|
||||
def test_wrapped_function_parameters(self, jit):
|
||||
def orig(x):
|
||||
"""Example Docstring
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : array_like
|
||||
Input Data
|
||||
|
||||
.. versionadded:: 1.8.0
|
||||
out : array_like, optional
|
||||
Output to overwrite
|
||||
other_arg : Any
|
||||
not used
|
||||
|
||||
Returns
|
||||
-------
|
||||
x : input
|
||||
"""
|
||||
return x
|
||||
|
||||
def wrapped(x, out=None):
|
||||
return x
|
||||
|
||||
if jit:
|
||||
wrapped = jax.jit(wrapped)
|
||||
|
||||
wrapped = implements(orig)(wrapped)
|
||||
doc = wrapped.__doc__
|
||||
|
||||
self.assertStartsWith(doc, "Example Docstring")
|
||||
self.assertIn("Original docstring below", doc)
|
||||
self.assertIn("Parameters", doc)
|
||||
self.assertIn("Returns", doc)
|
||||
self.assertNotIn('other_arg', doc)
|
||||
self.assertNotIn('versionadded', doc)
|
||||
|
||||
|
||||
def test_parse_numpydoc(self):
|
||||
# Unit test ensuring that _parse_numpydoc correctly parses docstrings for all
|
||||
# functions in NumPy's top-level namespace.
|
||||
section_titles = {'Attributes', 'Examples', 'Notes',
|
||||
'Parameters', 'Raises', 'References',
|
||||
'Returns', 'See also', 'See Also', 'Warnings', 'Warns'}
|
||||
headings = [title + '\n' + '-'*len(title) for title in section_titles]
|
||||
|
||||
for name in dir(np):
|
||||
if name.startswith('_'):
|
||||
continue
|
||||
obj = getattr(np, name)
|
||||
if isinstance(obj, type):
|
||||
continue
|
||||
if not callable(obj):
|
||||
continue
|
||||
if 'built-in function' in repr(obj):
|
||||
continue
|
||||
parsed = _parse_numpydoc(obj.__doc__)
|
||||
|
||||
# Check that no docstring is handled gracefully.
|
||||
if not obj.__doc__:
|
||||
self.assertEqual(parsed, ParsedDoc(obj.__doc__))
|
||||
continue
|
||||
|
||||
# Check that no unexpected section names are found.
|
||||
extra_keys = parsed.sections.keys() - section_titles
|
||||
if extra_keys:
|
||||
raise ValueError(f"Extra section headers found in np.{name}: {extra_keys}")
|
||||
|
||||
# Check that every docstring has a summary.
|
||||
if not parsed.summary:
|
||||
raise ValueError(f"No summary found for np.{name}")
|
||||
|
||||
# Check that no expected headings are missed.
|
||||
for heading in headings:
|
||||
assert heading not in parsed.front_matter
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user