mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 19:16:06 +00:00
98 lines
3.6 KiB
Python
98 lines
3.6 KiB
Python
# Copyright 2020 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import re
|
|
|
|
|
|
def update_numpydoc(docstr, fun, op):
|
|
'''Transforms the numpy docstring to remove references of
|
|
parameters that are supported by the numpy version but not the JAX version'''
|
|
|
|
#Some numpy functions have an extra tab at the beginning of each line,
|
|
#If this function is one of those we remove this extra tab from all the lines
|
|
if not hasattr(op, '__code__'):
|
|
return docstr
|
|
if docstr[:4] == ' ':
|
|
lines = docstr.split('\n')
|
|
for idx, line in enumerate(lines):
|
|
lines[idx] = line.replace(' ', '', 1)
|
|
docstr = '\n'.join(lines)
|
|
|
|
begin_idx = docstr.find("Parameters")
|
|
begin_idx = docstr.find("--\n", begin_idx) + 2
|
|
end_idx = docstr.find("Returns", begin_idx)
|
|
|
|
parameters = docstr[begin_idx:end_idx]
|
|
param_list = parameters.replace('\n ', '@@').split('\n')
|
|
for idx, p in enumerate(param_list):
|
|
param = p[:p.find(' : ')].split(", ")[0]
|
|
if param not in op.__code__.co_varnames:
|
|
param_list[idx] = ''
|
|
param_list = [param for param in param_list if param != '']
|
|
parameters = '\n'.join(param_list).replace('@@', '\n ')
|
|
return docstr[:begin_idx + 1] + parameters + docstr[end_idx - 2:]
|
|
|
|
_numpy_signature_re = re.compile(r'^([\w., ]+=)?\s*[\w\.]+\([\w\W]*\)$')
|
|
|
|
|
|
def _wraps(fun, update_doc=True, lax_description=""):
|
|
"""Like functools.wraps but works with numpy.ufuncs.
|
|
It is important that when wrapping numpy functions the parameters names
|
|
in the original function and in the JAX version are the same
|
|
Parameters:
|
|
fun: The function being wrapped
|
|
update_doc: whether to transform the numpy docstring to remove references of
|
|
parameters that are supported by the numpy version but not the JAX version.
|
|
If False, include the numpy docstring verbatim.
|
|
"""
|
|
def wrap(op):
|
|
if not hasattr(fun, '__doc__') or fun.__doc__ is None:
|
|
return op
|
|
try:
|
|
# Numpy doc comments have the form:
|
|
# fn(x, y, z) (optional)
|
|
#
|
|
# A one-line summary
|
|
#
|
|
# ... everything else ...
|
|
# We (a) move the summary to the top, since it is what the Sphinx
|
|
# autosummary extension expects, and (b) add a comment below the summary
|
|
# to the effect that this is a LAX wrapper of a Numpy function.
|
|
sections = fun.__doc__.split("\n\n")
|
|
|
|
signatures = []
|
|
summary = None
|
|
for i in range(len(sections)):
|
|
if _numpy_signature_re.match(sections[i]):
|
|
signatures.append(sections[i])
|
|
else:
|
|
summary = sections[i].strip()
|
|
break
|
|
body = "\n\n".join(signatures + sections[i + 1:])
|
|
if update_doc:
|
|
body = update_numpydoc(body, fun, op)
|
|
desc = lax_description + "\n" if lax_description else ""
|
|
docstr = (
|
|
"{summary}\n\nLAX-backend implementation of :func:`{fun}`.\n"
|
|
"{lax_description}Original docstring below.\n\n{body}"
|
|
.format(summary=summary, lax_description=desc,
|
|
fun=fun.__name__, body=body))
|
|
|
|
op.__name__ = fun.__name__
|
|
op.__doc__ = docstr
|
|
op.__np_wrapped__ = fun
|
|
finally:
|
|
return op
|
|
return wrap
|