[JAX] move example libraries from jax.experimental into jax.example_libraries

The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction.

PiperOrigin-RevId: 404405186
This commit is contained in:
Roy Frostig 2021-10-19 17:30:16 -07:00 committed by jax authors
parent 349d0d0879
commit 623c201054
25 changed files with 1068 additions and 979 deletions

View File

@ -14,6 +14,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.24...main).
* Breaking changes
* Moved `jax.experimental.stax` to `jax.example_libraries.stax`
* Moved `jax.experimental.optimizers` to `jax.example_libraries.optimizers`
## jax 0.2.24 (Oct 19, 2021)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.22...jax-v0.2.24).

View File

@ -98,11 +98,11 @@ For a deeper dive into JAX:
notebooks](https://github.com/google/jax/tree/main/docs/notebooks).
You can also take a look at [the mini-libraries in
`jax.experimental`](https://github.com/google/jax/tree/main/jax/experimental/README.md),
`jax.example_libraries`](https://github.com/google/jax/tree/main/jax/experimental/README.md),
like [`stax` for building neural
networks](https://github.com/google/jax/tree/main/jax/experimental/README.md#neural-net-building-with-stax)
networks](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#neural-net-building-with-stax)
and [`optimizers` for first-order stochastic
optimization](https://github.com/google/jax/tree/main/jax/experimental/README.md#first-order-optimization),
optimization](https://github.com/google/jax/tree/main/jax/example_libraries/README.md#first-order-optimization),
or the [examples](https://github.com/google/jax/tree/main/examples).
## Transformations

View File

@ -0,0 +1,7 @@
jax.example_libraries.optimizers module
=======================================
.. automodule:: jax.example_libraries.optimizers
:members:
:undoc-members:
:show-inheritance:

View File

@ -0,0 +1,10 @@
jax.example_libraries package
=============================
.. toctree::
:maxdepth: 1
jax.example_libraries.optimizers
jax.example_libraries.stax
.. automodule:: jax.example_libraries

View File

@ -0,0 +1,7 @@
jax.example_libraries.stax module
=================================
.. automodule:: jax.example_libraries.stax
:members:
:undoc-members:
:show-inheritance:

View File

@ -1,7 +0,0 @@
jax.experimental.optimizers module
==================================
.. automodule:: jax.experimental.optimizers
:members:
:undoc-members:
:show-inheritance:

View File

@ -11,9 +11,7 @@ jax.experimental package
jax.experimental.loops
jax.experimental.maps
jax.experimental.pjit
jax.experimental.optimizers
jax.experimental.sparse
jax.experimental.stax
.. automodule:: jax.experimental

View File

@ -1,7 +0,0 @@
jax.experimental.stax module
============================
.. automodule:: jax.experimental.stax
:members:
:undoc-members:
:show-inheritance:

View File

@ -11,6 +11,7 @@ Subpackages
jax.numpy
jax.scipy
jax.example_libraries
jax.experimental
jax.image
jax.lax

View File

@ -24,7 +24,7 @@ import matplotlib.pyplot as plt
from jax import jit, grad, vmap
from jax import random
from jax.experimental import optimizers
from jax.example_libraries import optimizers
import jax.numpy as jnp
import jax.scipy.stats.norm as norm

View File

@ -75,8 +75,8 @@ from jax import grad
from jax import jit
from jax import random
from jax import vmap
from jax.experimental import optimizers
from jax.experimental import stax
from jax.example_libraries import optimizers
from jax.example_libraries import stax
from jax.tree_util import tree_flatten, tree_unflatten
import jax.numpy as jnp
from jax.examples import datasets

View File

@ -18,7 +18,7 @@ from functools import partial
import numpy.random as npr
import jax.numpy as jnp
from jax.experimental import optimizers
from jax.example_libraries import optimizers
from jax import grad, jit, make_jaxpr, vmap, lax

View File

@ -14,8 +14,8 @@
"""A basic MNIST example using JAX with the mini-libraries stax and optimizers.
The mini-library jax.experimental.stax is for neural network building, and
the mini-library jax.experimental.optimizers is for first-order stochastic
The mini-library jax.example_libraries.stax is for neural network building, and
the mini-library jax.example_libraries.optimizers is for first-order stochastic
optimization.
"""
@ -27,9 +27,9 @@ import numpy.random as npr
import jax.numpy as jnp
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax
from jax.example_libraries import optimizers
from jax.example_libraries import stax
from jax.example_libraries.stax import Dense, Relu, LogSoftmax
from examples import datasets

View File

@ -27,9 +27,9 @@ import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import jit, grad, lax, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
from jax.example_libraries import optimizers
from jax.example_libraries import stax
from jax.example_libraries.stax import Dense, FanOut, Relu, Softplus
from examples import datasets

View File

@ -22,11 +22,11 @@ import numpy.random as npr
import jax.numpy as jnp
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum,
FanOut, Flatten, GeneralConv, Identity,
MaxPool, Relu, LogSoftmax)
from jax.example_libraries import optimizers
from jax.example_libraries import stax
from jax.example_libraries.stax import (AvgPool, BatchNorm, Conv, Dense,
FanInSum, FanOut, Flatten, GeneralConv,
Identity, MaxPool, Relu, LogSoftmax)
# ResNet blocks compose other layers

View File

@ -25,13 +25,14 @@ constructor functions for common basic pairs, like `Conv` and `Relu`, and these
pairs can be composed in series using `stax.serial` or in parallel using
`stax.parallel`.
Heres an example:
Here's an example:
```python
import jax.numpy as jnp
from jax import random
from jax.experimental import stax
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax
from jax.example_libraries import stax
from jax.example_libraries.stax import (
Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax)
# Use stax to set up network initialization and evaluation functions
net_init, net_apply = stax.serial(
@ -54,20 +55,20 @@ predictions = net_apply(net_params, inputs)
### First-order optimization
JAX has a minimal optimization library focused on stochastic first-order
optimizers. Every optimizer is modeled as an `(init_fun, update_fun,
get_params)` triple of functions. The `init_fun` is used to initialize the
optimizer state, which could include things like momentum variables, and the
`update_fun` accepts a gradient and an optimizer state to produce a new
optimizer state. The `get_params` function extracts the current iterate (i.e.
the current parameters) from the optimizer state. The parameters being optimized
can be ndarrays or arbitrarily-nested list/tuple/dict structures, so you can
store your parameters however youd like.
The file `optimizers.py` contains a minimal optimization library focused on
stochastic first-order optimizers. Every optimizer is modeled as an
`(init_fun, update_fun, get_params)` triple of functions. The `init_fun` is used
to initialize the optimizer state, which could include things like momentum
variables, and the `update_fun` accepts a gradient and an optimizer state to
produce a new optimizer state. The `get_params` function extracts the current
iterate (i.e. the current parameters) from the optimizer state. The parameters
being optimized can be ndarrays or arbitrarily-nested list/tuple/dict
structures, so you can store your parameters however you'd like.
Heres an example, using `jit` to compile the whole update end-to-end:
Here's an example, using `jit` to compile the whole update end-to-end:
```python
from jax.experimental import optimizers
from jax.example_libraries import optimizers
from jax import jit, grad
# Define a simple squared-error loss

View File

@ -0,0 +1,13 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,613 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimizers for use with JAX.
This module contains some convenient optimizer definitions, specifically
initialization and update functions, which can be used with ndarrays or
arbitrarily-nested tuple/list/dicts of ndarrays.
An optimizer is modeled as an ``(init_fun, update_fun, get_params)`` triple of
functions, where the component functions have these signatures:
::
init_fun(params)
Args:
params: pytree representing the initial parameters.
Returns:
A pytree representing the initial optimizer state, which includes the
initial parameters and may also include auxiliary values like initial
momentum. The optimizer state pytree structure generally differs from that
of `params`.
::
update_fun(step, grads, opt_state)
Args:
step: integer representing the step index.
grads: a pytree with the same structure as `get_params(opt_state)`
representing the gradients to be used in updating the optimizer state.
opt_state: a pytree representing the optimizer state to be updated.
Returns:
A pytree with the same structure as the `opt_state` argument representing
the updated optimizer state.
::
get_params(opt_state)
Args:
opt_state: pytree representing an optimizer state.
Returns:
A pytree representing the parameters extracted from `opt_state`, such that
the invariant `params == get_params(init_fun(params))` holds true.
Notice that an optimizer implementation has a lot of flexibility in the form of
opt_state: it just has to be a pytree of JaxTypes (so that it can be passed to
the JAX transforms defined in api.py) and it has to be consumable by update_fun
and get_params.
Example Usage:
.. code-block:: python
opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
opt_state = opt_init(params)
def step(step, opt_state):
value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state))
opt_state = opt_update(step, grads, opt_state)
return value, opt_state
for i in range(num_steps):
value, opt_state = step(i, opt_state)
"""
from typing import Any, Callable, NamedTuple, Tuple, Union
from collections import namedtuple
import functools
from functools import partial
import jax.numpy as jnp
from jax._src.util import safe_zip, safe_map, unzip2
from jax import tree_util
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
register_pytree_node)
map = safe_map
zip = safe_zip
# The implementation here basically works by flattening pytrees. There are two
# levels of pytrees to think about: the pytree of params, which we can think of
# as defining an "outer pytree", and a pytree produced by applying init_fun to
# each leaf of the params pytree, which we can think of as the "inner pytrees".
# Since pytrees can be flattened, that structure is isomorphic to a list of
# lists (with no further nesting).
OptimizerState = namedtuple("OptimizerState",
["packed_state", "tree_def", "subtree_defs"])
register_pytree_node(
OptimizerState,
lambda xs: ((xs.packed_state,), (xs.tree_def, xs.subtree_defs)),
lambda data, xs: OptimizerState(xs[0], data[0], data[1])) # type: ignore[index]
Array = Any
Params = Any # Parameters are arbitrary nests of `jnp.ndarrays`.
State = Any # internal State
Updates = Params # Gradient updates are of the same type as parameters.
InitFn = Callable[[Params], OptimizerState]
Step = int
UpdateFn = Callable[[Step, Updates, OptimizerState], OptimizerState]
ParamsFn = Callable[[OptimizerState], Params]
class Optimizer(NamedTuple):
init_fn: InitFn
update_fn: UpdateFn
params_fn: ParamsFn
Schedule = Callable[[Step], float]
def optimizer(opt_maker: Callable[...,
Tuple[Callable[[Params], State],
Callable[[Step, Updates, Params], Params],
Callable[[State], Params]]]) -> Callable[..., Optimizer]:
"""Decorator to make an optimizer defined for arrays generalize to containers.
With this decorator, you can write init, update, and get_params functions that
each operate only on single arrays, and convert them to corresponding
functions that operate on pytrees of parameters. See the optimizers defined in
optimizers.py for examples.
Args:
opt_maker: a function that returns an ``(init_fun, update_fun, get_params)``
triple of functions that might only work with ndarrays, as per
.. code-block:: haskell
init_fun :: ndarray -> OptStatePytree ndarray
update_fun :: OptStatePytree ndarray -> OptStatePytree ndarray
get_params :: OptStatePytree ndarray -> ndarray
Returns:
An ``(init_fun, update_fun, get_params)`` triple of functions that work on
arbitrary pytrees, as per
.. code-block:: haskell
init_fun :: ParameterPytree ndarray -> OptimizerState
update_fun :: OptimizerState -> OptimizerState
get_params :: OptimizerState -> ParameterPytree ndarray
The OptimizerState pytree type used by the returned functions is isomorphic
to ``ParameterPytree (OptStatePytree ndarray)``, but may store the state
instead as e.g. a partially-flattened data structure for performance.
"""
@functools.wraps(opt_maker)
def tree_opt_maker(*args, **kwargs):
init, update, get_params = opt_maker(*args, **kwargs)
@functools.wraps(init)
def tree_init(x0_tree):
x0_flat, tree = tree_flatten(x0_tree)
initial_states = [init(x0) for x0 in x0_flat]
states_flat, subtrees = unzip2(map(tree_flatten, initial_states))
return OptimizerState(states_flat, tree, subtrees)
@functools.wraps(update)
def tree_update(i, grad_tree, opt_state):
states_flat, tree, subtrees = opt_state
grad_flat, tree2 = tree_flatten(grad_tree)
if tree2 != tree:
msg = ("optimizer update function was passed a gradient tree that did "
"not match the parameter tree structure with which it was "
"initialized: parameter tree {} and grad tree {}.")
raise TypeError(msg.format(tree, tree2))
states = map(tree_unflatten, subtrees, states_flat)
new_states = map(partial(update, i), grad_flat, states)
new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states))
for subtree, subtree2 in zip(subtrees, subtrees2):
if subtree2 != subtree:
msg = ("optimizer update function produced an output structure that "
"did not match its input structure: input {} and output {}.")
raise TypeError(msg.format(subtree, subtree2))
return OptimizerState(new_states_flat, tree, subtrees)
@functools.wraps(get_params)
def tree_get_params(opt_state):
states_flat, tree, subtrees = opt_state
states = map(tree_unflatten, subtrees, states_flat)
params = map(get_params, states)
return tree_unflatten(tree, params)
return Optimizer(tree_init, tree_update, tree_get_params)
return tree_opt_maker
### optimizers
@optimizer
def sgd(step_size):
"""Construct optimizer triple for stochastic gradient descent.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
return x0
def update(i, g, x):
return x - step_size(i) * g
def get_params(x):
return x
return Optimizer(init, update, get_params)
@optimizer
def momentum(step_size: Schedule, mass: float):
"""Construct optimizer triple for SGD with momentum.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
mass: positive scalar representing the momentum coefficient.
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
v0 = jnp.zeros_like(x0)
return x0, v0
def update(i, g, state):
x, velocity = state
velocity = mass * velocity + g
x = x - step_size(i) * velocity
return x, velocity
def get_params(state):
x, _ = state
return x
return init, update, get_params
@optimizer
def nesterov(step_size: Schedule, mass: float):
"""Construct optimizer triple for SGD with Nesterov momentum.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
mass: positive scalar representing the momentum coefficient.
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
v0 = jnp.zeros_like(x0)
return x0, v0
def update(i, g, state):
x, velocity = state
velocity = mass * velocity + g
x = x - step_size(i) * (mass * velocity + g)
return x, velocity
def get_params(state):
x, _ = state
return x
return init, update, get_params
@optimizer
def adagrad(step_size, momentum=0.9):
"""Construct optimizer triple for Adagrad.
Adaptive Subgradient Methods for Online Learning and Stochastic Optimization:
http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
momentum: optional, a positive scalar value for momentum
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
g_sq = jnp.zeros_like(x0)
m = jnp.zeros_like(x0)
return x0, g_sq, m
def update(i, g, state):
x, g_sq, m = state
g_sq += jnp.square(g)
g_sq_inv_sqrt = jnp.where(g_sq > 0, 1. / jnp.sqrt(g_sq), 0.0)
m = (1. - momentum) * (g * g_sq_inv_sqrt) + momentum * m
x = x - step_size(i) * m
return x, g_sq, m
def get_params(state):
x, _, _ = state
return x
return init, update, get_params
@optimizer
def rmsprop(step_size, gamma=0.9, eps=1e-8):
"""Construct optimizer triple for RMSProp.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
gamma: Decay parameter.
eps: Epsilon parameter.
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
avg_sq_grad = jnp.zeros_like(x0)
return x0, avg_sq_grad
def update(i, g, state):
x, avg_sq_grad = state
avg_sq_grad = avg_sq_grad * gamma + jnp.square(g) * (1. - gamma)
x = x - step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
return x, avg_sq_grad
def get_params(state):
x, _ = state
return x
return init, update, get_params
@optimizer
def rmsprop_momentum(step_size, gamma=0.9, eps=1e-8, momentum=0.9):
"""Construct optimizer triple for RMSProp with momentum.
This optimizer is separate from the rmsprop optimizer because it needs to
keep track of additional parameters.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
gamma: Decay parameter.
eps: Epsilon parameter.
momentum: Momentum parameter.
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
avg_sq_grad = jnp.zeros_like(x0)
mom = jnp.zeros_like(x0)
return x0, avg_sq_grad, mom
def update(i, g, state):
x, avg_sq_grad, mom = state
avg_sq_grad = avg_sq_grad * gamma + jnp.square(g) * (1. - gamma)
mom = momentum * mom + step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
x = x - mom
return x, avg_sq_grad, mom
def get_params(state):
x, _, _ = state
return x
return init, update, get_params
@optimizer
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
"""Construct optimizer triple for Adam.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
b1: optional, a positive scalar value for beta_1, the exponential decay rate
for the first moment estimates (default 0.9).
b2: optional, a positive scalar value for beta_2, the exponential decay rate
for the second moment estimates (default 0.999).
eps: optional, a positive scalar value for epsilon, a small constant for
numerical stability (default 1e-8).
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
m0 = jnp.zeros_like(x0)
v0 = jnp.zeros_like(x0)
return x0, m0, v0
def update(i, g, state):
x, m, v = state
m = (1 - b1) * g + b1 * m # First moment estimate.
v = (1 - b2) * jnp.square(g) + b2 * v # Second moment estimate.
mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
x = x - step_size(i) * mhat / (jnp.sqrt(vhat) + eps)
return x, m, v
def get_params(state):
x, _, _ = state
return x
return init, update, get_params
@optimizer
def adamax(step_size, b1=0.9, b2=0.999, eps=1e-8):
"""Construct optimizer triple for AdaMax (a variant of Adam based on infinity norm).
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
b1: optional, a positive scalar value for beta_1, the exponential decay rate
for the first moment estimates (default 0.9).
b2: optional, a positive scalar value for beta_2, the exponential decay rate
for the second moment estimates (default 0.999).
eps: optional, a positive scalar value for epsilon, a small constant for
numerical stability (default 1e-8).
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
m0 = jnp.zeros_like(x0)
u0 = jnp.zeros_like(x0)
return x0, m0, u0
def update(i, g, state):
x, m, u = state
m = (1 - b1) * g + b1 * m # First moment estimate.
u = jnp.maximum(b2 * u, jnp.abs(g)) # Update exponentially weighted infinity norm.
x = (x - (step_size(i) / (1 - jnp.asarray(b1, m.dtype) ** (i + 1))) * m
/ (u + eps))
return x, m, u
def get_params(state):
x, _, _ = state
return x
return init, update, get_params
@optimizer
def sm3(step_size, momentum=0.9):
"""Construct optimizer triple for SM3.
Memory-Efficient Adaptive Optimization for Large-Scale Learning.
https://arxiv.org/abs/1901.11150
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
momentum: optional, a positive scalar value for momentum
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def splice(seq, i, x):
lst = list(seq)
lst[i:i+1] = x
return lst
def broadcast_into(ndim, x, axis):
idx = splice([None] * ndim, axis, [slice(None)])
return x[tuple(idx)]
def init(x0):
x_shape = x0.shape
x0 = jnp.atleast_1d(x0)
vs = [jnp.zeros(sz, dtype=x0.dtype) for sz in x0.shape]
return x0, jnp.zeros_like(x0), vs, x_shape
def update(i, g, state):
x, m, vs, x_shape = state
vs = [broadcast_into(g.ndim, v, i) for i, v in enumerate(vs)]
accum = functools.reduce(jnp.minimum, vs) + jnp.square(g)
accum_inv_sqrt = jnp.where(accum > 0, 1. / jnp.sqrt(accum), 0)
m = (1. - momentum) * (g * accum_inv_sqrt) + momentum * m
x = x - step_size(i) * m
vs = [accum.max(splice(range(x.ndim), j, [])) for j in range(x.ndim)]
return x, m, vs, x_shape
def get_params(state):
x, _, _, x_shape = state
return x.reshape(x_shape)
return init, update, get_params
### learning rate schedules
def constant(step_size) -> Schedule:
def schedule(i):
return step_size
return schedule
def exponential_decay(step_size, decay_steps, decay_rate):
def schedule(i):
return step_size * decay_rate ** (i / decay_steps)
return schedule
def inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False):
if staircase:
def schedule(i):
return step_size / (1 + decay_rate * jnp.floor(i / decay_steps))
else:
def schedule(i):
return step_size / (1 + decay_rate * i / decay_steps)
return schedule
def polynomial_decay(step_size, decay_steps, final_step_size, power=1.0):
def schedule(step_num):
step_num = jnp.minimum(step_num, decay_steps)
step_mult = (1 - step_num / decay_steps) ** power
return step_mult * (step_size - final_step_size) + final_step_size
return schedule
def piecewise_constant(boundaries: Any, values: Any):
boundaries = jnp.array(boundaries)
values = jnp.array(values)
if not boundaries.ndim == values.ndim == 1:
raise ValueError("boundaries and values must be sequences")
if not boundaries.shape[0] == values.shape[0] - 1:
raise ValueError("boundaries length must be one shorter than values length")
def schedule(i):
return values[jnp.sum(i > boundaries)]
return schedule
def make_schedule(scalar_or_schedule: Union[float, Schedule]) -> Schedule:
if callable(scalar_or_schedule):
return scalar_or_schedule
elif jnp.ndim(scalar_or_schedule) == 0:
return constant(scalar_or_schedule)
else:
raise TypeError(type(scalar_or_schedule))
### utilities
def l2_norm(tree):
"""Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
leaves, _ = tree_flatten(tree)
return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves))
def clip_grads(grad_tree, max_norm):
"""Clip gradients stored as a pytree of arrays to maximum norm `max_norm`."""
norm = l2_norm(grad_tree)
normalize = lambda g: jnp.where(norm < max_norm, g, g * (max_norm / norm))
return tree_map(normalize, grad_tree)
### serialization utilities
class JoinPoint(object):
"""Marks the boundary between two joined (nested) pytrees."""
def __init__(self, subtree):
self.subtree = subtree
# Since pytrees are containers of numpy arrays, look iterable.
def __iter__(self):
yield self.subtree
def unpack_optimizer_state(opt_state):
"""Converts an OptimizerState to a marked pytree.
Converts an OptimizerState to a marked pytree with the leaves of the outer
pytree represented as JoinPoints to avoid losing information. This function is
intended to be useful when serializing optimizer states.
Args:
opt_state: An OptimizerState
Returns:
A pytree with JoinPoint leaves that contain a second level of pytrees.
"""
states_flat, tree_def, subtree_defs = opt_state
subtrees = map(tree_unflatten, subtree_defs, states_flat)
sentinels = [JoinPoint(subtree) for subtree in subtrees]
return tree_util.tree_unflatten(tree_def, sentinels)
def pack_optimizer_state(marked_pytree):
"""Converts a marked pytree to an OptimizerState.
The inverse of unpack_optimizer_state. Converts a marked pytree with the
leaves of the outer pytree represented as JoinPoints back into an
OptimizerState. This function is intended to be useful when deserializing
optimizer states.
Args:
marked_pytree: A pytree containing JoinPoint leaves that hold more pytrees.
Returns:
An equivalent OptimizerState to the input argument.
"""
sentinels, tree_def = tree_flatten(marked_pytree)
assert all(isinstance(s, JoinPoint) for s in sentinels)
subtrees = [s.subtree for s in sentinels]
states_flat, subtree_defs = unzip2(map(tree_flatten, subtrees))
return OptimizerState(states_flat, tree_def, subtree_defs)

View File

@ -0,0 +1,351 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stax is a small but flexible neural net specification library from scratch.
For an example of its use, see examples/resnet50.py.
"""
import functools
import operator as op
from jax import lax
from jax import random
import jax.numpy as jnp
from jax.nn import (relu, log_softmax, softmax, softplus, sigmoid, elu,
leaky_relu, selu, gelu, normalize)
from jax.nn.initializers import glorot_normal, normal, ones, zeros
# aliases for backwards compatibility
glorot = glorot_normal
randn = normal
logsoftmax = log_softmax
# Following the convention used in Keras and tf.layers, we use CamelCase for the
# names of layer constructors, like Conv and Relu, while using snake_case for
# other functions, like lax.conv and relu.
# Each layer constructor function returns an (init_fun, apply_fun) pair, where
# init_fun: takes an rng key and an input shape and returns an
# (output_shape, params) pair,
# apply_fun: takes params, inputs, and an rng key and applies the layer.
def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
"""Layer constructor function for a dense (fully-connected) layer."""
def init_fun(rng, input_shape):
output_shape = input_shape[:-1] + (out_dim,)
k1, k2 = random.split(rng)
W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
return jnp.dot(inputs, W) + b
return init_fun, apply_fun
def GeneralConv(dimension_numbers, out_chan, filter_shape,
strides=None, padding='VALID', W_init=None,
b_init=normal(1e-6)):
"""Layer construction function for a general convolution layer."""
lhs_spec, rhs_spec, out_spec = dimension_numbers
one = (1,) * len(filter_shape)
strides = strides or one
W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))
def init_fun(rng, input_shape):
filter_shape_iter = iter(filter_shape)
kernel_shape = [out_chan if c == 'O' else
input_shape[lhs_spec.index('C')] if c == 'I' else
next(filter_shape_iter) for c in rhs_spec]
output_shape = lax.conv_general_shape_tuple(
input_shape, kernel_shape, strides, padding, dimension_numbers)
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
k1, k2 = random.split(rng)
W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
return lax.conv_general_dilated(inputs, W, strides, padding, one, one,
dimension_numbers=dimension_numbers) + b
return init_fun, apply_fun
Conv = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
def GeneralConvTranspose(dimension_numbers, out_chan, filter_shape,
strides=None, padding='VALID', W_init=None,
b_init=normal(1e-6)):
"""Layer construction function for a general transposed-convolution layer."""
lhs_spec, rhs_spec, out_spec = dimension_numbers
one = (1,) * len(filter_shape)
strides = strides or one
W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))
def init_fun(rng, input_shape):
filter_shape_iter = iter(filter_shape)
kernel_shape = [out_chan if c == 'O' else
input_shape[lhs_spec.index('C')] if c == 'I' else
next(filter_shape_iter) for c in rhs_spec]
output_shape = lax.conv_transpose_shape_tuple(
input_shape, kernel_shape, strides, padding, dimension_numbers)
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
k1, k2 = random.split(rng)
W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
return lax.conv_transpose(inputs, W, strides, padding,
dimension_numbers=dimension_numbers) + b
return init_fun, apply_fun
Conv1DTranspose = functools.partial(GeneralConvTranspose, ('NHC', 'HIO', 'NHC'))
ConvTranspose = functools.partial(GeneralConvTranspose,
('NHWC', 'HWIO', 'NHWC'))
def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
beta_init=zeros, gamma_init=ones):
"""Layer construction function for a batch normalization layer."""
_beta_init = lambda rng, shape: beta_init(rng, shape) if center else ()
_gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else ()
axis = (axis,) if jnp.isscalar(axis) else axis
def init_fun(rng, input_shape):
shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
k1, k2 = random.split(rng)
beta, gamma = _beta_init(k1, shape), _gamma_init(k2, shape)
return input_shape, (beta, gamma)
def apply_fun(params, x, **kwargs):
beta, gamma = params
# TODO(phawkins): jnp.expand_dims should accept an axis tuple.
# (https://github.com/numpy/numpy/issues/12290)
ed = tuple(None if i in axis else slice(None) for i in range(jnp.ndim(x)))
z = normalize(x, axis, epsilon=epsilon)
if center and scale: return gamma[ed] * z + beta[ed]
if center: return z + beta[ed]
if scale: return gamma[ed] * z
return z
return init_fun, apply_fun
def elementwise(fun, **fun_kwargs):
"""Layer that applies a scalar function elementwise on its inputs."""
init_fun = lambda rng, input_shape: (input_shape, ())
apply_fun = lambda params, inputs, **kwargs: fun(inputs, **fun_kwargs)
return init_fun, apply_fun
Tanh = elementwise(jnp.tanh)
Relu = elementwise(relu)
Exp = elementwise(jnp.exp)
LogSoftmax = elementwise(log_softmax, axis=-1)
Softmax = elementwise(softmax, axis=-1)
Softplus = elementwise(softplus)
Sigmoid = elementwise(sigmoid)
Elu = elementwise(elu)
LeakyRelu = elementwise(leaky_relu)
Selu = elementwise(selu)
Gelu = elementwise(gelu)
def _pooling_layer(reducer, init_val, rescaler=None):
def PoolingLayer(window_shape, strides=None, padding='VALID', spec=None):
"""Layer construction function for a pooling layer."""
strides = strides or (1,) * len(window_shape)
rescale = rescaler(window_shape, strides, padding) if rescaler else None
if spec is None:
non_spatial_axes = 0, len(window_shape) + 1
else:
non_spatial_axes = spec.index('N'), spec.index('C')
for i in sorted(non_spatial_axes):
window_shape = window_shape[:i] + (1,) + window_shape[i:]
strides = strides[:i] + (1,) + strides[i:]
def init_fun(rng, input_shape):
padding_vals = lax.padtype_to_pads(input_shape, window_shape,
strides, padding)
ones = (1,) * len(window_shape)
out_shape = lax.reduce_window_shape_tuple(
input_shape, window_shape, strides, padding_vals, ones, ones)
return out_shape, ()
def apply_fun(params, inputs, **kwargs):
out = lax.reduce_window(inputs, init_val, reducer, window_shape,
strides, padding)
return rescale(out, inputs, spec) if rescale else out
return init_fun, apply_fun
return PoolingLayer
MaxPool = _pooling_layer(lax.max, -jnp.inf)
SumPool = _pooling_layer(lax.add, 0.)
def _normalize_by_window_size(dims, strides, padding):
def rescale(outputs, inputs, spec):
if spec is None:
non_spatial_axes = 0, inputs.ndim - 1
else:
non_spatial_axes = spec.index('N'), spec.index('C')
spatial_shape = tuple(inputs.shape[i]
for i in range(inputs.ndim)
if i not in non_spatial_axes)
one = jnp.ones(spatial_shape, dtype=inputs.dtype)
window_sizes = lax.reduce_window(one, 0., lax.add, dims, strides, padding)
for i in sorted(non_spatial_axes):
window_sizes = jnp.expand_dims(window_sizes, i)
return outputs / window_sizes
return rescale
AvgPool = _pooling_layer(lax.add, 0., _normalize_by_window_size)
def Flatten():
"""Layer construction function for flattening all but the leading dim."""
def init_fun(rng, input_shape):
output_shape = input_shape[0], functools.reduce(op.mul, input_shape[1:], 1)
return output_shape, ()
def apply_fun(params, inputs, **kwargs):
return jnp.reshape(inputs, (inputs.shape[0], -1))
return init_fun, apply_fun
Flatten = Flatten()
def Identity():
"""Layer construction function for an identity layer."""
init_fun = lambda rng, input_shape: (input_shape, ())
apply_fun = lambda params, inputs, **kwargs: inputs
return init_fun, apply_fun
Identity = Identity()
def FanOut(num):
"""Layer construction function for a fan-out layer."""
init_fun = lambda rng, input_shape: ([input_shape] * num, ())
apply_fun = lambda params, inputs, **kwargs: [inputs] * num
return init_fun, apply_fun
def FanInSum():
"""Layer construction function for a fan-in sum layer."""
init_fun = lambda rng, input_shape: (input_shape[0], ())
apply_fun = lambda params, inputs, **kwargs: sum(inputs)
return init_fun, apply_fun
FanInSum = FanInSum()
def FanInConcat(axis=-1):
"""Layer construction function for a fan-in concatenation layer."""
def init_fun(rng, input_shape):
ax = axis % len(input_shape[0])
concat_size = sum(shape[ax] for shape in input_shape)
out_shape = input_shape[0][:ax] + (concat_size,) + input_shape[0][ax+1:]
return out_shape, ()
def apply_fun(params, inputs, **kwargs):
return jnp.concatenate(inputs, axis)
return init_fun, apply_fun
def Dropout(rate, mode='train'):
"""Layer construction function for a dropout layer with given rate."""
def init_fun(rng, input_shape):
return input_shape, ()
def apply_fun(params, inputs, **kwargs):
rng = kwargs.get('rng', None)
if rng is None:
msg = ("Dropout layer requires apply_fun to be called with a PRNG key "
"argument. That is, instead of `apply_fun(params, inputs)`, call "
"it like `apply_fun(params, inputs, rng)` where `rng` is a "
"jax.random.PRNGKey value.")
raise ValueError(msg)
if mode == 'train':
keep = random.bernoulli(rng, rate, inputs.shape)
return jnp.where(keep, inputs / rate, 0)
else:
return inputs
return init_fun, apply_fun
# Composing layers via combinators
def serial(*layers):
"""Combinator for composing layers in serial.
Args:
*layers: a sequence of layers, each an (init_fun, apply_fun) pair.
Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the serial
composition of the given sequence of layers.
"""
nlayers = len(layers)
init_funs, apply_funs = zip(*layers)
def init_fun(rng, input_shape):
params = []
for init_fun in init_funs:
rng, layer_rng = random.split(rng)
input_shape, param = init_fun(layer_rng, input_shape)
params.append(param)
return input_shape, params
def apply_fun(params, inputs, **kwargs):
rng = kwargs.pop('rng', None)
rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
for fun, param, rng in zip(apply_funs, params, rngs):
inputs = fun(param, inputs, rng=rng, **kwargs)
return inputs
return init_fun, apply_fun
def parallel(*layers):
"""Combinator for composing layers in parallel.
The layer resulting from this combinator is often used with the FanOut and
FanInSum layers.
Args:
*layers: a sequence of layers, each an (init_fun, apply_fun) pair.
Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the
parallel composition of the given sequence of layers. In particular, the
returned layer takes a sequence of inputs and returns a sequence of outputs
with the same length as the argument `layers`.
"""
nlayers = len(layers)
init_funs, apply_funs = zip(*layers)
def init_fun(rng, input_shape):
rngs = random.split(rng, nlayers)
return zip(*[init(rng, shape) for init, rng, shape
in zip(init_funs, rngs, input_shape)])
def apply_fun(params, inputs, **kwargs):
rng = kwargs.pop('rng', None)
rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
return [f(p, x, rng=r, **kwargs) for f, p, x, r in zip(apply_funs, params, inputs, rngs)]
return init_fun, apply_fun
def shape_dependent(make_layer):
"""Combinator to delay layer constructor pair until input shapes are known.
Args:
make_layer: a one-argument function that takes an input shape as an argument
(a tuple of positive integers) and returns an (init_fun, apply_fun) pair.
Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the same
layer as returned by `make_layer` but with its construction delayed until
input shapes are known.
"""
def init_fun(rng, input_shape):
return make_layer(input_shape)[0](rng, input_shape)
def apply_fun(params, inputs, **kwargs):
return make_layer(inputs.shape)[1](params, inputs, **kwargs)
return init_fun, apply_fun

