mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[JAX] move example libraries from jax.experimental
into jax.example_libraries
The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction. PiperOrigin-RevId: 404405186
This commit is contained in:
parent
349d0d0879
commit
623c201054
@ -14,6 +14,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* [GitHub
|
||||
commits](https://github.com/google/jax/compare/jax-v0.2.24...main).
|
||||
|
||||
* Breaking changes
|
||||
* Moved `jax.experimental.stax` to `jax.example_libraries.stax`
|
||||
* Moved `jax.experimental.optimizers` to `jax.example_libraries.optimizers`
|
||||
|
||||
|
||||
## jax 0.2.24 (Oct 19, 2021)
|
||||
* [GitHub
|
||||
commits](https://github.com/google/jax/compare/jax-v0.2.22...jax-v0.2.24).
|
||||
|
@ -98,11 +98,11 @@ For a deeper dive into JAX:
|
||||
notebooks](https://github.com/google/jax/tree/main/docs/notebooks).
|
||||
|
||||
You can also take a look at [the mini-libraries in
|
||||
`jax.experimental`](https://github.com/google/jax/tree/main/jax/experimental/README.md),
|
||||
`jax.example_libraries`](https://github.com/google/jax/tree/main/jax/experimental/README.md),
|
||||
like [`stax` for building neural
|
||||
networks](https://github.com/google/jax/tree/main/jax/experimental/README.md#neural-net-building-with-stax)
|
||||
networks](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#neural-net-building-with-stax)
|
||||
and [`optimizers` for first-order stochastic
|
||||
optimization](https://github.com/google/jax/tree/main/jax/experimental/README.md#first-order-optimization),
|
||||
optimization](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#first-order-optimization),
|
||||
or the [examples](https://github.com/google/jax/tree/main/examples).
|
||||
|
||||
## Transformations
|
||||
|
7
docs/jax.example_libraries.optimizers.rst
Normal file
7
docs/jax.example_libraries.optimizers.rst
Normal file
@ -0,0 +1,7 @@
|
||||
jax.example_libraries.optimizers module
|
||||
=======================================
|
||||
|
||||
.. automodule:: jax.example_libraries.optimizers
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
10
docs/jax.example_libraries.rst
Normal file
10
docs/jax.example_libraries.rst
Normal file
@ -0,0 +1,10 @@
|
||||
jax.example_libraries package
|
||||
=============================
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
jax.example_libraries.optimizers
|
||||
jax.example_libraries.stax
|
||||
|
||||
.. automodule:: jax.example_libraries
|
7
docs/jax.example_libraries.stax.rst
Normal file
7
docs/jax.example_libraries.stax.rst
Normal file
@ -0,0 +1,7 @@
|
||||
jax.example_libraries.stax module
|
||||
=================================
|
||||
|
||||
.. automodule:: jax.example_libraries.stax
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -1,7 +0,0 @@
|
||||
jax.experimental.optimizers module
|
||||
==================================
|
||||
|
||||
.. automodule:: jax.experimental.optimizers
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -11,9 +11,7 @@ jax.experimental package
|
||||
jax.experimental.loops
|
||||
jax.experimental.maps
|
||||
jax.experimental.pjit
|
||||
jax.experimental.optimizers
|
||||
jax.experimental.sparse
|
||||
jax.experimental.stax
|
||||
|
||||
.. automodule:: jax.experimental
|
||||
|
||||
|
@ -1,7 +0,0 @@
|
||||
jax.experimental.stax module
|
||||
============================
|
||||
|
||||
.. automodule:: jax.experimental.stax
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -11,6 +11,7 @@ Subpackages
|
||||
|
||||
jax.numpy
|
||||
jax.scipy
|
||||
jax.example_libraries
|
||||
jax.experimental
|
||||
jax.image
|
||||
jax.lax
|
||||
|
@ -24,7 +24,7 @@ import matplotlib.pyplot as plt
|
||||
|
||||
from jax import jit, grad, vmap
|
||||
from jax import random
|
||||
from jax.experimental import optimizers
|
||||
from jax.example_libraries import optimizers
|
||||
import jax.numpy as jnp
|
||||
import jax.scipy.stats.norm as norm
|
||||
|
||||
|
@ -75,8 +75,8 @@ from jax import grad
|
||||
from jax import jit
|
||||
from jax import random
|
||||
from jax import vmap
|
||||
from jax.experimental import optimizers
|
||||
from jax.experimental import stax
|
||||
from jax.example_libraries import optimizers
|
||||
from jax.example_libraries import stax
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
import jax.numpy as jnp
|
||||
from jax.examples import datasets
|
||||
|
@ -18,7 +18,7 @@ from functools import partial
|
||||
import numpy.random as npr
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental import optimizers
|
||||
from jax.example_libraries import optimizers
|
||||
from jax import grad, jit, make_jaxpr, vmap, lax
|
||||
|
||||
|
||||
|
@ -14,8 +14,8 @@
|
||||
|
||||
"""A basic MNIST example using JAX with the mini-libraries stax and optimizers.
|
||||
|
||||
The mini-library jax.experimental.stax is for neural network building, and
|
||||
the mini-library jax.experimental.optimizers is for first-order stochastic
|
||||
The mini-library jax.example_libraries.stax is for neural network building, and
|
||||
the mini-library jax.example_libraries.optimizers is for first-order stochastic
|
||||
optimization.
|
||||
"""
|
||||
|
||||
@ -27,9 +27,9 @@ import numpy.random as npr
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import jit, grad, random
|
||||
from jax.experimental import optimizers
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, Relu, LogSoftmax
|
||||
from jax.example_libraries import optimizers
|
||||
from jax.example_libraries import stax
|
||||
from jax.example_libraries.stax import Dense, Relu, LogSoftmax
|
||||
from examples import datasets
|
||||
|
||||
|
||||
|
@ -27,9 +27,9 @@ import matplotlib.pyplot as plt
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import jit, grad, lax, random
|
||||
from jax.experimental import optimizers
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
|
||||
from jax.example_libraries import optimizers
|
||||
from jax.example_libraries import stax
|
||||
from jax.example_libraries.stax import Dense, FanOut, Relu, Softplus
|
||||
from examples import datasets
|
||||
|
||||
|
||||
|
@ -22,11 +22,11 @@ import numpy.random as npr
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import jit, grad, random
|
||||
from jax.experimental import optimizers
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum,
|
||||
FanOut, Flatten, GeneralConv, Identity,
|
||||
MaxPool, Relu, LogSoftmax)
|
||||
from jax.example_libraries import optimizers
|
||||
from jax.example_libraries import stax
|
||||
from jax.example_libraries.stax import (AvgPool, BatchNorm, Conv, Dense,
|
||||
FanInSum, FanOut, Flatten, GeneralConv,
|
||||
Identity, MaxPool, Relu, LogSoftmax)
|
||||
|
||||
|
||||
# ResNet blocks compose other layers
|
||||
|
@ -25,13 +25,14 @@ constructor functions for common basic pairs, like `Conv` and `Relu`, and these
|
||||
pairs can be composed in series using `stax.serial` or in parallel using
|
||||
`stax.parallel`.
|
||||
|
||||
Here’s an example:
|
||||
Here's an example:
|
||||
|
||||
```python
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax
|
||||
from jax.example_libraries import stax
|
||||
from jax.example_libraries.stax import (
|
||||
Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax)
|
||||
|
||||
# Use stax to set up network initialization and evaluation functions
|
||||
net_init, net_apply = stax.serial(
|
||||
@ -54,20 +55,20 @@ predictions = net_apply(net_params, inputs)
|
||||
|
||||
### First-order optimization
|
||||
|
||||
JAX has a minimal optimization library focused on stochastic first-order
|
||||
optimizers. Every optimizer is modeled as an `(init_fun, update_fun,
|
||||
get_params)` triple of functions. The `init_fun` is used to initialize the
|
||||
optimizer state, which could include things like momentum variables, and the
|
||||
`update_fun` accepts a gradient and an optimizer state to produce a new
|
||||
optimizer state. The `get_params` function extracts the current iterate (i.e.
|
||||
the current parameters) from the optimizer state. The parameters being optimized
|
||||
can be ndarrays or arbitrarily-nested list/tuple/dict structures, so you can
|
||||
store your parameters however you’d like.
|
||||
The file `optimizers.py` contains a minimal optimization library focused on
|
||||
stochastic first-order optimizers. Every optimizer is modeled as an
|
||||
`(init_fun, update_fun, get_params)` triple of functions. The `init_fun` is used
|
||||
to initialize the optimizer state, which could include things like momentum
|
||||
variables, and the `update_fun` accepts a gradient and an optimizer state to
|
||||
produce a new optimizer state. The `get_params` function extracts the current
|
||||
iterate (i.e. the current parameters) from the optimizer state. The parameters
|
||||
being optimized can be ndarrays or arbitrarily-nested list/tuple/dict
|
||||
structures, so you can store your parameters however you'd like.
|
||||
|
||||
Here’s an example, using `jit` to compile the whole update end-to-end:
|
||||
Here's an example, using `jit` to compile the whole update end-to-end:
|
||||
|
||||
```python
|
||||
from jax.experimental import optimizers
|
||||
from jax.example_libraries import optimizers
|
||||
from jax import jit, grad
|
||||
|
||||
# Define a simple squared-error loss
|
13
jax/example_libraries/__init__.py
Normal file
13
jax/example_libraries/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
613
jax/example_libraries/optimizers.py
Normal file
613
jax/example_libraries/optimizers.py
Normal file
@ -0,0 +1,613 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Optimizers for use with JAX.
|
||||
|
||||
This module contains some convenient optimizer definitions, specifically
|
||||
initialization and update functions, which can be used with ndarrays or
|
||||
arbitrarily-nested tuple/list/dicts of ndarrays.
|
||||
|
||||
An optimizer is modeled as an ``(init_fun, update_fun, get_params)`` triple of
|
||||
functions, where the component functions have these signatures:
|
||||
|
||||
::
|
||||
|
||||
init_fun(params)
|
||||
|
||||
Args:
|
||||
params: pytree representing the initial parameters.
|
||||
|
||||
Returns:
|
||||
A pytree representing the initial optimizer state, which includes the
|
||||
initial parameters and may also include auxiliary values like initial
|
||||
momentum. The optimizer state pytree structure generally differs from that
|
||||
of `params`.
|
||||
|
||||
::
|
||||
|
||||
update_fun(step, grads, opt_state)
|
||||
|
||||
Args:
|
||||
step: integer representing the step index.
|
||||
grads: a pytree with the same structure as `get_params(opt_state)`
|
||||
representing the gradients to be used in updating the optimizer state.
|
||||
opt_state: a pytree representing the optimizer state to be updated.
|
||||
|
||||
Returns:
|
||||
A pytree with the same structure as the `opt_state` argument representing
|
||||
the updated optimizer state.
|
||||
|
||||
::
|
||||
|
||||
get_params(opt_state)
|
||||
|
||||
Args:
|
||||
opt_state: pytree representing an optimizer state.
|
||||
|
||||
Returns:
|
||||
A pytree representing the parameters extracted from `opt_state`, such that
|
||||
the invariant `params == get_params(init_fun(params))` holds true.
|
||||
|
||||
|
||||
Notice that an optimizer implementation has a lot of flexibility in the form of
|
||||
opt_state: it just has to be a pytree of JaxTypes (so that it can be passed to
|
||||
the JAX transforms defined in api.py) and it has to be consumable by update_fun
|
||||
and get_params.
|
||||
|
||||
Example Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
|
||||
opt_state = opt_init(params)
|
||||
|
||||
def step(step, opt_state):
|
||||
value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state))
|
||||
opt_state = opt_update(step, grads, opt_state)
|
||||
return value, opt_state
|
||||
|
||||
for i in range(num_steps):
|
||||
value, opt_state = step(i, opt_state)
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, NamedTuple, Tuple, Union
|
||||
|
||||
from collections import namedtuple
|
||||
import functools
|
||||
from functools import partial
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax._src.util import safe_zip, safe_map, unzip2
|
||||
from jax import tree_util
|
||||
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
register_pytree_node)
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
|
||||
|
||||
# The implementation here basically works by flattening pytrees. There are two
|
||||
# levels of pytrees to think about: the pytree of params, which we can think of
|
||||
# as defining an "outer pytree", and a pytree produced by applying init_fun to
|
||||
# each leaf of the params pytree, which we can think of as the "inner pytrees".
|
||||
# Since pytrees can be flattened, that structure is isomorphic to a list of
|
||||
# lists (with no further nesting).
|
||||
|
||||
OptimizerState = namedtuple("OptimizerState",
|
||||
["packed_state", "tree_def", "subtree_defs"])
|
||||
register_pytree_node(
|
||||
OptimizerState,
|
||||
lambda xs: ((xs.packed_state,), (xs.tree_def, xs.subtree_defs)),
|
||||
lambda data, xs: OptimizerState(xs[0], data[0], data[1])) # type: ignore[index]
|
||||
|
||||
|
||||
Array = Any
|
||||
Params = Any # Parameters are arbitrary nests of `jnp.ndarrays`.
|
||||
State = Any # internal State
|
||||
Updates = Params # Gradient updates are of the same type as parameters.
|
||||
|
||||
InitFn = Callable[[Params], OptimizerState]
|
||||
Step = int
|
||||
UpdateFn = Callable[[Step, Updates, OptimizerState], OptimizerState]
|
||||
ParamsFn = Callable[[OptimizerState], Params]
|
||||
|
||||
class Optimizer(NamedTuple):
|
||||
init_fn: InitFn
|
||||
update_fn: UpdateFn
|
||||
params_fn: ParamsFn
|
||||
|
||||
Schedule = Callable[[Step], float]
|
||||
|
||||
def optimizer(opt_maker: Callable[...,
|
||||
Tuple[Callable[[Params], State],
|
||||
Callable[[Step, Updates, Params], Params],
|
||||
Callable[[State], Params]]]) -> Callable[..., Optimizer]:
|
||||
"""Decorator to make an optimizer defined for arrays generalize to containers.
|
||||
|
||||
With this decorator, you can write init, update, and get_params functions that
|
||||
each operate only on single arrays, and convert them to corresponding
|
||||
functions that operate on pytrees of parameters. See the optimizers defined in
|
||||
optimizers.py for examples.
|
||||
|
||||
Args:
|
||||
opt_maker: a function that returns an ``(init_fun, update_fun, get_params)``
|
||||
triple of functions that might only work with ndarrays, as per
|
||||
|
||||
.. code-block:: haskell
|
||||
|
||||
init_fun :: ndarray -> OptStatePytree ndarray
|
||||
update_fun :: OptStatePytree ndarray -> OptStatePytree ndarray
|
||||
get_params :: OptStatePytree ndarray -> ndarray
|
||||
|
||||
Returns:
|
||||
An ``(init_fun, update_fun, get_params)`` triple of functions that work on
|
||||
arbitrary pytrees, as per
|
||||
|
||||
.. code-block:: haskell
|
||||
|
||||
init_fun :: ParameterPytree ndarray -> OptimizerState
|
||||
update_fun :: OptimizerState -> OptimizerState
|
||||
get_params :: OptimizerState -> ParameterPytree ndarray
|
||||
|
||||
The OptimizerState pytree type used by the returned functions is isomorphic
|
||||
to ``ParameterPytree (OptStatePytree ndarray)``, but may store the state
|
||||
instead as e.g. a partially-flattened data structure for performance.
|
||||
"""
|
||||
|
||||
@functools.wraps(opt_maker)
|
||||
def tree_opt_maker(*args, **kwargs):
|
||||
init, update, get_params = opt_maker(*args, **kwargs)
|
||||
|
||||
@functools.wraps(init)
|
||||
def tree_init(x0_tree):
|
||||
x0_flat, tree = tree_flatten(x0_tree)
|
||||
initial_states = [init(x0) for x0 in x0_flat]
|
||||
states_flat, subtrees = unzip2(map(tree_flatten, initial_states))
|
||||
return OptimizerState(states_flat, tree, subtrees)
|
||||
|
||||
@functools.wraps(update)
|
||||
def tree_update(i, grad_tree, opt_state):
|
||||
states_flat, tree, subtrees = opt_state
|
||||
grad_flat, tree2 = tree_flatten(grad_tree)
|
||||
if tree2 != tree:
|
||||
msg = ("optimizer update function was passed a gradient tree that did "
|
||||
"not match the parameter tree structure with which it was "
|
||||
"initialized: parameter tree {} and grad tree {}.")
|
||||
raise TypeError(msg.format(tree, tree2))
|
||||
states = map(tree_unflatten, subtrees, states_flat)
|
||||
new_states = map(partial(update, i), grad_flat, states)
|
||||
new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states))
|
||||
for subtree, subtree2 in zip(subtrees, subtrees2):
|
||||
if subtree2 != subtree:
|
||||
msg = ("optimizer update function produced an output structure that "
|
||||
"did not match its input structure: input {} and output {}.")
|
||||
raise TypeError(msg.format(subtree, subtree2))
|
||||
return OptimizerState(new_states_flat, tree, subtrees)
|
||||
|
||||
@functools.wraps(get_params)
|
||||
def tree_get_params(opt_state):
|
||||
states_flat, tree, subtrees = opt_state
|
||||
states = map(tree_unflatten, subtrees, states_flat)
|
||||
params = map(get_params, states)
|
||||
return tree_unflatten(tree, params)
|
||||
|
||||
return Optimizer(tree_init, tree_update, tree_get_params)
|
||||
return tree_opt_maker
|
||||
|
||||
|
||||
### optimizers
|
||||
|
||||
@optimizer
|
||||
def sgd(step_size):
|
||||
"""Construct optimizer triple for stochastic gradient descent.
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
return x0
|
||||
def update(i, g, x):
|
||||
return x - step_size(i) * g
|
||||
def get_params(x):
|
||||
return x
|
||||
return Optimizer(init, update, get_params)
|
||||
|
||||
@optimizer
|
||||
def momentum(step_size: Schedule, mass: float):
|
||||
"""Construct optimizer triple for SGD with momentum.
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
mass: positive scalar representing the momentum coefficient.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
v0 = jnp.zeros_like(x0)
|
||||
return x0, v0
|
||||
def update(i, g, state):
|
||||
x, velocity = state
|
||||
velocity = mass * velocity + g
|
||||
x = x - step_size(i) * velocity
|
||||
return x, velocity
|
||||
def get_params(state):
|
||||
x, _ = state
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
@optimizer
|
||||
def nesterov(step_size: Schedule, mass: float):
|
||||
"""Construct optimizer triple for SGD with Nesterov momentum.
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
mass: positive scalar representing the momentum coefficient.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
v0 = jnp.zeros_like(x0)
|
||||
return x0, v0
|
||||
def update(i, g, state):
|
||||
x, velocity = state
|
||||
velocity = mass * velocity + g
|
||||
x = x - step_size(i) * (mass * velocity + g)
|
||||
return x, velocity
|
||||
def get_params(state):
|
||||
x, _ = state
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
@optimizer
|
||||
def adagrad(step_size, momentum=0.9):
|
||||
"""Construct optimizer triple for Adagrad.
|
||||
|
||||
Adaptive Subgradient Methods for Online Learning and Stochastic Optimization:
|
||||
http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
momentum: optional, a positive scalar value for momentum
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
|
||||
def init(x0):
|
||||
g_sq = jnp.zeros_like(x0)
|
||||
m = jnp.zeros_like(x0)
|
||||
return x0, g_sq, m
|
||||
|
||||
def update(i, g, state):
|
||||
x, g_sq, m = state
|
||||
g_sq += jnp.square(g)
|
||||
g_sq_inv_sqrt = jnp.where(g_sq > 0, 1. / jnp.sqrt(g_sq), 0.0)
|
||||
m = (1. - momentum) * (g * g_sq_inv_sqrt) + momentum * m
|
||||
x = x - step_size(i) * m
|
||||
return x, g_sq, m
|
||||
|
||||
def get_params(state):
|
||||
x, _, _ = state
|
||||
return x
|
||||
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
@optimizer
|
||||
def rmsprop(step_size, gamma=0.9, eps=1e-8):
|
||||
"""Construct optimizer triple for RMSProp.
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
gamma: Decay parameter.
|
||||
eps: Epsilon parameter.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
avg_sq_grad = jnp.zeros_like(x0)
|
||||
return x0, avg_sq_grad
|
||||
def update(i, g, state):
|
||||
x, avg_sq_grad = state
|
||||
avg_sq_grad = avg_sq_grad * gamma + jnp.square(g) * (1. - gamma)
|
||||
x = x - step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
|
||||
return x, avg_sq_grad
|
||||
def get_params(state):
|
||||
x, _ = state
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
@optimizer
|
||||
def rmsprop_momentum(step_size, gamma=0.9, eps=1e-8, momentum=0.9):
|
||||
"""Construct optimizer triple for RMSProp with momentum.
|
||||
|
||||
This optimizer is separate from the rmsprop optimizer because it needs to
|
||||
keep track of additional parameters.
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
gamma: Decay parameter.
|
||||
eps: Epsilon parameter.
|
||||
momentum: Momentum parameter.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
avg_sq_grad = jnp.zeros_like(x0)
|
||||
mom = jnp.zeros_like(x0)
|
||||
return x0, avg_sq_grad, mom
|
||||
def update(i, g, state):
|
||||
x, avg_sq_grad, mom = state
|
||||
avg_sq_grad = avg_sq_grad * gamma + jnp.square(g) * (1. - gamma)
|
||||
mom = momentum * mom + step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
|
||||
x = x - mom
|
||||
return x, avg_sq_grad, mom
|
||||
def get_params(state):
|
||||
x, _, _ = state
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
@optimizer
|
||||
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
|
||||
"""Construct optimizer triple for Adam.
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
b1: optional, a positive scalar value for beta_1, the exponential decay rate
|
||||
for the first moment estimates (default 0.9).
|
||||
b2: optional, a positive scalar value for beta_2, the exponential decay rate
|
||||
for the second moment estimates (default 0.999).
|
||||
eps: optional, a positive scalar value for epsilon, a small constant for
|
||||
numerical stability (default 1e-8).
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
m0 = jnp.zeros_like(x0)
|
||||
v0 = jnp.zeros_like(x0)
|
||||
return x0, m0, v0
|
||||
def update(i, g, state):
|
||||
x, m, v = state
|
||||
m = (1 - b1) * g + b1 * m # First moment estimate.
|
||||
v = (1 - b2) * jnp.square(g) + b2 * v # Second moment estimate.
|
||||
mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
|
||||
vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
|
||||
x = x - step_size(i) * mhat / (jnp.sqrt(vhat) + eps)
|
||||
return x, m, v
|
||||
def get_params(state):
|
||||
x, _, _ = state
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
@optimizer
|
||||
def adamax(step_size, b1=0.9, b2=0.999, eps=1e-8):
|
||||
"""Construct optimizer triple for AdaMax (a variant of Adam based on infinity norm).
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
b1: optional, a positive scalar value for beta_1, the exponential decay rate
|
||||
for the first moment estimates (default 0.9).
|
||||
b2: optional, a positive scalar value for beta_2, the exponential decay rate
|
||||
for the second moment estimates (default 0.999).
|
||||
eps: optional, a positive scalar value for epsilon, a small constant for
|
||||
numerical stability (default 1e-8).
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
m0 = jnp.zeros_like(x0)
|
||||
u0 = jnp.zeros_like(x0)
|
||||
return x0, m0, u0
|
||||
def update(i, g, state):
|
||||
x, m, u = state
|
||||
m = (1 - b1) * g + b1 * m # First moment estimate.
|
||||
u = jnp.maximum(b2 * u, jnp.abs(g)) # Update exponentially weighted infinity norm.
|
||||
x = (x - (step_size(i) / (1 - jnp.asarray(b1, m.dtype) ** (i + 1))) * m
|
||||
/ (u + eps))
|
||||
return x, m, u
|
||||
def get_params(state):
|
||||
x, _, _ = state
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
@optimizer
|
||||
def sm3(step_size, momentum=0.9):
|
||||
"""Construct optimizer triple for SM3.
|
||||
|
||||
Memory-Efficient Adaptive Optimization for Large-Scale Learning.
|
||||
https://arxiv.org/abs/1901.11150
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
momentum: optional, a positive scalar value for momentum
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
|
||||
def splice(seq, i, x):
|
||||
lst = list(seq)
|
||||
lst[i:i+1] = x
|
||||
return lst
|
||||
|
||||
def broadcast_into(ndim, x, axis):
|
||||
idx = splice([None] * ndim, axis, [slice(None)])
|
||||
return x[tuple(idx)]
|
||||
|
||||
def init(x0):
|
||||
x_shape = x0.shape
|
||||
x0 = jnp.atleast_1d(x0)
|
||||
vs = [jnp.zeros(sz, dtype=x0.dtype) for sz in x0.shape]
|
||||
return x0, jnp.zeros_like(x0), vs, x_shape
|
||||
|
||||
def update(i, g, state):
|
||||
x, m, vs, x_shape = state
|
||||
vs = [broadcast_into(g.ndim, v, i) for i, v in enumerate(vs)]
|
||||
accum = functools.reduce(jnp.minimum, vs) + jnp.square(g)
|
||||
accum_inv_sqrt = jnp.where(accum > 0, 1. / jnp.sqrt(accum), 0)
|
||||
m = (1. - momentum) * (g * accum_inv_sqrt) + momentum * m
|
||||
x = x - step_size(i) * m
|
||||
vs = [accum.max(splice(range(x.ndim), j, [])) for j in range(x.ndim)]
|
||||
return x, m, vs, x_shape
|
||||
|
||||
def get_params(state):
|
||||
x, _, _, x_shape = state
|
||||
return x.reshape(x_shape)
|
||||
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
### learning rate schedules
|
||||
|
||||
def constant(step_size) -> Schedule:
|
||||
def schedule(i):
|
||||
return step_size
|
||||
return schedule
|
||||
|
||||
def exponential_decay(step_size, decay_steps, decay_rate):
|
||||
def schedule(i):
|
||||
return step_size * decay_rate ** (i / decay_steps)
|
||||
return schedule
|
||||
|
||||
def inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False):
|
||||
if staircase:
|
||||
def schedule(i):
|
||||
return step_size / (1 + decay_rate * jnp.floor(i / decay_steps))
|
||||
else:
|
||||
def schedule(i):
|
||||
return step_size / (1 + decay_rate * i / decay_steps)
|
||||
return schedule
|
||||
|
||||
def polynomial_decay(step_size, decay_steps, final_step_size, power=1.0):
|
||||
def schedule(step_num):
|
||||
step_num = jnp.minimum(step_num, decay_steps)
|
||||
step_mult = (1 - step_num / decay_steps) ** power
|
||||
return step_mult * (step_size - final_step_size) + final_step_size
|
||||
|
||||
return schedule
|
||||
|
||||
def piecewise_constant(boundaries: Any, values: Any):
|
||||
boundaries = jnp.array(boundaries)
|
||||
values = jnp.array(values)
|
||||
if not boundaries.ndim == values.ndim == 1:
|
||||
raise ValueError("boundaries and values must be sequences")
|
||||
if not boundaries.shape[0] == values.shape[0] - 1:
|
||||
raise ValueError("boundaries length must be one shorter than values length")
|
||||
|
||||
def schedule(i):
|
||||
return values[jnp.sum(i > boundaries)]
|
||||
return schedule
|
||||
|
||||
def make_schedule(scalar_or_schedule: Union[float, Schedule]) -> Schedule:
|
||||
if callable(scalar_or_schedule):
|
||||
return scalar_or_schedule
|
||||
elif jnp.ndim(scalar_or_schedule) == 0:
|
||||
return constant(scalar_or_schedule)
|
||||
else:
|
||||
raise TypeError(type(scalar_or_schedule))
|
||||
|
||||
|
||||
### utilities
|
||||
|
||||
def l2_norm(tree):
|
||||
"""Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
|
||||
leaves, _ = tree_flatten(tree)
|
||||
return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves))
|
||||
|
||||
def clip_grads(grad_tree, max_norm):
|
||||
"""Clip gradients stored as a pytree of arrays to maximum norm `max_norm`."""
|
||||
norm = l2_norm(grad_tree)
|
||||
normalize = lambda g: jnp.where(norm < max_norm, g, g * (max_norm / norm))
|
||||
return tree_map(normalize, grad_tree)
|
||||
|
||||
|
||||
### serialization utilities
|
||||
|
||||
class JoinPoint(object):
|
||||
"""Marks the boundary between two joined (nested) pytrees."""
|
||||
def __init__(self, subtree):
|
||||
self.subtree = subtree
|
||||
|
||||
# Since pytrees are containers of numpy arrays, look iterable.
|
||||
def __iter__(self):
|
||||
yield self.subtree
|
||||
|
||||
def unpack_optimizer_state(opt_state):
|
||||
"""Converts an OptimizerState to a marked pytree.
|
||||
|
||||
Converts an OptimizerState to a marked pytree with the leaves of the outer
|
||||
pytree represented as JoinPoints to avoid losing information. This function is
|
||||
intended to be useful when serializing optimizer states.
|
||||
|
||||
Args:
|
||||
opt_state: An OptimizerState
|
||||
Returns:
|
||||
A pytree with JoinPoint leaves that contain a second level of pytrees.
|
||||
"""
|
||||
states_flat, tree_def, subtree_defs = opt_state
|
||||
subtrees = map(tree_unflatten, subtree_defs, states_flat)
|
||||
sentinels = [JoinPoint(subtree) for subtree in subtrees]
|
||||
return tree_util.tree_unflatten(tree_def, sentinels)
|
||||
|
||||
def pack_optimizer_state(marked_pytree):
|
||||
"""Converts a marked pytree to an OptimizerState.
|
||||
|
||||
The inverse of unpack_optimizer_state. Converts a marked pytree with the
|
||||
leaves of the outer pytree represented as JoinPoints back into an
|
||||
OptimizerState. This function is intended to be useful when deserializing
|
||||
optimizer states.
|
||||
|
||||
Args:
|
||||
marked_pytree: A pytree containing JoinPoint leaves that hold more pytrees.
|
||||
Returns:
|
||||
An equivalent OptimizerState to the input argument.
|
||||
"""
|
||||
sentinels, tree_def = tree_flatten(marked_pytree)
|
||||
assert all(isinstance(s, JoinPoint) for s in sentinels)
|
||||
subtrees = [s.subtree for s in sentinels]
|
||||
states_flat, subtree_defs = unzip2(map(tree_flatten, subtrees))
|
||||
return OptimizerState(states_flat, tree_def, subtree_defs)
|
351
jax/example_libraries/stax.py
Normal file
351
jax/example_libraries/stax.py
Normal file
@ -0,0 +1,351 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Stax is a small but flexible neural net specification library from scratch.
|
||||
|
||||
For an example of its use, see examples/resnet50.py.
|
||||
"""
|
||||
|
||||
|
||||
import functools
|
||||
import operator as op
|
||||
|
||||
from jax import lax
|
||||
from jax import random
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.nn import (relu, log_softmax, softmax, softplus, sigmoid, elu,
|
||||
leaky_relu, selu, gelu, normalize)
|
||||
from jax.nn.initializers import glorot_normal, normal, ones, zeros
|
||||
|
||||
# aliases for backwards compatibility
|
||||
glorot = glorot_normal
|
||||
randn = normal
|
||||
logsoftmax = log_softmax
|
||||
|
||||
# Following the convention used in Keras and tf.layers, we use CamelCase for the
|
||||
# names of layer constructors, like Conv and Relu, while using snake_case for
|
||||
# other functions, like lax.conv and relu.
|
||||
|
||||
# Each layer constructor function returns an (init_fun, apply_fun) pair, where
|
||||
# init_fun: takes an rng key and an input shape and returns an
|
||||
# (output_shape, params) pair,
|
||||
# apply_fun: takes params, inputs, and an rng key and applies the layer.
|
||||
|
||||
|
||||
def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
|
||||
"""Layer constructor function for a dense (fully-connected) layer."""
|
||||
def init_fun(rng, input_shape):
|
||||
output_shape = input_shape[:-1] + (out_dim,)
|
||||
k1, k2 = random.split(rng)
|
||||
W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
|
||||
return output_shape, (W, b)
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
W, b = params
|
||||
return jnp.dot(inputs, W) + b
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
def GeneralConv(dimension_numbers, out_chan, filter_shape,
|
||||
strides=None, padding='VALID', W_init=None,
|
||||
b_init=normal(1e-6)):
|
||||
"""Layer construction function for a general convolution layer."""
|
||||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||||
one = (1,) * len(filter_shape)
|
||||
strides = strides or one
|
||||
W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))
|
||||
def init_fun(rng, input_shape):
|
||||
filter_shape_iter = iter(filter_shape)
|
||||
kernel_shape = [out_chan if c == 'O' else
|
||||
input_shape[lhs_spec.index('C')] if c == 'I' else
|
||||
next(filter_shape_iter) for c in rhs_spec]
|
||||
output_shape = lax.conv_general_shape_tuple(
|
||||
input_shape, kernel_shape, strides, padding, dimension_numbers)
|
||||
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
|
||||
k1, k2 = random.split(rng)
|
||||
W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
|
||||
return output_shape, (W, b)
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
W, b = params
|
||||
return lax.conv_general_dilated(inputs, W, strides, padding, one, one,
|
||||
dimension_numbers=dimension_numbers) + b
|
||||
return init_fun, apply_fun
|
||||
Conv = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
|
||||
|
||||
|
||||
def GeneralConvTranspose(dimension_numbers, out_chan, filter_shape,
|
||||
strides=None, padding='VALID', W_init=None,
|
||||
b_init=normal(1e-6)):
|
||||
"""Layer construction function for a general transposed-convolution layer."""
|
||||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||||
one = (1,) * len(filter_shape)
|
||||
strides = strides or one
|
||||
W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))
|
||||
def init_fun(rng, input_shape):
|
||||
filter_shape_iter = iter(filter_shape)
|
||||
kernel_shape = [out_chan if c == 'O' else
|
||||
input_shape[lhs_spec.index('C')] if c == 'I' else
|
||||
next(filter_shape_iter) for c in rhs_spec]
|
||||
output_shape = lax.conv_transpose_shape_tuple(
|
||||
input_shape, kernel_shape, strides, padding, dimension_numbers)
|
||||
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
|
||||
k1, k2 = random.split(rng)
|
||||
W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
|
||||
return output_shape, (W, b)
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
W, b = params
|
||||
return lax.conv_transpose(inputs, W, strides, padding,
|
||||
dimension_numbers=dimension_numbers) + b
|
||||
return init_fun, apply_fun
|
||||
Conv1DTranspose = functools.partial(GeneralConvTranspose, ('NHC', 'HIO', 'NHC'))
|
||||
ConvTranspose = functools.partial(GeneralConvTranspose,
|
||||
('NHWC', 'HWIO', 'NHWC'))
|
||||
|
||||
|
||||
def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
|
||||
beta_init=zeros, gamma_init=ones):
|
||||
"""Layer construction function for a batch normalization layer."""
|
||||
_beta_init = lambda rng, shape: beta_init(rng, shape) if center else ()
|
||||
_gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else ()
|
||||
axis = (axis,) if jnp.isscalar(axis) else axis
|
||||
def init_fun(rng, input_shape):
|
||||
shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
|
||||
k1, k2 = random.split(rng)
|
||||
beta, gamma = _beta_init(k1, shape), _gamma_init(k2, shape)
|
||||
return input_shape, (beta, gamma)
|
||||
def apply_fun(params, x, **kwargs):
|
||||
beta, gamma = params
|
||||
# TODO(phawkins): jnp.expand_dims should accept an axis tuple.
|
||||
# (https://github.com/numpy/numpy/issues/12290)
|
||||
ed = tuple(None if i in axis else slice(None) for i in range(jnp.ndim(x)))
|
||||
z = normalize(x, axis, epsilon=epsilon)
|
||||
if center and scale: return gamma[ed] * z + beta[ed]
|
||||
if center: return z + beta[ed]
|
||||
if scale: return gamma[ed] * z
|
||||
return z
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
def elementwise(fun, **fun_kwargs):
|
||||
"""Layer that applies a scalar function elementwise on its inputs."""
|
||||
init_fun = lambda rng, input_shape: (input_shape, ())
|
||||
apply_fun = lambda params, inputs, **kwargs: fun(inputs, **fun_kwargs)
|
||||
return init_fun, apply_fun
|
||||
Tanh = elementwise(jnp.tanh)
|
||||
Relu = elementwise(relu)
|
||||
Exp = elementwise(jnp.exp)
|
||||
LogSoftmax = elementwise(log_softmax, axis=-1)
|
||||
Softmax = elementwise(softmax, axis=-1)
|
||||
Softplus = elementwise(softplus)
|
||||
Sigmoid = elementwise(sigmoid)
|
||||
Elu = elementwise(elu)
|
||||
LeakyRelu = elementwise(leaky_relu)
|
||||
Selu = elementwise(selu)
|
||||
Gelu = elementwise(gelu)
|
||||
|
||||
|
||||
def _pooling_layer(reducer, init_val, rescaler=None):
|
||||
def PoolingLayer(window_shape, strides=None, padding='VALID', spec=None):
|
||||
"""Layer construction function for a pooling layer."""
|
||||
strides = strides or (1,) * len(window_shape)
|
||||
rescale = rescaler(window_shape, strides, padding) if rescaler else None
|
||||
|
||||
if spec is None:
|
||||
non_spatial_axes = 0, len(window_shape) + 1
|
||||
else:
|
||||
non_spatial_axes = spec.index('N'), spec.index('C')
|
||||
|
||||
for i in sorted(non_spatial_axes):
|
||||
window_shape = window_shape[:i] + (1,) + window_shape[i:]
|
||||
strides = strides[:i] + (1,) + strides[i:]
|
||||
|
||||
def init_fun(rng, input_shape):
|
||||
padding_vals = lax.padtype_to_pads(input_shape, window_shape,
|
||||
strides, padding)
|
||||
ones = (1,) * len(window_shape)
|
||||
out_shape = lax.reduce_window_shape_tuple(
|
||||
input_shape, window_shape, strides, padding_vals, ones, ones)
|
||||
return out_shape, ()
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
out = lax.reduce_window(inputs, init_val, reducer, window_shape,
|
||||
strides, padding)
|
||||
return rescale(out, inputs, spec) if rescale else out
|
||||
return init_fun, apply_fun
|
||||
return PoolingLayer
|
||||
MaxPool = _pooling_layer(lax.max, -jnp.inf)
|
||||
SumPool = _pooling_layer(lax.add, 0.)
|
||||
|
||||
|
||||
def _normalize_by_window_size(dims, strides, padding):
|
||||
def rescale(outputs, inputs, spec):
|
||||
if spec is None:
|
||||
non_spatial_axes = 0, inputs.ndim - 1
|
||||
else:
|
||||
non_spatial_axes = spec.index('N'), spec.index('C')
|
||||
|
||||
spatial_shape = tuple(inputs.shape[i]
|
||||
for i in range(inputs.ndim)
|
||||
if i not in non_spatial_axes)
|
||||
one = jnp.ones(spatial_shape, dtype=inputs.dtype)
|
||||
window_sizes = lax.reduce_window(one, 0., lax.add, dims, strides, padding)
|
||||
for i in sorted(non_spatial_axes):
|
||||
window_sizes = jnp.expand_dims(window_sizes, i)
|
||||
|
||||
return outputs / window_sizes
|
||||
return rescale
|
||||
AvgPool = _pooling_layer(lax.add, 0., _normalize_by_window_size)
|
||||
|
||||
|
||||
def Flatten():
|
||||
"""Layer construction function for flattening all but the leading dim."""
|
||||
def init_fun(rng, input_shape):
|
||||
output_shape = input_shape[0], functools.reduce(op.mul, input_shape[1:], 1)
|
||||
return output_shape, ()
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
return jnp.reshape(inputs, (inputs.shape[0], -1))
|
||||
return init_fun, apply_fun
|
||||
Flatten = Flatten()
|
||||
|
||||
|
||||
def Identity():
|
||||
"""Layer construction function for an identity layer."""
|
||||
init_fun = lambda rng, input_shape: (input_shape, ())
|
||||
apply_fun = lambda params, inputs, **kwargs: inputs
|
||||
return init_fun, apply_fun
|
||||
Identity = Identity()
|
||||
|
||||
|
||||
def FanOut(num):
|
||||
"""Layer construction function for a fan-out layer."""
|
||||
init_fun = lambda rng, input_shape: ([input_shape] * num, ())
|
||||
apply_fun = lambda params, inputs, **kwargs: [inputs] * num
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
def FanInSum():
|
||||
"""Layer construction function for a fan-in sum layer."""
|
||||
init_fun = lambda rng, input_shape: (input_shape[0], ())
|
||||
apply_fun = lambda params, inputs, **kwargs: sum(inputs)
|
||||
return init_fun, apply_fun
|
||||
FanInSum = FanInSum()
|
||||
|
||||
|
||||
def FanInConcat(axis=-1):
|
||||
"""Layer construction function for a fan-in concatenation layer."""
|
||||
def init_fun(rng, input_shape):
|
||||
ax = axis % len(input_shape[0])
|
||||
concat_size = sum(shape[ax] for shape in input_shape)
|
||||
out_shape = input_shape[0][:ax] + (concat_size,) + input_shape[0][ax+1:]
|
||||
return out_shape, ()
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
return jnp.concatenate(inputs, axis)
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
def Dropout(rate, mode='train'):
|
||||
"""Layer construction function for a dropout layer with given rate."""
|
||||
def init_fun(rng, input_shape):
|
||||
return input_shape, ()
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
rng = kwargs.get('rng', None)
|
||||
if rng is None:
|
||||
msg = ("Dropout layer requires apply_fun to be called with a PRNG key "
|
||||
"argument. That is, instead of `apply_fun(params, inputs)`, call "
|
||||
"it like `apply_fun(params, inputs, rng)` where `rng` is a "
|
||||
"jax.random.PRNGKey value.")
|
||||
raise ValueError(msg)
|
||||
if mode == 'train':
|
||||
keep = random.bernoulli(rng, rate, inputs.shape)
|
||||
return jnp.where(keep, inputs / rate, 0)
|
||||
else:
|
||||
return inputs
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
# Composing layers via combinators
|
||||
|
||||
|
||||
def serial(*layers):
|
||||
"""Combinator for composing layers in serial.
|
||||
|
||||
Args:
|
||||
*layers: a sequence of layers, each an (init_fun, apply_fun) pair.
|
||||
|
||||
Returns:
|
||||
A new layer, meaning an (init_fun, apply_fun) pair, representing the serial
|
||||
composition of the given sequence of layers.
|
||||
"""
|
||||
nlayers = len(layers)
|
||||
init_funs, apply_funs = zip(*layers)
|
||||
def init_fun(rng, input_shape):
|
||||
params = []
|
||||
for init_fun in init_funs:
|
||||
rng, layer_rng = random.split(rng)
|
||||
input_shape, param = init_fun(layer_rng, input_shape)
|
||||
params.append(param)
|
||||
return input_shape, params
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
rng = kwargs.pop('rng', None)
|
||||
rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
|
||||
for fun, param, rng in zip(apply_funs, params, rngs):
|
||||
inputs = fun(param, inputs, rng=rng, **kwargs)
|
||||
return inputs
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
def parallel(*layers):
|
||||
"""Combinator for composing layers in parallel.
|
||||
|
||||
The layer resulting from this combinator is often used with the FanOut and
|
||||
FanInSum layers.
|
||||
|
||||
Args:
|
||||
*layers: a sequence of layers, each an (init_fun, apply_fun) pair.
|
||||
|
||||
Returns:
|
||||
A new layer, meaning an (init_fun, apply_fun) pair, representing the
|
||||
parallel composition of the given sequence of layers. In particular, the
|
||||
returned layer takes a sequence of inputs and returns a sequence of outputs
|
||||
with the same length as the argument `layers`.
|
||||
"""
|
||||
nlayers = len(layers)
|
||||
init_funs, apply_funs = zip(*layers)
|
||||
def init_fun(rng, input_shape):
|
||||
rngs = random.split(rng, nlayers)
|
||||
return zip(*[init(rng, shape) for init, rng, shape
|
||||
in zip(init_funs, rngs, input_shape)])
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
rng = kwargs.pop('rng', None)
|
||||
rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
|
||||
return [f(p, x, rng=r, **kwargs) for f, p, x, r in zip(apply_funs, params, inputs, rngs)]
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
def shape_dependent(make_layer):
|
||||
"""Combinator to delay layer constructor pair until input shapes are known.
|
||||
|
||||
Args:
|
||||
make_layer: a one-argument function that takes an input shape as an argument
|
||||
(a tuple of positive integers) and returns an (init_fun, apply_fun) pair.
|
||||
|
||||
Returns:
|
||||
A new layer, meaning an (init_fun, apply_fun) pair, representing the same
|
||||
layer as returned by `make_layer` but with its construction delayed until
|
||||
input shapes are known.
|
||||
"""
|
||||
def init_fun(rng, input_shape):
|
||||
return make_layer(input_shape)[0](rng, input_shape)
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
return make_layer(inputs.shape)[1](params, inputs, **kwargs)
|
||||
return init_fun, apply_fun
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2018 Google LLC
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,602 +12,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Optimizers for use with JAX.
|
||||
"""Optimizers have moved to jax.example_libraries.optimizers
|
||||
|
||||
This module contains some convenient optimizer definitions, specifically
|
||||
initialization and update functions, which can be used with ndarrays or
|
||||
arbitrarily-nested tuple/list/dicts of ndarrays.
|
||||
|
||||
An optimizer is modeled as an ``(init_fun, update_fun, get_params)`` triple of
|
||||
functions, where the component functions have these signatures:
|
||||
|
||||
::
|
||||
|
||||
init_fun(params)
|
||||
|
||||
Args:
|
||||
params: pytree representing the initial parameters.
|
||||
|
||||
Returns:
|
||||
A pytree representing the initial optimizer state, which includes the
|
||||
initial parameters and may also include auxiliary values like initial
|
||||
momentum. The optimizer state pytree structure generally differs from that
|
||||
of `params`.
|
||||
|
||||
::
|
||||
|
||||
update_fun(step, grads, opt_state)
|
||||
|
||||
Args:
|
||||
step: integer representing the step index.
|
||||
grads: a pytree with the same structure as `get_params(opt_state)`
|
||||
representing the gradients to be used in updating the optimizer state.
|
||||
opt_state: a pytree representing the optimizer state to be updated.
|
||||
|
||||
Returns:
|
||||
A pytree with the same structure as the `opt_state` argument representing
|
||||
the updated optimizer state.
|
||||
|
||||
::
|
||||
|
||||
get_params(opt_state)
|
||||
|
||||
Args:
|
||||
opt_state: pytree representing an optimizer state.
|
||||
|
||||
Returns:
|
||||
A pytree representing the parameters extracted from `opt_state`, such that
|
||||
the invariant `params == get_params(init_fun(params))` holds true.
|
||||
|
||||
|
||||
Notice that an optimizer implementation has a lot of flexibility in the form of
|
||||
opt_state: it just has to be a pytree of JaxTypes (so that it can be passed to
|
||||
the JAX transforms defined in api.py) and it has to be consumable by update_fun
|
||||
and get_params.
|
||||
|
||||
Example Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
|
||||
opt_state = opt_init(params)
|
||||
|
||||
def step(step, opt_state):
|
||||
value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state))
|
||||
opt_state = opt_update(step, grads, opt_state)
|
||||
return value, opt_state
|
||||
|
||||
for i in range(num_steps):
|
||||
value, opt_state = step(i, opt_state)
|
||||
jax.experimental.optimizers is deprecated and will delegate to
|
||||
jax.example_libraries.optimizers with a warning for backwards-compatibility
|
||||
for a limited time.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, NamedTuple, Tuple, Union
|
||||
import warnings
|
||||
|
||||
from collections import namedtuple
|
||||
import functools
|
||||
from functools import partial
|
||||
from jax.example_libraries.optimizers import * # noqa: F401,F403
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax._src.util import safe_zip, safe_map, unzip2
|
||||
from jax import tree_util
|
||||
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
register_pytree_node)
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
|
||||
|
||||
# The implementation here basically works by flattening pytrees. There are two
|
||||
# levels of pytrees to think about: the pytree of params, which we can think of
|
||||
# as defining an "outer pytree", and a pytree produced by applying init_fun to
|
||||
# each leaf of the params pytree, which we can think of as the "inner pytrees".
|
||||
# Since pytrees can be flattened, that structure is isomorphic to a list of
|
||||
# lists (with no further nesting).
|
||||
|
||||
OptimizerState = namedtuple("OptimizerState",
|
||||
["packed_state", "tree_def", "subtree_defs"])
|
||||
register_pytree_node(
|
||||
OptimizerState,
|
||||
lambda xs: ((xs.packed_state,), (xs.tree_def, xs.subtree_defs)),
|
||||
lambda data, xs: OptimizerState(xs[0], data[0], data[1])) # type: ignore[index]
|
||||
|
||||
|
||||
Array = Any
|
||||
Params = Any # Parameters are arbitrary nests of `jnp.ndarrays`.
|
||||
State = Any # internal State
|
||||
Updates = Params # Gradient updates are of the same type as parameters.
|
||||
|
||||
InitFn = Callable[[Params], OptimizerState]
|
||||
Step = int
|
||||
UpdateFn = Callable[[Step, Updates, OptimizerState], OptimizerState]
|
||||
ParamsFn = Callable[[OptimizerState], Params]
|
||||
|
||||
class Optimizer(NamedTuple):
|
||||
init_fn: InitFn
|
||||
update_fn: UpdateFn
|
||||
params_fn: ParamsFn
|
||||
|
||||
Schedule = Callable[[Step], float]
|
||||
|
||||
def optimizer(opt_maker: Callable[...,
|
||||
Tuple[Callable[[Params], State],
|
||||
Callable[[Step, Updates, Params], Params],
|
||||
Callable[[State], Params]]]) -> Callable[..., Optimizer]:
|
||||
"""Decorator to make an optimizer defined for arrays generalize to containers.
|
||||
|
||||
With this decorator, you can write init, update, and get_params functions that
|
||||
each operate only on single arrays, and convert them to corresponding
|
||||
functions that operate on pytrees of parameters. See the optimizers defined in
|
||||
optimizers.py for examples.
|
||||
|
||||
Args:
|
||||
opt_maker: a function that returns an ``(init_fun, update_fun, get_params)``
|
||||
triple of functions that might only work with ndarrays, as per
|
||||
|
||||
.. code-block:: haskell
|
||||
|
||||
init_fun :: ndarray -> OptStatePytree ndarray
|
||||
update_fun :: OptStatePytree ndarray -> OptStatePytree ndarray
|
||||
get_params :: OptStatePytree ndarray -> ndarray
|
||||
|
||||
Returns:
|
||||
An ``(init_fun, update_fun, get_params)`` triple of functions that work on
|
||||
arbitrary pytrees, as per
|
||||
|
||||
.. code-block:: haskell
|
||||
|
||||
init_fun :: ParameterPytree ndarray -> OptimizerState
|
||||
update_fun :: OptimizerState -> OptimizerState
|
||||
get_params :: OptimizerState -> ParameterPytree ndarray
|
||||
|
||||
The OptimizerState pytree type used by the returned functions is isomorphic
|
||||
to ``ParameterPytree (OptStatePytree ndarray)``, but may store the state
|
||||
instead as e.g. a partially-flattened data structure for performance.
|
||||
"""
|
||||
|
||||
@functools.wraps(opt_maker)
|
||||
def tree_opt_maker(*args, **kwargs):
|
||||
init, update, get_params = opt_maker(*args, **kwargs)
|
||||
|
||||
@functools.wraps(init)
|
||||
def tree_init(x0_tree):
|
||||
x0_flat, tree = tree_flatten(x0_tree)
|
||||
initial_states = [init(x0) for x0 in x0_flat]
|
||||
states_flat, subtrees = unzip2(map(tree_flatten, initial_states))
|
||||
return OptimizerState(states_flat, tree, subtrees)
|
||||
|
||||
@functools.wraps(update)
|
||||
def tree_update(i, grad_tree, opt_state):
|
||||
states_flat, tree, subtrees = opt_state
|
||||
grad_flat, tree2 = tree_flatten(grad_tree)
|
||||
if tree2 != tree:
|
||||
msg = ("optimizer update function was passed a gradient tree that did "
|
||||
"not match the parameter tree structure with which it was "
|
||||
"initialized: parameter tree {} and grad tree {}.")
|
||||
raise TypeError(msg.format(tree, tree2))
|
||||
states = map(tree_unflatten, subtrees, states_flat)
|
||||
new_states = map(partial(update, i), grad_flat, states)
|
||||
new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states))
|
||||
for subtree, subtree2 in zip(subtrees, subtrees2):
|
||||
if subtree2 != subtree:
|
||||
msg = ("optimizer update function produced an output structure that "
|
||||
"did not match its input structure: input {} and output {}.")
|
||||
raise TypeError(msg.format(subtree, subtree2))
|
||||
return OptimizerState(new_states_flat, tree, subtrees)
|
||||
|
||||
@functools.wraps(get_params)
|
||||
def tree_get_params(opt_state):
|
||||
states_flat, tree, subtrees = opt_state
|
||||
states = map(tree_unflatten, subtrees, states_flat)
|
||||
params = map(get_params, states)
|
||||
return tree_unflatten(tree, params)
|
||||
|
||||
return Optimizer(tree_init, tree_update, tree_get_params)
|
||||
return tree_opt_maker
|
||||
|
||||
|
||||
### optimizers
|
||||
|
||||
@optimizer
|
||||
def sgd(step_size):
|
||||
"""Construct optimizer triple for stochastic gradient descent.
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
return x0
|
||||
def update(i, g, x):
|
||||
return x - step_size(i) * g
|
||||
def get_params(x):
|
||||
return x
|
||||
return Optimizer(init, update, get_params)
|
||||
|
||||
@optimizer
|
||||
def momentum(step_size: Schedule, mass: float):
|
||||
"""Construct optimizer triple for SGD with momentum.
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
mass: positive scalar representing the momentum coefficient.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
v0 = jnp.zeros_like(x0)
|
||||
return x0, v0
|
||||
def update(i, g, state):
|
||||
x, velocity = state
|
||||
velocity = mass * velocity + g
|
||||
x = x - step_size(i) * velocity
|
||||
return x, velocity
|
||||
def get_params(state):
|
||||
x, _ = state
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
@optimizer
|
||||
def nesterov(step_size: Schedule, mass: float):
|
||||
"""Construct optimizer triple for SGD with Nesterov momentum.
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
mass: positive scalar representing the momentum coefficient.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
v0 = jnp.zeros_like(x0)
|
||||
return x0, v0
|
||||
def update(i, g, state):
|
||||
x, velocity = state
|
||||
velocity = mass * velocity + g
|
||||
x = x - step_size(i) * (mass * velocity + g)
|
||||
return x, velocity
|
||||
def get_params(state):
|
||||
x, _ = state
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
@optimizer
|
||||
def adagrad(step_size, momentum=0.9):
|
||||
"""Construct optimizer triple for Adagrad.
|
||||
|
||||
Adaptive Subgradient Methods for Online Learning and Stochastic Optimization:
|
||||
http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
momentum: optional, a positive scalar value for momentum
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
|
||||
def init(x0):
|
||||
g_sq = jnp.zeros_like(x0)
|
||||
m = jnp.zeros_like(x0)
|
||||
return x0, g_sq, m
|
||||
|
||||
def update(i, g, state):
|
||||
x, g_sq, m = state
|
||||
g_sq += jnp.square(g)
|
||||
g_sq_inv_sqrt = jnp.where(g_sq > 0, 1. / jnp.sqrt(g_sq), 0.0)
|
||||
m = (1. - momentum) * (g * g_sq_inv_sqrt) + momentum * m
|
||||
x = x - step_size(i) * m
|
||||
return x, g_sq, m
|
||||
|
||||
def get_params(state):
|
||||
x, _, _ = state
|
||||
return x
|
||||
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
@optimizer
|
||||
def rmsprop(step_size, gamma=0.9, eps=1e-8):
|
||||
"""Construct optimizer triple for RMSProp.
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
gamma: Decay parameter.
|
||||
eps: Epsilon parameter.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
avg_sq_grad = jnp.zeros_like(x0)
|
||||
return x0, avg_sq_grad
|
||||
def update(i, g, state):
|
||||
x, avg_sq_grad = state
|
||||
avg_sq_grad = avg_sq_grad * gamma + jnp.square(g) * (1. - gamma)
|
||||
x = x - step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
|
||||
return x, avg_sq_grad
|
||||
def get_params(state):
|
||||
x, _ = state
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
@optimizer
|
||||
def rmsprop_momentum(step_size, gamma=0.9, eps=1e-8, momentum=0.9):
|
||||
"""Construct optimizer triple for RMSProp with momentum.
|
||||
|
||||
This optimizer is separate from the rmsprop optimizer because it needs to
|
||||
keep track of additional parameters.
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
gamma: Decay parameter.
|
||||
eps: Epsilon parameter.
|
||||
momentum: Momentum parameter.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
avg_sq_grad = jnp.zeros_like(x0)
|
||||
mom = jnp.zeros_like(x0)
|
||||
return x0, avg_sq_grad, mom
|
||||
def update(i, g, state):
|
||||
x, avg_sq_grad, mom = state
|
||||
avg_sq_grad = avg_sq_grad * gamma + jnp.square(g) * (1. - gamma)
|
||||
mom = momentum * mom + step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
|
||||
x = x - mom
|
||||
return x, avg_sq_grad, mom
|
||||
def get_params(state):
|
||||
x, _, _ = state
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
@optimizer
|
||||
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
|
||||
"""Construct optimizer triple for Adam.
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
b1: optional, a positive scalar value for beta_1, the exponential decay rate
|
||||
for the first moment estimates (default 0.9).
|
||||
b2: optional, a positive scalar value for beta_2, the exponential decay rate
|
||||
for the second moment estimates (default 0.999).
|
||||
eps: optional, a positive scalar value for epsilon, a small constant for
|
||||
numerical stability (default 1e-8).
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
m0 = jnp.zeros_like(x0)
|
||||
v0 = jnp.zeros_like(x0)
|
||||
return x0, m0, v0
|
||||
def update(i, g, state):
|
||||
x, m, v = state
|
||||
m = (1 - b1) * g + b1 * m # First moment estimate.
|
||||
v = (1 - b2) * jnp.square(g) + b2 * v # Second moment estimate.
|
||||
mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
|
||||
vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
|
||||
x = x - step_size(i) * mhat / (jnp.sqrt(vhat) + eps)
|
||||
return x, m, v
|
||||
def get_params(state):
|
||||
x, _, _ = state
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
@optimizer
|
||||
def adamax(step_size, b1=0.9, b2=0.999, eps=1e-8):
|
||||
"""Construct optimizer triple for AdaMax (a variant of Adam based on infinity norm).
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
b1: optional, a positive scalar value for beta_1, the exponential decay rate
|
||||
for the first moment estimates (default 0.9).
|
||||
b2: optional, a positive scalar value for beta_2, the exponential decay rate
|
||||
for the second moment estimates (default 0.999).
|
||||
eps: optional, a positive scalar value for epsilon, a small constant for
|
||||
numerical stability (default 1e-8).
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
def init(x0):
|
||||
m0 = jnp.zeros_like(x0)
|
||||
u0 = jnp.zeros_like(x0)
|
||||
return x0, m0, u0
|
||||
def update(i, g, state):
|
||||
x, m, u = state
|
||||
m = (1 - b1) * g + b1 * m # First moment estimate.
|
||||
u = jnp.maximum(b2 * u, jnp.abs(g)) # Update exponentially weighted infinity norm.
|
||||
x = (x - (step_size(i) / (1 - jnp.asarray(b1, m.dtype) ** (i + 1))) * m
|
||||
/ (u + eps))
|
||||
return x, m, u
|
||||
def get_params(state):
|
||||
x, _, _ = state
|
||||
return x
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
@optimizer
|
||||
def sm3(step_size, momentum=0.9):
|
||||
"""Construct optimizer triple for SM3.
|
||||
|
||||
Memory-Efficient Adaptive Optimization for Large-Scale Learning.
|
||||
https://arxiv.org/abs/1901.11150
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to a positive scalar.
|
||||
momentum: optional, a positive scalar value for momentum
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
"""
|
||||
step_size = make_schedule(step_size)
|
||||
|
||||
def splice(seq, i, x):
|
||||
lst = list(seq)
|
||||
lst[i:i+1] = x
|
||||
return lst
|
||||
|
||||
def broadcast_into(ndim, x, axis):
|
||||
idx = splice([None] * ndim, axis, [slice(None)])
|
||||
return x[tuple(idx)]
|
||||
|
||||
def init(x0):
|
||||
x_shape = x0.shape
|
||||
x0 = jnp.atleast_1d(x0)
|
||||
vs = [jnp.zeros(sz, dtype=x0.dtype) for sz in x0.shape]
|
||||
return x0, jnp.zeros_like(x0), vs, x_shape
|
||||
|
||||
def update(i, g, state):
|
||||
x, m, vs, x_shape = state
|
||||
vs = [broadcast_into(g.ndim, v, i) for i, v in enumerate(vs)]
|
||||
accum = functools.reduce(jnp.minimum, vs) + jnp.square(g)
|
||||
accum_inv_sqrt = jnp.where(accum > 0, 1. / jnp.sqrt(accum), 0)
|
||||
m = (1. - momentum) * (g * accum_inv_sqrt) + momentum * m
|
||||
x = x - step_size(i) * m
|
||||
vs = [accum.max(splice(range(x.ndim), j, [])) for j in range(x.ndim)]
|
||||
return x, m, vs, x_shape
|
||||
|
||||
def get_params(state):
|
||||
x, _, _, x_shape = state
|
||||
return x.reshape(x_shape)
|
||||
|
||||
return init, update, get_params
|
||||
|
||||
|
||||
### learning rate schedules
|
||||
|
||||
def constant(step_size) -> Schedule:
|
||||
def schedule(i):
|
||||
return step_size
|
||||
return schedule
|
||||
|
||||
def exponential_decay(step_size, decay_steps, decay_rate):
|
||||
def schedule(i):
|
||||
return step_size * decay_rate ** (i / decay_steps)
|
||||
return schedule
|
||||
|
||||
def inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False):
|
||||
if staircase:
|
||||
def schedule(i):
|
||||
return step_size / (1 + decay_rate * jnp.floor(i / decay_steps))
|
||||
else:
|
||||
def schedule(i):
|
||||
return step_size / (1 + decay_rate * i / decay_steps)
|
||||
return schedule
|
||||
|
||||
def polynomial_decay(step_size, decay_steps, final_step_size, power=1.0):
|
||||
def schedule(step_num):
|
||||
step_num = jnp.minimum(step_num, decay_steps)
|
||||
step_mult = (1 - step_num / decay_steps) ** power
|
||||
return step_mult * (step_size - final_step_size) + final_step_size
|
||||
|
||||
return schedule
|
||||
|
||||
def piecewise_constant(boundaries: Any, values: Any):
|
||||
boundaries = jnp.array(boundaries)
|
||||
values = jnp.array(values)
|
||||
if not boundaries.ndim == values.ndim == 1:
|
||||
raise ValueError("boundaries and values must be sequences")
|
||||
if not boundaries.shape[0] == values.shape[0] - 1:
|
||||
raise ValueError("boundaries length must be one shorter than values length")
|
||||
|
||||
def schedule(i):
|
||||
return values[jnp.sum(i > boundaries)]
|
||||
return schedule
|
||||
|
||||
def make_schedule(scalar_or_schedule: Union[float, Schedule]) -> Schedule:
|
||||
if callable(scalar_or_schedule):
|
||||
return scalar_or_schedule
|
||||
elif jnp.ndim(scalar_or_schedule) == 0:
|
||||
return constant(scalar_or_schedule)
|
||||
else:
|
||||
raise TypeError(type(scalar_or_schedule))
|
||||
|
||||
|
||||
### utilities
|
||||
|
||||
def l2_norm(tree):
|
||||
"""Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
|
||||
leaves, _ = tree_flatten(tree)
|
||||
return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves))
|
||||
|
||||
def clip_grads(grad_tree, max_norm):
|
||||
"""Clip gradients stored as a pytree of arrays to maximum norm `max_norm`."""
|
||||
norm = l2_norm(grad_tree)
|
||||
normalize = lambda g: jnp.where(norm < max_norm, g, g * (max_norm / norm))
|
||||
return tree_map(normalize, grad_tree)
|
||||
|
||||
|
||||
### serialization utilities
|
||||
|
||||
class JoinPoint(object):
|
||||
"""Marks the boundary between two joined (nested) pytrees."""
|
||||
def __init__(self, subtree):
|
||||
self.subtree = subtree
|
||||
|
||||
# Since pytrees are containers of numpy arrays, look iterable.
|
||||
def __iter__(self):
|
||||
yield self.subtree
|
||||
|
||||
def unpack_optimizer_state(opt_state):
|
||||
"""Converts an OptimizerState to a marked pytree.
|
||||
|
||||
Converts an OptimizerState to a marked pytree with the leaves of the outer
|
||||
pytree represented as JoinPoints to avoid losing information. This function is
|
||||
intended to be useful when serializing optimizer states.
|
||||
|
||||
Args:
|
||||
opt_state: An OptimizerState
|
||||
Returns:
|
||||
A pytree with JoinPoint leaves that contain a second level of pytrees.
|
||||
"""
|
||||
states_flat, tree_def, subtree_defs = opt_state
|
||||
subtrees = map(tree_unflatten, subtree_defs, states_flat)
|
||||
sentinels = [JoinPoint(subtree) for subtree in subtrees]
|
||||
return tree_util.tree_unflatten(tree_def, sentinels)
|
||||
|
||||
def pack_optimizer_state(marked_pytree):
|
||||
"""Converts a marked pytree to an OptimizerState.
|
||||
|
||||
The inverse of unpack_optimizer_state. Converts a marked pytree with the
|
||||
leaves of the outer pytree represented as JoinPoints back into an
|
||||
OptimizerState. This function is intended to be useful when deserializing
|
||||
optimizer states.
|
||||
|
||||
Args:
|
||||
marked_pytree: A pytree containing JoinPoint leaves that hold more pytrees.
|
||||
Returns:
|
||||
An equivalent OptimizerState to the input argument.
|
||||
"""
|
||||
sentinels, tree_def = tree_flatten(marked_pytree)
|
||||
assert all(isinstance(s, JoinPoint) for s in sentinels)
|
||||
subtrees = [s.subtree for s in sentinels]
|
||||
states_flat, subtree_defs = unzip2(map(tree_flatten, subtrees))
|
||||
return OptimizerState(states_flat, tree_def, subtree_defs)
|
||||
warnings.warn('jax.experimental.optimizers is deprecated, '
|
||||
'import jax.example_libraries.optimizers instead',
|
||||
FutureWarning)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2018 Google LLC
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,340 +12,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Stax is a small but flexible neural net specification library from scratch.
|
||||
"""Stax has moved to jax.example_libraries.stax
|
||||
|
||||
For an example of its use, see examples/resnet50.py.
|
||||
jax.experimental.stax is deprecated and will delegate to
|
||||
jax.example_libraries.stax with a warning for backwards-compatibility
|
||||
for a limited time.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
import functools
|
||||
import operator as op
|
||||
from jax.example_libraries.stax import * # noqa: F401,F403
|
||||
|
||||
from jax import lax
|
||||
from jax import random
|
||||
import jax.numpy as jnp
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
|
||||
from jax.nn import (relu, log_softmax, softmax, softplus, sigmoid, elu,
|
||||
leaky_relu, selu, gelu, normalize)
|
||||
from jax.nn.initializers import glorot_normal, normal, ones, zeros
|
||||
|
||||
# aliases for backwards compatibility
|
||||
glorot = glorot_normal
|
||||
randn = normal
|
||||
logsoftmax = log_softmax
|
||||
|
||||
# Following the convention used in Keras and tf.layers, we use CamelCase for the
|
||||
# names of layer constructors, like Conv and Relu, while using snake_case for
|
||||
# other functions, like lax.conv and relu.
|
||||
|
||||
# Each layer constructor function returns an (init_fun, apply_fun) pair, where
|
||||
# init_fun: takes an rng key and an input shape and returns an
|
||||
# (output_shape, params) pair,
|
||||
# apply_fun: takes params, inputs, and an rng key and applies the layer.
|
||||
|
||||
|
||||
def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
|
||||
"""Layer constructor function for a dense (fully-connected) layer."""
|
||||
def init_fun(rng, input_shape):
|
||||
output_shape = input_shape[:-1] + (out_dim,)
|
||||
k1, k2 = random.split(rng)
|
||||
W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
|
||||
return output_shape, (W, b)
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
W, b = params
|
||||
return jnp.dot(inputs, W) + b
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
def GeneralConv(dimension_numbers, out_chan, filter_shape,
|
||||
strides=None, padding='VALID', W_init=None,
|
||||
b_init=normal(1e-6)):
|
||||
"""Layer construction function for a general convolution layer."""
|
||||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||||
one = (1,) * len(filter_shape)
|
||||
strides = strides or one
|
||||
W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))
|
||||
def init_fun(rng, input_shape):
|
||||
filter_shape_iter = iter(filter_shape)
|
||||
kernel_shape = [out_chan if c == 'O' else
|
||||
input_shape[lhs_spec.index('C')] if c == 'I' else
|
||||
next(filter_shape_iter) for c in rhs_spec]
|
||||
output_shape = lax.conv_general_shape_tuple(
|
||||
input_shape, kernel_shape, strides, padding, dimension_numbers)
|
||||
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
|
||||
k1, k2 = random.split(rng)
|
||||
W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
|
||||
return output_shape, (W, b)
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
W, b = params
|
||||
return lax.conv_general_dilated(inputs, W, strides, padding, one, one,
|
||||
dimension_numbers=dimension_numbers) + b
|
||||
return init_fun, apply_fun
|
||||
Conv = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
|
||||
|
||||
|
||||
def GeneralConvTranspose(dimension_numbers, out_chan, filter_shape,
|
||||
strides=None, padding='VALID', W_init=None,
|
||||
b_init=normal(1e-6)):
|
||||
"""Layer construction function for a general transposed-convolution layer."""
|
||||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||||
one = (1,) * len(filter_shape)
|
||||
strides = strides or one
|
||||
W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))
|
||||
def init_fun(rng, input_shape):
|
||||
filter_shape_iter = iter(filter_shape)
|
||||
kernel_shape = [out_chan if c == 'O' else
|
||||
input_shape[lhs_spec.index('C')] if c == 'I' else
|
||||
next(filter_shape_iter) for c in rhs_spec]
|
||||
output_shape = lax.conv_transpose_shape_tuple(
|
||||
input_shape, kernel_shape, strides, padding, dimension_numbers)
|
||||
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
|
||||
k1, k2 = random.split(rng)
|
||||
W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
|
||||
return output_shape, (W, b)
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
W, b = params
|
||||
return lax.conv_transpose(inputs, W, strides, padding,
|
||||
dimension_numbers=dimension_numbers) + b
|
||||
return init_fun, apply_fun
|
||||
Conv1DTranspose = functools.partial(GeneralConvTranspose, ('NHC', 'HIO', 'NHC'))
|
||||
ConvTranspose = functools.partial(GeneralConvTranspose,
|
||||
('NHWC', 'HWIO', 'NHWC'))
|
||||
|
||||
|
||||
def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
|
||||
beta_init=zeros, gamma_init=ones):
|
||||
"""Layer construction function for a batch normalization layer."""
|
||||
_beta_init = lambda rng, shape: beta_init(rng, shape) if center else ()
|
||||
_gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else ()
|
||||
axis = (axis,) if jnp.isscalar(axis) else axis
|
||||
def init_fun(rng, input_shape):
|
||||
shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
|
||||
k1, k2 = random.split(rng)
|
||||
beta, gamma = _beta_init(k1, shape), _gamma_init(k2, shape)
|
||||
return input_shape, (beta, gamma)
|
||||
def apply_fun(params, x, **kwargs):
|
||||
beta, gamma = params
|
||||
# TODO(phawkins): jnp.expand_dims should accept an axis tuple.
|
||||
# (https://github.com/numpy/numpy/issues/12290)
|
||||
ed = tuple(None if i in axis else slice(None) for i in range(jnp.ndim(x)))
|
||||
z = normalize(x, axis, epsilon=epsilon)
|
||||
if center and scale: return gamma[ed] * z + beta[ed]
|
||||
if center: return z + beta[ed]
|
||||
if scale: return gamma[ed] * z
|
||||
return z
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
def elementwise(fun, **fun_kwargs):
|
||||
"""Layer that applies a scalar function elementwise on its inputs."""
|
||||
init_fun = lambda rng, input_shape: (input_shape, ())
|
||||
apply_fun = lambda params, inputs, **kwargs: fun(inputs, **fun_kwargs)
|
||||
return init_fun, apply_fun
|
||||
Tanh = elementwise(jnp.tanh)
|
||||
Relu = elementwise(relu)
|
||||
Exp = elementwise(jnp.exp)
|
||||
LogSoftmax = elementwise(log_softmax, axis=-1)
|
||||
Softmax = elementwise(softmax, axis=-1)
|
||||
Softplus = elementwise(softplus)
|
||||
Sigmoid = elementwise(sigmoid)
|
||||
Elu = elementwise(elu)
|
||||
LeakyRelu = elementwise(leaky_relu)
|
||||
Selu = elementwise(selu)
|
||||
Gelu = elementwise(gelu)
|
||||
|
||||
|
||||
def _pooling_layer(reducer, init_val, rescaler=None):
|
||||
def PoolingLayer(window_shape, strides=None, padding='VALID', spec=None):
|
||||
"""Layer construction function for a pooling layer."""
|
||||
strides = strides or (1,) * len(window_shape)
|
||||
rescale = rescaler(window_shape, strides, padding) if rescaler else None
|
||||
|
||||
if spec is None:
|
||||
non_spatial_axes = 0, len(window_shape) + 1
|
||||
else:
|
||||
non_spatial_axes = spec.index('N'), spec.index('C')
|
||||
|
||||
for i in sorted(non_spatial_axes):
|
||||
window_shape = window_shape[:i] + (1,) + window_shape[i:]
|
||||
strides = strides[:i] + (1,) + strides[i:]
|
||||
|
||||
def init_fun(rng, input_shape):
|
||||
padding_vals = lax.padtype_to_pads(input_shape, window_shape,
|
||||
strides, padding)
|
||||
ones = (1,) * len(window_shape)
|
||||
out_shape = lax.reduce_window_shape_tuple(
|
||||
input_shape, window_shape, strides, padding_vals, ones, ones)
|
||||
return out_shape, ()
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
out = lax.reduce_window(inputs, init_val, reducer, window_shape,
|
||||
strides, padding)
|
||||
return rescale(out, inputs, spec) if rescale else out
|
||||
return init_fun, apply_fun
|
||||
return PoolingLayer
|
||||
MaxPool = _pooling_layer(lax.max, -jnp.inf)
|
||||
SumPool = _pooling_layer(lax.add, 0.)
|
||||
|
||||
|
||||
def _normalize_by_window_size(dims, strides, padding):
|
||||
def rescale(outputs, inputs, spec):
|
||||
if spec is None:
|
||||
non_spatial_axes = 0, inputs.ndim - 1
|
||||
else:
|
||||
non_spatial_axes = spec.index('N'), spec.index('C')
|
||||
|
||||
spatial_shape = tuple(inputs.shape[i]
|
||||
for i in range(inputs.ndim)
|
||||
if i not in non_spatial_axes)
|
||||
one = jnp.ones(spatial_shape, dtype=inputs.dtype)
|
||||
window_sizes = lax.reduce_window(one, 0., lax.add, dims, strides, padding)
|
||||
for i in sorted(non_spatial_axes):
|
||||
window_sizes = jnp.expand_dims(window_sizes, i)
|
||||
|
||||
return outputs / window_sizes
|
||||
return rescale
|
||||
AvgPool = _pooling_layer(lax.add, 0., _normalize_by_window_size)
|
||||
|
||||
|
||||
def Flatten():
|
||||
"""Layer construction function for flattening all but the leading dim."""
|
||||
def init_fun(rng, input_shape):
|
||||
output_shape = input_shape[0], functools.reduce(op.mul, input_shape[1:], 1)
|
||||
return output_shape, ()
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
return jnp.reshape(inputs, (inputs.shape[0], -1))
|
||||
return init_fun, apply_fun
|
||||
Flatten = Flatten()
|
||||
|
||||
|
||||
def Identity():
|
||||
"""Layer construction function for an identity layer."""
|
||||
init_fun = lambda rng, input_shape: (input_shape, ())
|
||||
apply_fun = lambda params, inputs, **kwargs: inputs
|
||||
return init_fun, apply_fun
|
||||
Identity = Identity()
|
||||
|
||||
|
||||
def FanOut(num):
|
||||
"""Layer construction function for a fan-out layer."""
|
||||
init_fun = lambda rng, input_shape: ([input_shape] * num, ())
|
||||
apply_fun = lambda params, inputs, **kwargs: [inputs] * num
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
def FanInSum():
|
||||
"""Layer construction function for a fan-in sum layer."""
|
||||
init_fun = lambda rng, input_shape: (input_shape[0], ())
|
||||
apply_fun = lambda params, inputs, **kwargs: sum(inputs)
|
||||
return init_fun, apply_fun
|
||||
FanInSum = FanInSum()
|
||||
|
||||
|
||||
def FanInConcat(axis=-1):
|
||||
"""Layer construction function for a fan-in concatenation layer."""
|
||||
def init_fun(rng, input_shape):
|
||||
ax = axis % len(input_shape[0])
|
||||
concat_size = sum(shape[ax] for shape in input_shape)
|
||||
out_shape = input_shape[0][:ax] + (concat_size,) + input_shape[0][ax+1:]
|
||||
return out_shape, ()
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
return jnp.concatenate(inputs, axis)
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
def Dropout(rate, mode='train'):
|
||||
"""Layer construction function for a dropout layer with given rate."""
|
||||
def init_fun(rng, input_shape):
|
||||
return input_shape, ()
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
rng = kwargs.get('rng', None)
|
||||
if rng is None:
|
||||
msg = ("Dropout layer requires apply_fun to be called with a PRNG key "
|
||||
"argument. That is, instead of `apply_fun(params, inputs)`, call "
|
||||
"it like `apply_fun(params, inputs, rng)` where `rng` is a "
|
||||
"jax.random.PRNGKey value.")
|
||||
raise ValueError(msg)
|
||||
if mode == 'train':
|
||||
keep = random.bernoulli(rng, rate, inputs.shape)
|
||||
return jnp.where(keep, inputs / rate, 0)
|
||||
else:
|
||||
return inputs
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
# Composing layers via combinators
|
||||
|
||||
|
||||
def serial(*layers):
|
||||
"""Combinator for composing layers in serial.
|
||||
|
||||
Args:
|
||||
*layers: a sequence of layers, each an (init_fun, apply_fun) pair.
|
||||
|
||||
Returns:
|
||||
A new layer, meaning an (init_fun, apply_fun) pair, representing the serial
|
||||
composition of the given sequence of layers.
|
||||
"""
|
||||
nlayers = len(layers)
|
||||
init_funs, apply_funs = zip(*layers)
|
||||
def init_fun(rng, input_shape):
|
||||
params = []
|
||||
for init_fun in init_funs:
|
||||
rng, layer_rng = random.split(rng)
|
||||
input_shape, param = init_fun(layer_rng, input_shape)
|
||||
params.append(param)
|
||||
return input_shape, params
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
rng = kwargs.pop('rng', None)
|
||||
rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
|
||||
for fun, param, rng in zip(apply_funs, params, rngs):
|
||||
inputs = fun(param, inputs, rng=rng, **kwargs)
|
||||
return inputs
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
def parallel(*layers):
|
||||
"""Combinator for composing layers in parallel.
|
||||
|
||||
The layer resulting from this combinator is often used with the FanOut and
|
||||
FanInSum layers.
|
||||
|
||||
Args:
|
||||
*layers: a sequence of layers, each an (init_fun, apply_fun) pair.
|
||||
|
||||
Returns:
|
||||
A new layer, meaning an (init_fun, apply_fun) pair, representing the
|
||||
parallel composition of the given sequence of layers. In particular, the
|
||||
returned layer takes a sequence of inputs and returns a sequence of outputs
|
||||
with the same length as the argument `layers`.
|
||||
"""
|
||||
nlayers = len(layers)
|
||||
init_funs, apply_funs = zip(*layers)
|
||||
def init_fun(rng, input_shape):
|
||||
rngs = random.split(rng, nlayers)
|
||||
return zip(*[init(rng, shape) for init, rng, shape
|
||||
in zip(init_funs, rngs, input_shape)])
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
rng = kwargs.pop('rng', None)
|
||||
rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
|
||||
return [f(p, x, rng=r, **kwargs) for f, p, x, r in zip(apply_funs, params, inputs, rngs)]
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
def shape_dependent(make_layer):
|
||||
"""Combinator to delay layer constructor pair until input shapes are known.
|
||||
|
||||
Args:
|
||||
make_layer: a one-argument function that takes an input shape as an argument
|
||||
(a tuple of positive integers) and returns an (init_fun, apply_fun) pair.
|
||||
|
||||
Returns:
|
||||
A new layer, meaning an (init_fun, apply_fun) pair, representing the same
|
||||
layer as returned by `make_layer` but with its construction delayed until
|
||||
input shapes are known.
|
||||
"""
|
||||
def init_fun(rng, input_shape):
|
||||
return make_layer(input_shape)[0](rng, input_shape)
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
return make_layer(inputs.shape)[1](params, inputs, **kwargs)
|
||||
return init_fun, apply_fun
|
||||
warnings.warn('jax.experimental.stax is deprecated, '
|
||||
'import jax.example_libraries.stax instead',
|
||||
FutureWarning)
|
||||
|
@ -13,5 +13,6 @@ filterwarnings =
|
||||
ignore:numpy.ufunc size changed
|
||||
ignore:.*experimental feature
|
||||
ignore:index.*is deprecated.*:DeprecationWarning
|
||||
ignore:jax.experimental.* is deprecated, import jax.example_libraries.* instead:FutureWarning
|
||||
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
|
||||
addopts = --doctest-glob="*.rst"
|
||||
|
@ -25,7 +25,7 @@ import jax.numpy as jnp
|
||||
import jax.scipy.special
|
||||
from jax import random
|
||||
from jax import jacfwd, jit
|
||||
from jax.experimental import stax
|
||||
from jax.example_libraries import stax
|
||||
from jax.experimental.jet import jet, fact, zero_series
|
||||
from jax import lax
|
||||
|
||||
|
@ -24,7 +24,7 @@ import jax._src.test_util as jtu
|
||||
from jax import jit, grad, jacfwd, jacrev
|
||||
from jax import tree_util
|
||||
from jax import lax
|
||||
from jax.experimental import optimizers
|
||||
from jax.example_libraries import optimizers
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -21,7 +21,7 @@ import numpy as np
|
||||
|
||||
from jax._src import test_util as jtu
|
||||
from jax import random
|
||||
from jax.experimental import stax
|
||||
from jax.example_libraries import stax
|
||||
from jax import dtypes
|
||||
|
||||
from jax.config import config
|
||||
|
Loading…
x
Reference in New Issue
Block a user