mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
added documentation
This commit is contained in:
parent
54f28ed5bc
commit
55a06cbd5b
@ -6,5 +6,6 @@ jax.experimental package
|
||||
|
||||
jax.experimental.optimizers
|
||||
jax.experimental.stax
|
||||
jax.experimental.vectorize
|
||||
|
||||
.. automodule:: jax.experimental
|
||||
|
6
docs/jax.experimental.vectorize.rst
Normal file
6
docs/jax.experimental.vectorize.rst
Normal file
@ -0,0 +1,6 @@
|
||||
jax.experimental.vectorize module
|
||||
============================
|
||||
|
||||
.. automodule:: jax.experimental.vectorize
|
||||
:members:
|
||||
:show-inheritance:
|
@ -1,3 +1,98 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Extending JAX's vmap to work like NumPY's gufuncs.
|
||||
|
||||
From the `example notebook <https://nbviewer.jupyter.org/github/google/jax/blob/master/notebooks/gufuncs.ipynb>`_ by `Stephan Hoyer <https://github.com/shoyer>`_.
|
||||
|
||||
What is a gufunc?
|
||||
=================
|
||||
|
||||
`Generalized universal functions
|
||||
<https://docs.scipy.org/doc/numpy-1.15.0/reference/c-api.generalized-ufuncs.html>`_
|
||||
("gufuncs") are one of my favorite abstractions from NumPy. They generalize
|
||||
NumPy's `broadcasting rules
|
||||
<https://docs.scipy.org/doc/numpy-1.15.0/user/basics.broadcasting.html>`_ to
|
||||
handle non-scalar operations. When a gufuncs is applied to arrays, there are:
|
||||
|
||||
* "core dimensions" over which an operation is defined.
|
||||
* "broadcast dimensions" over which operations can be automatically vectorized.
|
||||
|
||||
A string `signature <https://docs.scipy.org/doc/numpy-1.15.0/reference/c-api.generalized-ufuncs.html#details-of-signature>`_
|
||||
associated with each gufunc controls how this happens by indicating how core
|
||||
dimensions are mapped between inputs and outputs. The syntax is easiest to
|
||||
understand by looking at a few examples:
|
||||
|
||||
* Addition: `(),()->()`
|
||||
* 1D inner product: `(i),(i)->()`
|
||||
* 1D sum: `(i)->()`
|
||||
* Matrix multiplcation: `(m,n),(n,k)->(m,k)`
|
||||
|
||||
Why write gufuncs?
|
||||
=====================
|
||||
|
||||
From a user perspective, gufuncs are nice because they're guaranteed to
|
||||
vectorize in a consistent and general fashion. For example, by default gufuncs
|
||||
use the last dimensions of arrays as core dimensions, but you can control that
|
||||
explicitly with the ``axis`` or ``axes`` arguments.
|
||||
|
||||
From a developer perspective, gufuncs are nice because they simply your work:
|
||||
you only need to think about the core logic of your function, not how it
|
||||
handles arbitrary dimensional input. You can just write that down in a simple,
|
||||
declarative way.
|
||||
|
||||
JAX makes it easy to write high-level performant code
|
||||
=====================================================
|
||||
|
||||
Unfortunately, writing NumPy gufuncs today is somewhat non-trivial. Your
|
||||
options today are:
|
||||
|
||||
1. Write the inner loops yourself in C.
|
||||
2. `np.vectorize <https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html>`_ creates something kind of like a gufunc, but it's painfully slow: the outer loop is performed in Python.
|
||||
3. `numba.guvectorize <https://numba.pydata.org/numba-doc/dev/user/vectorize.html>`_ can work well, if you don't need further code transformations like automatic differentiation.
|
||||
|
||||
JAX's ``vmap`` contains all the core functionality we need to write functions that work like gufuncs. JAX gufuncs play nicely with other transformations like ``grad`` and ``jit``.
|
||||
|
||||
A simple example
|
||||
================
|
||||
|
||||
Consider a simple example from data preprocessing, centering an array.
|
||||
|
||||
Here's how we might write a vectorized version using NumPy::
|
||||
|
||||
def center(array, axis=-1):
|
||||
# array can have any number of dimensions
|
||||
bias = np.mean(array, axis=axis)
|
||||
debiased = array - np.expand_dims(bias, axis)
|
||||
return bias, debiased
|
||||
|
||||
And here's how we could write a vectorized version using JAX gufuncs::
|
||||
|
||||
@vectorize('(n)->(),(n)')
|
||||
def center(array):
|
||||
# array is always a 1D vector
|
||||
bias = np.mean(array)
|
||||
debiased = array - bias
|
||||
return bias, debiased
|
||||
|
||||
See the difference?
|
||||
|
||||
* Instead of needing to think about broadcasting while writing the entire function, we can write the function assuming the input is always a vector.
|
||||
* We get the ``axis`` argument automatically, without needing to write it ourselves.
|
||||
* As a bonus, the decorator makes the function self-documenting: a reader immediately knows that it handles higher dimensional input and output correctly.
|
||||
|
||||
"""
|
||||
|
||||
from jax import grad, jit, vmap
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
@ -12,19 +107,16 @@ _SIGNATURE = '^{0:}->{0:}$'.format(_ARGUMENT_LIST)
|
||||
|
||||
|
||||
def _parse_gufunc_signature(signature):
|
||||
"""
|
||||
Parse string signatures for a generalized universal function.
|
||||
"""Parse string signatures for a generalized universal function.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
signature : string
|
||||
Generalized universal function signature, e.g., ``(m,n),(n,p)->(m,p)``
|
||||
for ``np.matmul``.
|
||||
Args:
|
||||
signature : string
|
||||
Generalized universal function signature, e.g., ``(m,n),(n,p)->(m,p)``
|
||||
for ``np.matmul``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple of input and output core dimensions parsed from the signature, each
|
||||
of the form List[Tuple[str, ...]].
|
||||
Returns:
|
||||
Tuple of input and output core dimensions parsed from the signature, each
|
||||
of the form List[Tuple[str, ...]].
|
||||
"""
|
||||
if not re.match(_SIGNATURE, signature):
|
||||
raise ValueError(
|
||||
@ -36,17 +128,15 @@ def _parse_gufunc_signature(signature):
|
||||
|
||||
|
||||
def _update_dim_sizes(dim_sizes, arg, core_dims):
|
||||
"""
|
||||
Incrementally check and update core dimension sizes for a single argument.
|
||||
"""Incrementally check and update core dimension sizes for a single argument.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
dim_sizes : Dict[str, int]
|
||||
Sizes of existing core dimensions. Will be updated in-place.
|
||||
arg : ndarray
|
||||
Argument to examine.
|
||||
core_dims : Tuple[str, ...]
|
||||
Core dimensions for this argument.
|
||||
Args:
|
||||
dim_sizes : Dict[str, int]
|
||||
Sizes of existing core dimensions. Will be updated in-place.
|
||||
arg : ndarray
|
||||
Argument to examine.
|
||||
core_dims : Tuple[str, ...]
|
||||
Core dimensions for this argument.
|
||||
"""
|
||||
if not core_dims:
|
||||
return
|
||||
@ -70,22 +160,19 @@ def _update_dim_sizes(dim_sizes, arg, core_dims):
|
||||
|
||||
|
||||
def _parse_input_dimensions(args, input_core_dims):
|
||||
"""
|
||||
Parse broadcast and core dimensions for vectorize with a signature.
|
||||
"""Parse broadcast and core dimensions for vectorize with a signature.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
args : Tuple[ndarray, ...]
|
||||
Tuple of input arguments to examine.
|
||||
input_core_dims : List[Tuple[str, ...]]
|
||||
List of core dimensions corresponding to each input.
|
||||
Args:
|
||||
args : Tuple[ndarray, ...]
|
||||
Tuple of input arguments to examine.
|
||||
input_core_dims : List[Tuple[str, ...]]
|
||||
List of core dimensions corresponding to each input.
|
||||
|
||||
Returns
|
||||
-------
|
||||
broadcast_shape : Tuple[int, ...]
|
||||
Common shape to broadcast all non-core dimensions to.
|
||||
dim_sizes : Dict[str, int]
|
||||
Common sizes for named core dimensions.
|
||||
Returns:
|
||||
broadcast_shape : Tuple[int, ...]
|
||||
Common shape to broadcast all non-core dimensions to.
|
||||
dim_sizes : Dict[str, int]
|
||||
Common sizes for named core dimensions.
|
||||
"""
|
||||
broadcast_args = []
|
||||
dim_sizes = {}
|
||||
@ -146,7 +233,21 @@ import functools
|
||||
|
||||
|
||||
def vectorize(signature):
|
||||
"""Vectorize a function using JAX."""
|
||||
"""Vectorize a function using JAX.
|
||||
|
||||
Turns an abritrary function into a numpy style "gufunc". Once
|
||||
you specify the behavior of the core axis, the rest will be
|
||||
broadcast naturally.
|
||||
|
||||
Args:
|
||||
signature: an einsum style signature that defines how the core dimensions are mapped between inputs and outputs.
|
||||
|
||||
Returns:
|
||||
The vectorized 'gufunc' that will automatically broadcast
|
||||
while maintaining the specified core logic, the returned
|
||||
function also has a new ``axis`` parameter for specifying
|
||||
which axis should be treated as the core one.
|
||||
"""
|
||||
input_core_dims, output_core_dims = _parse_gufunc_signature(signature)
|
||||
|
||||
def decorator(func):
|
||||
|
Loading…
x
Reference in New Issue
Block a user