mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
Add Sphinx-generated reference documentation for JAX.
This commit is contained in:
parent
4792b9bed3
commit
86d8915c3d
191
docs/conf.py
Normal file
191
docs/conf.py
Normal file
@ -0,0 +1,191 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# This file does only contain a selection of the most common options. For a
|
||||
# full list see the documentation:
|
||||
# http://www.sphinx-doc.org/en/master/config
|
||||
|
||||
# -- Path setup --------------------------------------------------------------
|
||||
|
||||
# If extensions (or modules to document with autodoc) are in another directory,
|
||||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
#
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.abspath('..'))
|
||||
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'JAX'
|
||||
copyright = '2019, Google LLC. NumPy and SciPy documentation are copyright the respective authors.'
|
||||
author = 'The JAX authors'
|
||||
|
||||
# The short X.Y version
|
||||
version = ''
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = ''
|
||||
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
# If your documentation needs a minimal Sphinx version, state it here.
|
||||
#
|
||||
# needs_sphinx = '1.0'
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.mathjax',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx.ext.viewcode',
|
||||
]
|
||||
|
||||
intersphinx_mapping = {
|
||||
'python': ('https://docs.python.org/3/', None),
|
||||
'numpy': ('https://docs.scipy.org/doc/numpy/', None),
|
||||
'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None),
|
||||
}
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
# You can specify multiple suffix as a list of string:
|
||||
#
|
||||
# source_suffix = ['.rst', '.md']
|
||||
source_suffix = '.rst'
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = 'index'
|
||||
|
||||
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||
# for a list of supported languages.
|
||||
#
|
||||
# This is also used if you do content translation via gettext catalogs.
|
||||
# Usually you set "language" from the command line for these cases.
|
||||
language = None
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = []
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = None
|
||||
|
||||
|
||||
autosummary_generate = True
|
||||
napolean_use_rtype = False
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
|
||||
# Theme options are theme-specific and customize the look and feel of a theme
|
||||
# further. For a list of options available for each theme, see the
|
||||
# documentation.
|
||||
#
|
||||
# html_theme_options = {}
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
|
||||
# Custom sidebar templates, must be a dictionary that maps document names
|
||||
# to template names.
|
||||
#
|
||||
# The default sidebars (for documents that don't match any pattern) are
|
||||
# defined by theme itself. Builtin themes are using these templates by
|
||||
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
|
||||
# 'searchbox.html']``.
|
||||
#
|
||||
# html_sidebars = {}
|
||||
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = 'JAXdoc'
|
||||
|
||||
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
latex_elements = {
|
||||
# The paper size ('letterpaper' or 'a4paper').
|
||||
#
|
||||
# 'papersize': 'letterpaper',
|
||||
|
||||
# The font size ('10pt', '11pt' or '12pt').
|
||||
#
|
||||
# 'pointsize': '10pt',
|
||||
|
||||
# Additional stuff for the LaTeX preamble.
|
||||
#
|
||||
# 'preamble': '',
|
||||
|
||||
# Latex figure (float) alignment
|
||||
#
|
||||
# 'figure_align': 'htbp',
|
||||
}
|
||||
|
||||
# Grouping the document tree into LaTeX files. List of tuples
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(master_doc, 'JAX.tex', 'JAX Documentation',
|
||||
'The JAX authors', 'manual'),
|
||||
]
|
||||
|
||||
|
||||
# -- Options for manual page output ------------------------------------------
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [
|
||||
(master_doc, 'jax', 'JAX Documentation',
|
||||
[author], 1)
|
||||
]
|
||||
|
||||
|
||||
# -- Options for Texinfo output ----------------------------------------------
|
||||
|
||||
# Grouping the document tree into Texinfo files. List of tuples
|
||||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(master_doc, 'JAX', 'JAX Documentation',
|
||||
author, 'JAX', 'One line description of project.',
|
||||
'Miscellaneous'),
|
||||
]
|
||||
|
||||
|
||||
# -- Options for Epub output -------------------------------------------------
|
||||
|
||||
# Bibliographic Dublin Core info.
|
||||
epub_title = project
|
||||
|
||||
# The unique identifier of the text. This can be a ISBN number
|
||||
# or the project homepage.
|
||||
#
|
||||
# epub_identifier = ''
|
||||
|
||||
# A unique identification for the text.
|
||||
#
|
||||
# epub_uid = ''
|
||||
|
||||
# A list of files that should not be packed into the epub file.
|
||||
epub_exclude_files = ['search.html']
|
||||
|
||||
|
||||
# -- Extension configuration -------------------------------------------------
|
22
docs/index.rst
Normal file
22
docs/index.rst
Normal file
@ -0,0 +1,22 @@
|
||||
JAX reference documentation
|
||||
===============================
|
||||
|
||||
Composable transformations of Python+NumPy programs: differentiate, vectorize,
|
||||
JIT to GPU/TPU, and more
|
||||
|
||||
For an introduction to JAX, start at the
|
||||
`JAX GitHub page <https://github.com/google/jax>`_.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
:caption: Contents:
|
||||
|
||||
jax
|
||||
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
|
||||
* :ref:`genindex`
|
||||
* :ref:`modindex`
|
||||
* :ref:`search`
|
7
docs/jax.experimental.minmax.rst
Normal file
7
docs/jax.experimental.minmax.rst
Normal file
@ -0,0 +1,7 @@
|
||||
jax.experimental.minmax module
|
||||
==============================
|
||||
|
||||
.. automodule:: jax.experimental.minmax
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
10
docs/jax.experimental.rst
Normal file
10
docs/jax.experimental.rst
Normal file
@ -0,0 +1,10 @@
|
||||
jax.experimental package
|
||||
========================
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
jax.experimental.minmax
|
||||
jax.experimental.stax
|
||||
|
||||
.. automodule:: jax.experimental
|
7
docs/jax.experimental.stax.rst
Normal file
7
docs/jax.experimental.stax.rst
Normal file
@ -0,0 +1,7 @@
|
||||
jax.experimental.stax module
|
||||
============================
|
||||
|
||||
.. automodule:: jax.experimental.stax
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
6
docs/jax.lax.rst
Normal file
6
docs/jax.lax.rst
Normal file
@ -0,0 +1,6 @@
|
||||
jax.lax package
|
||||
================
|
||||
|
||||
.. automodule:: jax.lax
|
||||
:members:
|
||||
:undoc-members:
|
182
docs/jax.numpy.rst
Normal file
182
docs/jax.numpy.rst
Normal file
@ -0,0 +1,182 @@
|
||||
|
||||
jax.numpy package
|
||||
=================
|
||||
|
||||
.. currentmodule:: jax.numpy
|
||||
|
||||
.. automodule:: jax.numpy
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
abs
|
||||
absolute
|
||||
add
|
||||
all
|
||||
allclose
|
||||
alltrue
|
||||
amax
|
||||
amin
|
||||
angle
|
||||
any
|
||||
append
|
||||
arange
|
||||
arccos
|
||||
arccosh
|
||||
arcsin
|
||||
arcsinh
|
||||
arctan
|
||||
arctan2
|
||||
arctanh
|
||||
argmax
|
||||
argmin
|
||||
argsort
|
||||
around
|
||||
array
|
||||
asarray
|
||||
atleast_1d
|
||||
atleast_2d
|
||||
bitwise_and
|
||||
bitwise_not
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
broadcast_arrays
|
||||
broadcast_to
|
||||
ceil
|
||||
clip
|
||||
column_stack
|
||||
concatenate
|
||||
conj
|
||||
conjugate
|
||||
cos
|
||||
cosh
|
||||
count_nonzero
|
||||
diag
|
||||
diag_indices
|
||||
diagonal
|
||||
divide
|
||||
divmod
|
||||
dot
|
||||
einsum
|
||||
equal
|
||||
exp
|
||||
exp2
|
||||
expand_dims
|
||||
expm1
|
||||
eye
|
||||
fabs
|
||||
flip
|
||||
floor
|
||||
floor_divide
|
||||
fmod
|
||||
full
|
||||
full_like
|
||||
geomspace
|
||||
greater
|
||||
greater_equal
|
||||
hstack
|
||||
identity
|
||||
imag
|
||||
inner
|
||||
isclose
|
||||
isfinite
|
||||
isinf
|
||||
isnan
|
||||
isneginf
|
||||
isposinf
|
||||
kron
|
||||
left_shift
|
||||
less
|
||||
less_equal
|
||||
linspace
|
||||
log
|
||||
log10
|
||||
log1p
|
||||
log2
|
||||
logaddexp
|
||||
logaddexp2
|
||||
logical_and
|
||||
logical_not
|
||||
logical_or
|
||||
logical_xor
|
||||
logspace
|
||||
matmul
|
||||
max
|
||||
maximum
|
||||
mean
|
||||
meshgrid
|
||||
min
|
||||
minimum
|
||||
mod
|
||||
moveaxis
|
||||
multiply
|
||||
nan_to_num
|
||||
nanmax
|
||||
nanmin
|
||||
nanprod
|
||||
nansum
|
||||
negative
|
||||
not_equal
|
||||
ones
|
||||
ones_like
|
||||
outer
|
||||
pad
|
||||
polyval
|
||||
power
|
||||
prod
|
||||
ravel
|
||||
real
|
||||
remainder
|
||||
repeat
|
||||
reshape
|
||||
right_shift
|
||||
rot90
|
||||
round
|
||||
row_stack
|
||||
sign
|
||||
sin
|
||||
sinh
|
||||
sometrue
|
||||
sort
|
||||
split
|
||||
sqrt
|
||||
square
|
||||
squeeze
|
||||
stack
|
||||
std
|
||||
subtract
|
||||
sum
|
||||
swapaxes
|
||||
take_along_axis
|
||||
tan
|
||||
tanh
|
||||
tensordot
|
||||
trace
|
||||
transpose
|
||||
tri
|
||||
tril
|
||||
triu
|
||||
true_divide
|
||||
var
|
||||
vdot
|
||||
vstack
|
||||
where
|
||||
zeros
|
||||
zeros_like
|
||||
|
||||
jax.numpy.linalg
|
||||
----------------
|
||||
|
||||
.. automodule:: jax.numpy.linalg
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
cholesky
|
||||
det
|
||||
eigh
|
||||
inv
|
||||
qr
|
||||
slogdet
|
||||
solve
|
||||
svd
|
21
docs/jax.rst
Normal file
21
docs/jax.rst
Normal file
@ -0,0 +1,21 @@
|
||||
jax package
|
||||
===========
|
||||
|
||||
Subpackages
|
||||
-----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
jax.numpy
|
||||
jax.scipy
|
||||
jax.experimental
|
||||
jax.lax
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: jax
|
||||
:members: jit, grad, value_and_grad, vmap, jacfwd, jacrev, make_jaxpr
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
53
docs/jax.scipy.rst
Normal file
53
docs/jax.scipy.rst
Normal file
@ -0,0 +1,53 @@
|
||||
jax.scipy package
|
||||
=================
|
||||
|
||||
jax.scipy.linalg
|
||||
-----------------------
|
||||
|
||||
.. automodule:: jax.scipy.linalg
|
||||
:members:
|
||||
|
||||
jax.scipy.misc
|
||||
---------------------
|
||||
|
||||
.. automodule:: jax.scipy.misc
|
||||
:members:
|
||||
|
||||
jax.scipy.special
|
||||
------------------------
|
||||
|
||||
.. automodule:: jax.scipy.special
|
||||
:members:
|
||||
|
||||
jax.scipy.stats
|
||||
-----------------------
|
||||
|
||||
jax.scipy.stats.beta
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: jax.scipy.stats.beta
|
||||
:members:
|
||||
|
||||
jax.scipy.stats.expon
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: jax.scipy.stats.expon
|
||||
:members:
|
||||
|
||||
jax.scipy.stats.gamma
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: jax.scipy.stats.gamma
|
||||
:members:
|
||||
|
||||
jax.scipy.stats.laplace
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: jax.scipy.stats.laplace
|
||||
:members:
|
||||
|
||||
jax.scipy.stats.norm
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: jax.scipy.stats.norm
|
||||
:members:
|
||||
|
||||
jax.scipy.stats.uniform
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: jax.scipy.stats.uniform
|
||||
:members:
|
7
docs/modules.rst
Normal file
7
docs/modules.rst
Normal file
@ -0,0 +1,7 @@
|
||||
jax
|
||||
===
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
jax
|
@ -227,9 +227,9 @@ def vmap(fun, in_axes=0, out_axes=0):
|
||||
For example, we can implement a matrix-matrix product using a vector dot
|
||||
product:
|
||||
|
||||
vv = lambda x, y: np.vdot(x, y) # ([a], [a]) -> []
|
||||
mv = vmap(vv, (0, None), 0) # ([a,b], [b]) -> [a]
|
||||
mm = vmap(mv, (None, 1), 1) # ([a,b], [b,c]) -> [a,c]
|
||||
>>> vv = lambda x, y: np.vdot(x, y) # ([a], [a]) -> []
|
||||
>>> mv = vmap(vv, (0, None), 0) # ([a,b], [b]) -> [a]
|
||||
>>> mm = vmap(mv, (None, 1), 1) # ([a,b], [b,c]) -> [a,c]
|
||||
|
||||
(`[a,b]` indicates an array with shape (a,b))
|
||||
"""
|
||||
|
@ -18,6 +18,7 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import itertools
|
||||
import re
|
||||
import string
|
||||
import warnings
|
||||
|
||||
@ -199,15 +200,36 @@ def _promote_args_like(op, *args):
|
||||
def _constant_like(x, const):
|
||||
return onp.array(const, dtype=_dtype(x))
|
||||
|
||||
_numpy_signature_re = re.compile(r'^([\w., ]+=)?\s*[\w\.]+\(.*\)$')
|
||||
|
||||
def _wraps(fun):
|
||||
"""Like functools.wraps but works with numpy.ufuncs."""
|
||||
def wrap(op):
|
||||
try:
|
||||
docstr = """
|
||||
LAX-backed implementation of {fun}. Original docstring below.
|
||||
# Numpy doc comments have the form:
|
||||
# fn(x, y, z) (optional)
|
||||
#
|
||||
# A one-line summary
|
||||
#
|
||||
# ... everything else ...
|
||||
# We (a) move the summary to the top, since it is what the Sphinx
|
||||
# autosummary extension expects, and (b) add a comment below the summary
|
||||
# to the effect that this is a LAX wrapper of a Numpy function.
|
||||
sections = fun.__doc__.split("\n\n")
|
||||
|
||||
{np_doc}
|
||||
""".format(fun=fun.__name__, np_doc=fun.__doc__)
|
||||
signatures = []
|
||||
summary = None
|
||||
for i in xrange(len(sections)):
|
||||
if _numpy_signature_re.match(sections[i]):
|
||||
signatures.append(sections[i])
|
||||
else:
|
||||
summary = sections[i].strip()
|
||||
break
|
||||
body = "\n\n".join(signatures + sections[i + 1:])
|
||||
docstr = (
|
||||
"{summary}\n\nLAX-backend implementation of :func:`{fun}`. "
|
||||
"Original docstring below.\n\n{body}".format(
|
||||
summary=summary, fun=fun.__name__, body=body))
|
||||
op.__name__ = fun.__name__
|
||||
op.__doc__ = docstr
|
||||
finally:
|
||||
@ -426,7 +448,7 @@ def remainder(x1, x2):
|
||||
x1, x2 = _promote_args("remainder", x1, x2)
|
||||
return lax.rem(lax.add(lax.rem(x1, x2), x2), x2)
|
||||
mod = remainder
|
||||
fmod = lax.rem
|
||||
fmod = _wraps(onp.fmod)(lambda x, y: lax.rem(x, y))
|
||||
|
||||
|
||||
@_wraps(onp.sqrt)
|
||||
@ -1268,6 +1290,7 @@ def tensordot(a, b, axes=2):
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
@_wraps(onp.einsum)
|
||||
def einsum(*operands):
|
||||
# using einsum_call=True here is an internal api for opt_einsum
|
||||
operands, contractions = opt_einsum.contract_path(
|
||||
|
5
readthedocs.yml
Normal file
5
readthedocs.yml
Normal file
@ -0,0 +1,5 @@
|
||||
build:
|
||||
image: latest
|
||||
|
||||
python:
|
||||
version: 3.6
|
6
setup.py
6
setup.py
@ -21,8 +21,10 @@ setup(
|
||||
author='JAX team',
|
||||
author_email='jax-dev@google.com',
|
||||
packages=find_packages(),
|
||||
install_requires=['numpy>=1.12', 'six', 'protobuf>=3.6.0', 'absl-py',
|
||||
'opt_einsum'],
|
||||
install_requires=[
|
||||
'jaxlib>=0.1.4', 'numpy>=1.12', 'six', 'protobuf>=3.6.0', 'absl-py',
|
||||
'opt_einsum'
|
||||
],
|
||||
url='https://github.com/google/jax',
|
||||
license='Apache-2.0',
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user