View File

@ -1,4 +1,4 @@
# Copyright 2018 Google LLC
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,602 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimizers for use with JAX.
"""Optimizers have moved to jax.example_libraries.optimizers
This module contains some convenient optimizer definitions, specifically
initialization and update functions, which can be used with ndarrays or
arbitrarily-nested tuple/list/dicts of ndarrays.
An optimizer is modeled as an ``(init_fun, update_fun, get_params)`` triple of
functions, where the component functions have these signatures:
::
init_fun(params)
Args:
params: pytree representing the initial parameters.
Returns:
A pytree representing the initial optimizer state, which includes the
initial parameters and may also include auxiliary values like initial
momentum. The optimizer state pytree structure generally differs from that
of `params`.
::
update_fun(step, grads, opt_state)
Args:
step: integer representing the step index.
grads: a pytree with the same structure as `get_params(opt_state)`
representing the gradients to be used in updating the optimizer state.
opt_state: a pytree representing the optimizer state to be updated.
Returns:
A pytree with the same structure as the `opt_state` argument representing
the updated optimizer state.
::
get_params(opt_state)
Args:
opt_state: pytree representing an optimizer state.
Returns:
A pytree representing the parameters extracted from `opt_state`, such that
the invariant `params == get_params(init_fun(params))` holds true.
Notice that an optimizer implementation has a lot of flexibility in the form of
opt_state: it just has to be a pytree of JaxTypes (so that it can be passed to
the JAX transforms defined in api.py) and it has to be consumable by update_fun
and get_params.
Example Usage:
.. code-block:: python
opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
opt_state = opt_init(params)
def step(step, opt_state):
value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state))
opt_state = opt_update(step, grads, opt_state)
return value, opt_state
for i in range(num_steps):
value, opt_state = step(i, opt_state)
jax.experimental.optimizers is deprecated and will delegate to
jax.example_libraries.optimizers with a warning for backwards-compatibility
for a limited time.
"""
from typing import Any, Callable, NamedTuple, Tuple, Union
import warnings
from collections import namedtuple
import functools
from functools import partial
from jax.example_libraries.optimizers import * # noqa: F401,F403
import jax.numpy as jnp
from jax._src.util import safe_zip, safe_map, unzip2
from jax import tree_util
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
register_pytree_node)
_HAS_DYNAMIC_ATTRIBUTES = True
map = safe_map
zip = safe_zip
# The implementation here basically works by flattening pytrees. There are two
# levels of pytrees to think about: the pytree of params, which we can think of
# as defining an "outer pytree", and a pytree produced by applying init_fun to
# each leaf of the params pytree, which we can think of as the "inner pytrees".
# Since pytrees can be flattened, that structure is isomorphic to a list of
# lists (with no further nesting).
OptimizerState = namedtuple("OptimizerState",
["packed_state", "tree_def", "subtree_defs"])
register_pytree_node(
OptimizerState,
lambda xs: ((xs.packed_state,), (xs.tree_def, xs.subtree_defs)),
lambda data, xs: OptimizerState(xs[0], data[0], data[1])) # type: ignore[index]
Array = Any
Params = Any # Parameters are arbitrary nests of `jnp.ndarrays`.
State = Any # internal State
Updates = Params # Gradient updates are of the same type as parameters.
InitFn = Callable[[Params], OptimizerState]
Step = int
UpdateFn = Callable[[Step, Updates, OptimizerState], OptimizerState]
ParamsFn = Callable[[OptimizerState], Params]
class Optimizer(NamedTuple):
init_fn: InitFn
update_fn: UpdateFn
params_fn: ParamsFn
Schedule = Callable[[Step], float]
def optimizer(opt_maker: Callable[...,
Tuple[Callable[[Params], State],
Callable[[Step, Updates, Params], Params],
Callable[[State], Params]]]) -> Callable[..., Optimizer]:
"""Decorator to make an optimizer defined for arrays generalize to containers.
With this decorator, you can write init, update, and get_params functions that
each operate only on single arrays, and convert them to corresponding
functions that operate on pytrees of parameters. See the optimizers defined in
optimizers.py for examples.
Args:
opt_maker: a function that returns an ``(init_fun, update_fun, get_params)``
triple of functions that might only work with ndarrays, as per
.. code-block:: haskell
init_fun :: ndarray -> OptStatePytree ndarray
update_fun :: OptStatePytree ndarray -> OptStatePytree ndarray
get_params :: OptStatePytree ndarray -> ndarray
Returns:
An ``(init_fun, update_fun, get_params)`` triple of functions that work on
arbitrary pytrees, as per
.. code-block:: haskell
init_fun :: ParameterPytree ndarray -> OptimizerState
update_fun :: OptimizerState -> OptimizerState
get_params :: OptimizerState -> ParameterPytree ndarray
The OptimizerState pytree type used by the returned functions is isomorphic
to ``ParameterPytree (OptStatePytree ndarray)``, but may store the state
instead as e.g. a partially-flattened data structure for performance.
"""
@functools.wraps(opt_maker)
def tree_opt_maker(*args, **kwargs):
init, update, get_params = opt_maker(*args, **kwargs)
@functools.wraps(init)
def tree_init(x0_tree):
x0_flat, tree = tree_flatten(x0_tree)
initial_states = [init(x0) for x0 in x0_flat]
states_flat, subtrees = unzip2(map(tree_flatten, initial_states))
return OptimizerState(states_flat, tree, subtrees)
@functools.wraps(update)
def tree_update(i, grad_tree, opt_state):
states_flat, tree, subtrees = opt_state
grad_flat, tree2 = tree_flatten(grad_tree)
if tree2 != tree:
msg = ("optimizer update function was passed a gradient tree that did "
"not match the parameter tree structure with which it was "
"initialized: parameter tree {} and grad tree {}.")
raise TypeError(msg.format(tree, tree2))
states = map(tree_unflatten, subtrees, states_flat)
new_states = map(partial(update, i), grad_flat, states)
new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states))
for subtree, subtree2 in zip(subtrees, subtrees2):
if subtree2 != subtree:
msg = ("optimizer update function produced an output structure that "
"did not match its input structure: input {} and output {}.")
raise TypeError(msg.format(subtree, subtree2))
return OptimizerState(new_states_flat, tree, subtrees)
@functools.wraps(get_params)
def tree_get_params(opt_state):
states_flat, tree, subtrees = opt_state
states = map(tree_unflatten, subtrees, states_flat)
params = map(get_params, states)
return tree_unflatten(tree, params)
return Optimizer(tree_init, tree_update, tree_get_params)
return tree_opt_maker
### optimizers
@optimizer
def sgd(step_size):
"""Construct optimizer triple for stochastic gradient descent.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
return x0
def update(i, g, x):
return x - step_size(i) * g
def get_params(x):
return x
return Optimizer(init, update, get_params)
@optimizer
def momentum(step_size: Schedule, mass: float):
"""Construct optimizer triple for SGD with momentum.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
mass: positive scalar representing the momentum coefficient.
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
v0 = jnp.zeros_like(x0)
return x0, v0
def update(i, g, state):
x, velocity = state
velocity = mass * velocity + g
x = x - step_size(i) * velocity
return x, velocity
def get_params(state):
x, _ = state
return x
return init, update, get_params
@optimizer
def nesterov(step_size: Schedule, mass: float):
"""Construct optimizer triple for SGD with Nesterov momentum.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
mass: positive scalar representing the momentum coefficient.
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
v0 = jnp.zeros_like(x0)
return x0, v0
def update(i, g, state):
x, velocity = state
velocity = mass * velocity + g
x = x - step_size(i) * (mass * velocity + g)
return x, velocity
def get_params(state):
x, _ = state
return x
return init, update, get_params
@optimizer
def adagrad(step_size, momentum=0.9):
"""Construct optimizer triple for Adagrad.
Adaptive Subgradient Methods for Online Learning and Stochastic Optimization:
http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
momentum: optional, a positive scalar value for momentum
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
g_sq = jnp.zeros_like(x0)
m = jnp.zeros_like(x0)
return x0, g_sq, m
def update(i, g, state):
x, g_sq, m = state
g_sq += jnp.square(g)
g_sq_inv_sqrt = jnp.where(g_sq > 0, 1. / jnp.sqrt(g_sq), 0.0)
m = (1. - momentum) * (g * g_sq_inv_sqrt) + momentum * m
x = x - step_size(i) * m
return x, g_sq, m
def get_params(state):
x, _, _ = state
return x
return init, update, get_params
@optimizer
def rmsprop(step_size, gamma=0.9, eps=1e-8):
"""Construct optimizer triple for RMSProp.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
gamma: Decay parameter.
eps: Epsilon parameter.
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
avg_sq_grad = jnp.zeros_like(x0)
return x0, avg_sq_grad
def update(i, g, state):
x, avg_sq_grad = state
avg_sq_grad = avg_sq_grad * gamma + jnp.square(g) * (1. - gamma)
x = x - step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
return x, avg_sq_grad
def get_params(state):
x, _ = state
return x
return init, update, get_params
@optimizer
def rmsprop_momentum(step_size, gamma=0.9, eps=1e-8, momentum=0.9):
"""Construct optimizer triple for RMSProp with momentum.
This optimizer is separate from the rmsprop optimizer because it needs to
keep track of additional parameters.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
gamma: Decay parameter.
eps: Epsilon parameter.
momentum: Momentum parameter.
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
avg_sq_grad = jnp.zeros_like(x0)
mom = jnp.zeros_like(x0)
return x0, avg_sq_grad, mom
def update(i, g, state):
x, avg_sq_grad, mom = state
avg_sq_grad = avg_sq_grad * gamma + jnp.square(g) * (1. - gamma)
mom = momentum * mom + step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
x = x - mom
return x, avg_sq_grad, mom
def get_params(state):
x, _, _ = state
return x
return init, update, get_params
@optimizer
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
"""Construct optimizer triple for Adam.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
b1: optional, a positive scalar value for beta_1, the exponential decay rate
for the first moment estimates (default 0.9).
b2: optional, a positive scalar value for beta_2, the exponential decay rate
for the second moment estimates (default 0.999).
eps: optional, a positive scalar value for epsilon, a small constant for
numerical stability (default 1e-8).
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
m0 = jnp.zeros_like(x0)
v0 = jnp.zeros_like(x0)
return x0, m0, v0
def update(i, g, state):
x, m, v = state
m = (1 - b1) * g + b1 * m # First moment estimate.
v = (1 - b2) * jnp.square(g) + b2 * v # Second moment estimate.
mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
x = x - step_size(i) * mhat / (jnp.sqrt(vhat) + eps)
return x, m, v
def get_params(state):
x, _, _ = state
return x
return init, update, get_params
@optimizer
def adamax(step_size, b1=0.9, b2=0.999, eps=1e-8):
"""Construct optimizer triple for AdaMax (a variant of Adam based on infinity norm).
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
b1: optional, a positive scalar value for beta_1, the exponential decay rate
for the first moment estimates (default 0.9).
b2: optional, a positive scalar value for beta_2, the exponential decay rate
for the second moment estimates (default 0.999).
eps: optional, a positive scalar value for epsilon, a small constant for
numerical stability (default 1e-8).
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
m0 = jnp.zeros_like(x0)
u0 = jnp.zeros_like(x0)
return x0, m0, u0
def update(i, g, state):
x, m, u = state
m = (1 - b1) * g + b1 * m # First moment estimate.
u = jnp.maximum(b2 * u, jnp.abs(g)) # Update exponentially weighted infinity norm.
x = (x - (step_size(i) / (1 - jnp.asarray(b1, m.dtype) ** (i + 1))) * m
/ (u + eps))
return x, m, u
def get_params(state):
x, _, _ = state
return x
return init, update, get_params
@optimizer
def sm3(step_size, momentum=0.9):
"""Construct optimizer triple for SM3.
Memory-Efficient Adaptive Optimization for Large-Scale Learning.
https://arxiv.org/abs/1901.11150
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
momentum: optional, a positive scalar value for momentum
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def splice(seq, i, x):
lst = list(seq)
lst[i:i+1] = x
return lst
def broadcast_into(ndim, x, axis):
idx = splice([None] * ndim, axis, [slice(None)])
return x[tuple(idx)]
def init(x0):
x_shape = x0.shape
x0 = jnp.atleast_1d(x0)
vs = [jnp.zeros(sz, dtype=x0.dtype) for sz in x0.shape]
return x0, jnp.zeros_like(x0), vs, x_shape
def update(i, g, state):
x, m, vs, x_shape = state
vs = [broadcast_into(g.ndim, v, i) for i, v in enumerate(vs)]
accum = functools.reduce(jnp.minimum, vs) + jnp.square(g)
accum_inv_sqrt = jnp.where(accum > 0, 1. / jnp.sqrt(accum), 0)
m = (1. - momentum) * (g * accum_inv_sqrt) + momentum * m
x = x - step_size(i) * m
vs = [accum.max(splice(range(x.ndim), j, [])) for j in range(x.ndim)]
return x, m, vs, x_shape
def get_params(state):
x, _, _, x_shape = state
return x.reshape(x_shape)
return init, update, get_params
### learning rate schedules
def constant(step_size) -> Schedule:
def schedule(i):
return step_size
return schedule
def exponential_decay(step_size, decay_steps, decay_rate):
def schedule(i):
return step_size * decay_rate ** (i / decay_steps)
return schedule
def inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False):
if staircase:
def schedule(i):
return step_size / (1 + decay_rate * jnp.floor(i / decay_steps))
else:
def schedule(i):
return step_size / (1 + decay_rate * i / decay_steps)
return schedule
def polynomial_decay(step_size, decay_steps, final_step_size, power=1.0):
def schedule(step_num):
step_num = jnp.minimum(step_num, decay_steps)
step_mult = (1 - step_num / decay_steps) ** power
return step_mult * (step_size - final_step_size) + final_step_size
return schedule
def piecewise_constant(boundaries: Any, values: Any):
boundaries = jnp.array(boundaries)
values = jnp.array(values)
if not boundaries.ndim == values.ndim == 1:
raise ValueError("boundaries and values must be sequences")
if not boundaries.shape[0] == values.shape[0] - 1:
raise ValueError("boundaries length must be one shorter than values length")
def schedule(i):
return values[jnp.sum(i > boundaries)]
return schedule
def make_schedule(scalar_or_schedule: Union[float, Schedule]) -> Schedule:
if callable(scalar_or_schedule):
return scalar_or_schedule
elif jnp.ndim(scalar_or_schedule) == 0:
return constant(scalar_or_schedule)
else:
raise TypeError(type(scalar_or_schedule))
### utilities
def l2_norm(tree):
"""Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
leaves, _ = tree_flatten(tree)
return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves))
def clip_grads(grad_tree, max_norm):
"""Clip gradients stored as a pytree of arrays to maximum norm `max_norm`."""
norm = l2_norm(grad_tree)
normalize = lambda g: jnp.where(norm < max_norm, g, g * (max_norm / norm))
return tree_map(normalize, grad_tree)
### serialization utilities
class JoinPoint(object):
"""Marks the boundary between two joined (nested) pytrees."""
def __init__(self, subtree):
self.subtree = subtree
# Since pytrees are containers of numpy arrays, look iterable.
def __iter__(self):
yield self.subtree
def unpack_optimizer_state(opt_state):
"""Converts an OptimizerState to a marked pytree.
Converts an OptimizerState to a marked pytree with the leaves of the outer
pytree represented as JoinPoints to avoid losing information. This function is
intended to be useful when serializing optimizer states.
Args:
opt_state: An OptimizerState
Returns:
A pytree with JoinPoint leaves that contain a second level of pytrees.
"""
states_flat, tree_def, subtree_defs = opt_state
subtrees = map(tree_unflatten, subtree_defs, states_flat)
sentinels = [JoinPoint(subtree) for subtree in subtrees]
return tree_util.tree_unflatten(tree_def, sentinels)
def pack_optimizer_state(marked_pytree):
"""Converts a marked pytree to an OptimizerState.
The inverse of unpack_optimizer_state. Converts a marked pytree with the
leaves of the outer pytree represented as JoinPoints back into an
OptimizerState. This function is intended to be useful when deserializing
optimizer states.
Args:
marked_pytree: A pytree containing JoinPoint leaves that hold more pytrees.
Returns:
An equivalent OptimizerState to the input argument.
"""
sentinels, tree_def = tree_flatten(marked_pytree)
assert all(isinstance(s, JoinPoint) for s in sentinels)
subtrees = [s.subtree for s in sentinels]
states_flat, subtree_defs = unzip2(map(tree_flatten, subtrees))
return OptimizerState(states_flat, tree_def, subtree_defs)
warnings.warn('jax.experimental.optimizers is deprecated, '
'import jax.example_libraries.optimizers instead',
FutureWarning)

