1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00

Add Sphinx-generated reference documentation for JAX.

This commit is contained in:
Peter Hawkins 2019-01-15 20:14:19 -05:00
parent 4792b9bed3
commit 86d8915c3d
14 changed files with 546 additions and 10 deletions

191
docs/conf.py Normal file

@ -0,0 +1,191 @@
# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
sys.path.insert(0, os.path.abspath('..'))
# -- Project information -----------------------------------------------------
project = 'JAX'
copyright = '2019, Google LLC. NumPy and SciPy documentation are copyright the respective authors.'
author = 'The JAX authors'
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = ''
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.intersphinx',
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
]
intersphinx_mapping = {
'python': ('https://docs.python.org/3/', None),
'numpy': ('https://docs.scipy.org/doc/numpy/', None),
'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None),
}
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
# The master toctree document.
master_doc = 'index'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = None
autosummary_generate = True
napolean_use_rtype = False
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# The default sidebars (for documents that don't match any pattern) are
# defined by theme itself. Builtin themes are using these templates by
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
# 'searchbox.html']``.
#
# html_sidebars = {}
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'JAXdoc'
# -- Options for LaTeX output ------------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'JAX.tex', 'JAX Documentation',
'The JAX authors', 'manual'),
]
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'jax', 'JAX Documentation',
[author], 1)
]
# -- Options for Texinfo output ----------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'JAX', 'JAX Documentation',
author, 'JAX', 'One line description of project.',
'Miscellaneous'),
]
# -- Options for Epub output -------------------------------------------------
# Bibliographic Dublin Core info.
epub_title = project
# The unique identifier of the text. This can be a ISBN number
# or the project homepage.
#
# epub_identifier = ''
# A unique identification for the text.
#
# epub_uid = ''
# A list of files that should not be packed into the epub file.
epub_exclude_files = ['search.html']
# -- Extension configuration -------------------------------------------------

22
docs/index.rst Normal file

