
The goal of this change is to support shape polymorphism for operations such as average (which needs to divide by the size of a dimension) or indexing (which needs to normalize indices by comparing them with 0 and adding dimension size for negative indices). In both of these cases the size of a dimenion needs to be used as a value in the array computation. In general, the size of a dimension is used only to customize primitives. This change introduces `core.dim_as_value` which must be used on a dimension size before using it as a value in the array computation. E.g., ``` def average(x): return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0]) ``` This function is the identity function if the dimension size is constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`. Note that this does not change fundamentally the flavor of shape polymorphism supported in jax2tf: intermediate shapes and their values may depend on the input shapes, but never does a shape depend on the input values. In fact, one could have expressed the `dim_as_value` already: ``` def dim_as_value(d): jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,))) ``` We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`, `lax.dynamic_slice`, `lax.dynamic_update_slice` by using `core.dim_as_value` internally, but to fully roll-up the solution we need to make `core.dim_as_value` a public API and teach the users how to use it when they want to use shape polymorphism. Alternatively, perhaps there is a way to automatically convert dimension polynomials to values when passed to the lax primitives.
Mini-libraries
JAX provides some small, experimental libraries for machine learning. These libraries are in part about providing tools and in part about serving as examples for how to build such libraries using JAX. Each one is only <300 source lines of code, so take a look inside and adapt them as you need!
👉 Note: each mini-library is meant to be an inspiration, but not a prescription.
To serve that purpose, it is best to keep their code samples minimal; so we generally will not merge PRs adding new features. Instead, please send your lovely pull requests and design ideas to more fully-featured libraries like Haiku, Flax, or Trax.
Neural-net building with Stax
Stax is a functional neural network building library. The basic idea is that
a single layer or an entire network can be modeled as an (init_fun, apply_fun)
pair. The init_fun
is used to initialize network parameters and the
apply_fun
takes parameters and inputs to produce outputs. There are
constructor functions for common basic pairs, like Conv
and Relu
, and these
pairs can be composed in series using stax.serial
or in parallel using
stax.parallel
.
Here’s an example:
import jax.numpy as jnp
from jax import random
from jax.experimental import stax
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax
# Use stax to set up network initialization and evaluation functions
net_init, net_apply = stax.serial(
Conv(32, (3, 3), padding='SAME'), Relu,
Conv(64, (3, 3), padding='SAME'), Relu,
MaxPool((2, 2)), Flatten,
Dense(128), Relu,
Dense(10), LogSoftmax,
)
# Initialize parameters, not committing to a batch shape
rng = random.PRNGKey(0)
in_shape = (-1, 28, 28, 1)
out_shape, net_params = net_init(rng, in_shape)
# Apply network to dummy inputs
inputs = jnp.zeros((128, 28, 28, 1))
predictions = net_apply(net_params, inputs)
First-order optimization
JAX has a minimal optimization library focused on stochastic first-order
optimizers. Every optimizer is modeled as an (init_fun, update_fun, get_params)
triple of functions. The init_fun
is used to initialize the
optimizer state, which could include things like momentum variables, and the
update_fun
accepts a gradient and an optimizer state to produce a new
optimizer state. The get_params
function extracts the current iterate (i.e.
the current parameters) from the optimizer state. The parameters being optimized
can be ndarrays or arbitrarily-nested list/tuple/dict structures, so you can
store your parameters however you’d like.
Here’s an example, using jit
to compile the whole update end-to-end:
from jax.experimental import optimizers
from jax import jit, grad
# Define a simple squared-error loss
def loss(params, batch):
inputs, targets = batch
predictions = net_apply(params, inputs)
return jnp.sum((predictions - targets)**2)
# Use optimizers to set optimizer initialization and update functions
opt_init, opt_update, get_params = optimizers.momentum(step_size=1e-3, mass=0.9)
# Define a compiled update step
@jit
def step(i, opt_state, batch):
params = get_params(opt_state)
g = grad(loss)(params, batch)
return opt_update(i, g, opt_state)
# Dummy input data stream
data_generator = ((jnp.zeros((128, 28, 28, 1)), jnp.zeros((128, 10)))
for _ in range(10))
# Optimize parameters in a loop
opt_state = opt_init(net_params)
for i in range(10):
opt_state = step(i, opt_state, next(data_generator))
net_params = get_params(opt_state)