View File

@ -1,4 +1,4 @@
# Copyright 2018 Google LLC
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,340 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stax is a small but flexible neural net specification library from scratch.
"""Stax has moved to jax.example_libraries.stax
For an example of its use, see examples/resnet50.py.
jax.experimental.stax is deprecated and will delegate to
jax.example_libraries.stax with a warning for backwards-compatibility
for a limited time.
"""
import warnings
import functools
import operator as op
from jax.example_libraries.stax import * # noqa: F401,F403
from jax import lax
from jax import random
import jax.numpy as jnp
_HAS_DYNAMIC_ATTRIBUTES = True
from jax.nn import (relu, log_softmax, softmax, softplus, sigmoid, elu,
leaky_relu, selu, gelu, normalize)
from jax.nn.initializers import glorot_normal, normal, ones, zeros
# aliases for backwards compatibility
glorot = glorot_normal
randn = normal
logsoftmax = log_softmax
# Following the convention used in Keras and tf.layers, we use CamelCase for the
# names of layer constructors, like Conv and Relu, while using snake_case for
# other functions, like lax.conv and relu.
# Each layer constructor function returns an (init_fun, apply_fun) pair, where
# init_fun: takes an rng key and an input shape and returns an
# (output_shape, params) pair,
# apply_fun: takes params, inputs, and an rng key and applies the layer.
def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
"""Layer constructor function for a dense (fully-connected) layer."""
def init_fun(rng, input_shape):
output_shape = input_shape[:-1] + (out_dim,)
k1, k2 = random.split(rng)
W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
return jnp.dot(inputs, W) + b
return init_fun, apply_fun
def GeneralConv(dimension_numbers, out_chan, filter_shape,
strides=None, padding='VALID', W_init=None,
b_init=normal(1e-6)):
"""Layer construction function for a general convolution layer."""
lhs_spec, rhs_spec, out_spec = dimension_numbers
one = (1,) * len(filter_shape)
strides = strides or one
W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))
def init_fun(rng, input_shape):
filter_shape_iter = iter(filter_shape)
kernel_shape = [out_chan if c == 'O' else
input_shape[lhs_spec.index('C')] if c == 'I' else
next(filter_shape_iter) for c in rhs_spec]
output_shape = lax.conv_general_shape_tuple(
input_shape, kernel_shape, strides, padding, dimension_numbers)
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
k1, k2 = random.split(rng)
W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
return lax.conv_general_dilated(inputs, W, strides, padding, one, one,
dimension_numbers=dimension_numbers) + b
return init_fun, apply_fun
Conv = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
def GeneralConvTranspose(dimension_numbers, out_chan, filter_shape,
strides=None, padding='VALID', W_init=None,
b_init=normal(1e-6)):
"""Layer construction function for a general transposed-convolution layer."""
lhs_spec, rhs_spec, out_spec = dimension_numbers
one = (1,) * len(filter_shape)
strides = strides or one
W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))
def init_fun(rng, input_shape):
filter_shape_iter = iter(filter_shape)
kernel_shape = [out_chan if c == 'O' else
input_shape[lhs_spec.index('C')] if c == 'I' else
next(filter_shape_iter) for c in rhs_spec]
output_shape = lax.conv_transpose_shape_tuple(
input_shape, kernel_shape, strides, padding, dimension_numbers)
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
k1, k2 = random.split(rng)
W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
return lax.conv_transpose(inputs, W, strides, padding,
dimension_numbers=dimension_numbers) + b
return init_fun, apply_fun
Conv1DTranspose = functools.partial(GeneralConvTranspose, ('NHC', 'HIO', 'NHC'))
ConvTranspose = functools.partial(GeneralConvTranspose,
('NHWC', 'HWIO', 'NHWC'))
def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
beta_init=zeros, gamma_init=ones):
"""Layer construction function for a batch normalization layer."""
_beta_init = lambda rng, shape: beta_init(rng, shape) if center else ()
_gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else ()
axis = (axis,) if jnp.isscalar(axis) else axis
def init_fun(rng, input_shape):
shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
k1, k2 = random.split(rng)
beta, gamma = _beta_init(k1, shape), _gamma_init(k2, shape)
return input_shape, (beta, gamma)
def apply_fun(params, x, **kwargs):
beta, gamma = params
# TODO(phawkins): jnp.expand_dims should accept an axis tuple.
# (https://github.com/numpy/numpy/issues/12290)
ed = tuple(None if i in axis else slice(None) for i in range(jnp.ndim(x)))
z = normalize(x, axis, epsilon=epsilon)
if center and scale: return gamma[ed] * z + beta[ed]
if center: return z + beta[ed]
if scale: return gamma[ed] * z
return z
return init_fun, apply_fun
def elementwise(fun, **fun_kwargs):
"""Layer that applies a scalar function elementwise on its inputs."""
init_fun = lambda rng, input_shape: (input_shape, ())
apply_fun = lambda params, inputs, **kwargs: fun(inputs, **fun_kwargs)
return init_fun, apply_fun
Tanh = elementwise(jnp.tanh)
Relu = elementwise(relu)
Exp = elementwise(jnp.exp)
LogSoftmax = elementwise(log_softmax, axis=-1)
Softmax = elementwise(softmax, axis=-1)
Softplus = elementwise(softplus)
Sigmoid = elementwise(sigmoid)
Elu = elementwise(elu)
LeakyRelu = elementwise(leaky_relu)
Selu = elementwise(selu)
Gelu = elementwise(gelu)
def _pooling_layer(reducer, init_val, rescaler=None):
def PoolingLayer(window_shape, strides=None, padding='VALID', spec=None):
"""Layer construction function for a pooling layer."""
strides = strides or (1,) * len(window_shape)
rescale = rescaler(window_shape, strides, padding) if rescaler else None
if spec is None:
non_spatial_axes = 0, len(window_shape) + 1
else:
non_spatial_axes = spec.index('N'), spec.index('C')
for i in sorted(non_spatial_axes):
window_shape = window_shape[:i] + (1,) + window_shape[i:]
strides = strides[:i] + (1,) + strides[i:]
def init_fun(rng, input_shape):
padding_vals = lax.padtype_to_pads(input_shape, window_shape,
strides, padding)
ones = (1,) * len(window_shape)
out_shape = lax.reduce_window_shape_tuple(
input_shape, window_shape, strides, padding_vals, ones, ones)
return out_shape, ()
def apply_fun(params, inputs, **kwargs):
out = lax.reduce_window(inputs, init_val, reducer, window_shape,
strides, padding)
return rescale(out, inputs, spec) if rescale else out
return init_fun, apply_fun
return PoolingLayer
MaxPool = _pooling_layer(lax.max, -jnp.inf)
SumPool = _pooling_layer(lax.add, 0.)
def _normalize_by_window_size(dims, strides, padding):
def rescale(outputs, inputs, spec):
if spec is None:
non_spatial_axes = 0, inputs.ndim - 1
else:
non_spatial_axes = spec.index('N'), spec.index('C')
spatial_shape = tuple(inputs.shape[i]
for i in range(inputs.ndim)
if i not in non_spatial_axes)
one = jnp.ones(spatial_shape, dtype=inputs.dtype)
window_sizes = lax.reduce_window(one, 0., lax.add, dims, strides, padding)
for i in sorted(non_spatial_axes):
window_sizes = jnp.expand_dims(window_sizes, i)
return outputs / window_sizes
return rescale
AvgPool = _pooling_layer(lax.add, 0., _normalize_by_window_size)
def Flatten():
"""Layer construction function for flattening all but the leading dim."""
def init_fun(rng, input_shape):
output_shape = input_shape[0], functools.reduce(op.mul, input_shape[1:], 1)
return output_shape, ()
def apply_fun(params, inputs, **kwargs):
return jnp.reshape(inputs, (inputs.shape[0], -1))
return init_fun, apply_fun
Flatten = Flatten()
def Identity():
"""Layer construction function for an identity layer."""
init_fun = lambda rng, input_shape: (input_shape, ())
apply_fun = lambda params, inputs, **kwargs: inputs
return init_fun, apply_fun
Identity = Identity()
def FanOut(num):
"""Layer construction function for a fan-out layer."""
init_fun = lambda rng, input_shape: ([input_shape] * num, ())
apply_fun = lambda params, inputs, **kwargs: [inputs] * num
return init_fun, apply_fun
def FanInSum():
"""Layer construction function for a fan-in sum layer."""
init_fun = lambda rng, input_shape: (input_shape[0], ())
apply_fun = lambda params, inputs, **kwargs: sum(inputs)
return init_fun, apply_fun
FanInSum = FanInSum()
def FanInConcat(axis=-1):
"""Layer construction function for a fan-in concatenation layer."""
def init_fun(rng, input_shape):
ax = axis % len(input_shape[0])
concat_size = sum(shape[ax] for shape in input_shape)
out_shape = input_shape[0][:ax] + (concat_size,) + input_shape[0][ax+1:]
return out_shape, ()
def apply_fun(params, inputs, **kwargs):
return jnp.concatenate(inputs, axis)
return init_fun, apply_fun
def Dropout(rate, mode='train'):
"""Layer construction function for a dropout layer with given rate."""
def init_fun(rng, input_shape):
return input_shape, ()
def apply_fun(params, inputs, **kwargs):
rng = kwargs.get('rng', None)
if rng is None:
msg = ("Dropout layer requires apply_fun to be called with a PRNG key "
"argument. That is, instead of `apply_fun(params, inputs)`, call "
"it like `apply_fun(params, inputs, rng)` where `rng` is a "
"jax.random.PRNGKey value.")
raise ValueError(msg)
if mode == 'train':
keep = random.bernoulli(rng, rate, inputs.shape)
return jnp.where(keep, inputs / rate, 0)
else:
return inputs
return init_fun, apply_fun
# Composing layers via combinators
def serial(*layers):
"""Combinator for composing layers in serial.
Args:
*layers: a sequence of layers, each an (init_fun, apply_fun) pair.
Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the serial
composition of the given sequence of layers.
"""
nlayers = len(layers)
init_funs, apply_funs = zip(*layers)
def init_fun(rng, input_shape):
params = []
for init_fun in init_funs:
rng, layer_rng = random.split(rng)
input_shape, param = init_fun(layer_rng, input_shape)
params.append(param)
return input_shape, params
def apply_fun(params, inputs, **kwargs):
rng = kwargs.pop('rng', None)
rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
for fun, param, rng in zip(apply_funs, params, rngs):
inputs = fun(param, inputs, rng=rng, **kwargs)
return inputs
return init_fun, apply_fun
def parallel(*layers):
"""Combinator for composing layers in parallel.
The layer resulting from this combinator is often used with the FanOut and
FanInSum layers.
Args:
*layers: a sequence of layers, each an (init_fun, apply_fun) pair.
Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the
parallel composition of the given sequence of layers. In particular, the
returned layer takes a sequence of inputs and returns a sequence of outputs
with the same length as the argument `layers`.
"""
nlayers = len(layers)
init_funs, apply_funs = zip(*layers)
def init_fun(rng, input_shape):
rngs = random.split(rng, nlayers)
return zip(*[init(rng, shape) for init, rng, shape
in zip(init_funs, rngs, input_shape)])
def apply_fun(params, inputs, **kwargs):
rng = kwargs.pop('rng', None)
rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
return [f(p, x, rng=r, **kwargs) for f, p, x, r in zip(apply_funs, params, inputs, rngs)]
return init_fun, apply_fun
def shape_dependent(make_layer):
"""Combinator to delay layer constructor pair until input shapes are known.
Args:
make_layer: a one-argument function that takes an input shape as an argument
(a tuple of positive integers) and returns an (init_fun, apply_fun) pair.
Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the same
layer as returned by `make_layer` but with its construction delayed until
input shapes are known.
"""
def init_fun(rng, input_shape):
return make_layer(input_shape)[0](rng, input_shape)
def apply_fun(params, inputs, **kwargs):
return make_layer(inputs.shape)[1](params, inputs, **kwargs)
return init_fun, apply_fun
warnings.warn('jax.experimental.stax is deprecated, '
'import jax.example_libraries.stax instead',
FutureWarning)

View File

@ -13,5 +13,6 @@ filterwarnings =
ignore:numpy.ufunc size changed
ignore:.*experimental feature
ignore:index.*is deprecated.*:DeprecationWarning
ignore:jax.experimental.* is deprecated, import jax.example_libraries.* instead:FutureWarning
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
addopts = --doctest-glob="*.rst"

View File

@ -25,7 +25,7 @@ import jax.numpy as jnp
import jax.scipy.special
from jax import random
from jax import jacfwd, jit
from jax.experimental import stax
from jax.example_libraries import stax
from jax.experimental.jet import jet, fact, zero_series
from jax import lax

View File

@ -24,7 +24,7 @@ import jax._src.test_util as jtu
from jax import jit, grad, jacfwd, jacrev
from jax import tree_util
from jax import lax
from jax.experimental import optimizers
from jax.example_libraries import optimizers
from jax.config import config
config.parse_flags_with_absl()

View File

@ -21,7 +21,7 @@ import numpy as np
from jax._src import test_util as jtu
from jax import random
from jax.experimental import stax
from jax.example_libraries import stax
from jax import dtypes
from jax.config import config