@ -0,0 +1,22 @@
JAX reference documentation
===============================
Composable transformations of Python+NumPy programs: differentiate, vectorize,
JIT to GPU/TPU, and more
For an introduction to JAX, start at the
`JAX GitHub page <https://github.com/google/jax>`_.
.. toctree::
:maxdepth: 3
:caption: Contents:
jax
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`

@ -0,0 +1,7 @@
jax.experimental.minmax module
==============================
.. automodule:: jax.experimental.minmax
:members:
:undoc-members:
:show-inheritance:

10
docs/jax.experimental.rst Normal file

@ -0,0 +1,10 @@
jax.experimental package
========================
.. toctree::
:maxdepth: 1
jax.experimental.minmax
jax.experimental.stax
.. automodule:: jax.experimental

@ -0,0 +1,7 @@
jax.experimental.stax module
============================
.. automodule:: jax.experimental.stax
:members:
:undoc-members:
:show-inheritance:

6
docs/jax.lax.rst Normal file

@ -0,0 +1,6 @@
jax.lax package
================
.. automodule:: jax.lax
:members:
:undoc-members:

182
docs/jax.numpy.rst Normal file

@ -0,0 +1,182 @@
jax.numpy package
=================
.. currentmodule:: jax.numpy
.. automodule:: jax.numpy
.. autosummary::
:toctree: _autosummary
abs
absolute
add
all
allclose
alltrue
amax
amin
angle
any
append
arange
arccos
arccosh
arcsin
arcsinh
arctan
arctan2
arctanh
argmax
argmin
argsort
around
array
asarray
atleast_1d
atleast_2d
bitwise_and
bitwise_not
bitwise_or
bitwise_xor
broadcast_arrays
broadcast_to
ceil
clip
column_stack
concatenate
conj
conjugate
cos
cosh
count_nonzero
diag
diag_indices
diagonal
divide
divmod
dot
einsum
equal
exp
exp2
expand_dims
expm1
eye
fabs
flip
floor
floor_divide
fmod
full
full_like
geomspace
greater
greater_equal
hstack
identity
imag
inner
isclose
isfinite
isinf
isnan
isneginf
isposinf
kron
left_shift
less
less_equal
linspace
log
log10
log1p
log2
logaddexp
logaddexp2
logical_and
logical_not
logical_or
logical_xor
logspace
matmul
max
maximum
mean
meshgrid
min
minimum
mod
moveaxis
multiply
nan_to_num
nanmax
nanmin
nanprod
nansum
negative
not_equal
ones
ones_like
outer
pad
polyval
power
prod
ravel
real
remainder
repeat
reshape
right_shift
rot90
round
row_stack
sign
sin
sinh
sometrue
sort
split
sqrt
square
squeeze
stack
std
subtract
sum
swapaxes
take_along_axis
tan
tanh
tensordot
trace
transpose
tri
tril
triu
true_divide
var
vdot
vstack
where
zeros
zeros_like
jax.numpy.linalg
----------------
.. automodule:: jax.numpy.linalg
.. autosummary::
:toctree: _autosummary
cholesky
det
eigh
inv
qr
slogdet
solve
svd

21
docs/jax.rst Normal file

@ -0,0 +1,21 @@
jax package
===========
Subpackages
-----------
.. toctree::
:maxdepth: 1
jax.numpy
jax.scipy
jax.experimental
jax.lax
Module contents
---------------
.. automodule:: jax
:members: jit, grad, value_and_grad, vmap, jacfwd, jacrev, make_jaxpr
:undoc-members:
:show-inheritance:

53
docs/jax.scipy.rst Normal file

@ -0,0 +1,53 @@
jax.scipy package
=================
jax.scipy.linalg
-----------------------
.. automodule:: jax.scipy.linalg
:members:
jax.scipy.misc
---------------------
.. automodule:: jax.scipy.misc
:members:
jax.scipy.special
------------------------
.. automodule:: jax.scipy.special
:members:
jax.scipy.stats
-----------------------
jax.scipy.stats.beta
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: jax.scipy.stats.beta
:members:
jax.scipy.stats.expon
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: jax.scipy.stats.expon
:members:
jax.scipy.stats.gamma
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: jax.scipy.stats.gamma
:members:
jax.scipy.stats.laplace
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: jax.scipy.stats.laplace
:members:
jax.scipy.stats.norm
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: jax.scipy.stats.norm
:members:
jax.scipy.stats.uniform
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: jax.scipy.stats.uniform
:members:

7
docs/modules.rst Normal file

@ -0,0 +1,7 @@
jax
===
.. toctree::
:maxdepth: 4
jax

@ -227,9 +227,9 @@ def vmap(fun, in_axes=0, out_axes=0):
For example, we can implement a matrix-matrix product using a vector dot
product:
vv = lambda x, y: np.vdot(x, y) # ([a], [a]) -> []
mv = vmap(vv, (0, None), 0) # ([a,b], [b]) -> [a]
mm = vmap(mv, (None, 1), 1) # ([a,b], [b,c]) -> [a,c]
>>> vv = lambda x, y: np.vdot(x, y) # ([a], [a]) -> []
>>> mv = vmap(vv, (0, None), 0) # ([a,b], [b]) -> [a]
>>> mm = vmap(mv, (None, 1), 1) # ([a,b], [b,c]) -> [a,c]
(`[a,b]` indicates an array with shape (a,b))
"""

@ -18,6 +18,7 @@ from __future__ import print_function
import collections
import itertools
import re
import string
import warnings
@ -199,15 +200,36 @@ def _promote_args_like(op, *args):
def _constant_like(x, const):
return onp.array(const, dtype=_dtype(x))
_numpy_signature_re = re.compile(r'^([\w., ]+=)?\s*[\w\.]+\(.*\)$')
def _wraps(fun):
"""Like functools.wraps but works with numpy.ufuncs."""
def wrap(op):
try:
docstr = """
LAX-backed implementation of {fun}. Original docstring below.
# Numpy doc comments have the form:
# fn(x, y, z) (optional)
#
# A one-line summary
#
# ... everything else ...
# We (a) move the summary to the top, since it is what the Sphinx
# autosummary extension expects, and (b) add a comment below the summary
# to the effect that this is a LAX wrapper of a Numpy function.
sections = fun.__doc__.split("\n\n")
{np_doc}
""".format(fun=fun.__name__, np_doc=fun.__doc__)
signatures = []
summary = None
for i in xrange(len(sections)):
if _numpy_signature_re.match(sections[i]):
signatures.append(sections[i])
else:
summary = sections[i].strip()
break
body = "\n\n".join(signatures + sections[i + 1:])
docstr = (
"{summary}\n\nLAX-backend implementation of :func:`{fun}`. "
"Original docstring below.\n\n{body}".format(
summary=summary, fun=fun.__name__, body=body))
op.__name__ = fun.__name__
op.__doc__ = docstr
finally:
@ -426,7 +448,7 @@ def remainder(x1, x2):
x1, x2 = _promote_args("remainder", x1, x2)
return lax.rem(lax.add(lax.rem(x1, x2), x2), x2)
mod = remainder
fmod = lax.rem
fmod = _wraps(onp.fmod)(lambda x, y: lax.rem(x, y))
@_wraps(onp.sqrt)
@ -1268,6 +1290,7 @@ def tensordot(a, b, axes=2):
raise TypeError(msg)
@_wraps(onp.einsum)
def einsum(*operands):
# using einsum_call=True here is an internal api for opt_einsum
operands, contractions = opt_einsum.contract_path(

5
readthedocs.yml Normal file

@ -0,0 +1,5 @@
build:
image: latest
python:
version: 3.6

@ -21,8 +21,10 @@ setup(
author='JAX team',
author_email='jax-dev@google.com',
packages=find_packages(),
install_requires=['numpy>=1.12', 'six', 'protobuf>=3.6.0', 'absl-py',
'opt_einsum'],
install_requires=[
'jaxlib>=0.1.4', 'numpy>=1.12', 'six', 'protobuf>=3.6.0', 'absl-py',
'opt_einsum'
],
url='https://github.com/google/jax',
license='Apache-2.0',
)