rocm_jax/jax/numpy/_util.py
2020-07-01 16:52:41 -07:00

98 lines
3.6 KiB
Python

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
def update_numpydoc(docstr, fun, op):
'''Transforms the numpy docstring to remove references of
parameters that are supported by the numpy version but not the JAX version'''
#Some numpy functions have an extra tab at the beginning of each line,
#If this function is one of those we remove this extra tab from all the lines
if not hasattr(op, '__code__'):
return docstr
if docstr[:4] == ' ':
lines = docstr.split('\n')
for idx, line in enumerate(lines):
lines[idx] = line.replace(' ', '', 1)
docstr = '\n'.join(lines)
begin_idx = docstr.find("Parameters")
begin_idx = docstr.find("--\n", begin_idx) + 2
end_idx = docstr.find("Returns", begin_idx)
parameters = docstr[begin_idx:end_idx]
param_list = parameters.replace('\n ', '@@').split('\n')
for idx, p in enumerate(param_list):
param = p[:p.find(' : ')].split(", ")[0]
if param not in op.__code__.co_varnames:
param_list[idx] = ''
param_list = [param for param in param_list if param != '']
parameters = '\n'.join(param_list).replace('@@', '\n ')
return docstr[:begin_idx + 1] + parameters + docstr[end_idx - 2:]
_numpy_signature_re = re.compile(r'^([\w., ]+=)?\s*[\w\.]+\([\w\W]*\)$')
def _wraps(fun, update_doc=True, lax_description=""):
"""Like functools.wraps but works with numpy.ufuncs.
It is important that when wrapping numpy functions the parameters names
in the original function and in the JAX version are the same
Parameters:
fun: The function being wrapped
update_doc: whether to transform the numpy docstring to remove references of
parameters that are supported by the numpy version but not the JAX version.
If False, include the numpy docstring verbatim.
"""
def wrap(op):
if not hasattr(fun, '__doc__') or fun.__doc__ is None:
return op
try:
# Numpy doc comments have the form:
# fn(x, y, z) (optional)
#
# A one-line summary
#
# ... everything else ...
# We (a) move the summary to the top, since it is what the Sphinx
# autosummary extension expects, and (b) add a comment below the summary
# to the effect that this is a LAX wrapper of a Numpy function.
sections = fun.__doc__.split("\n\n")
signatures = []
summary = None
for i in range(len(sections)):
if _numpy_signature_re.match(sections[i]):
signatures.append(sections[i])
else:
summary = sections[i].strip()
break
body = "\n\n".join(signatures + sections[i + 1:])
if update_doc:
body = update_numpydoc(body, fun, op)
desc = lax_description + "\n" if lax_description else ""
docstr = (
"{summary}\n\nLAX-backend implementation of :func:`{fun}`.\n"
"{lax_description}Original docstring below.\n\n{body}"
.format(summary=summary, lax_description=desc,
fun=fun.__name__, body=body))
op.__name__ = fun.__name__
op.__doc__ = docstr
op.__np_wrapped__ = fun
finally:
return op
return wrap