rocm_jax/jax/experimental
George Necula b62ceba91c [jax2tf] Expand shape polymorphism support to use dimension polynomials as values.
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.

This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,

```
def average(x):
   return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```

This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.

Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:

```
def dim_as_value(d):
   jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```

We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
2021-07-27 09:02:15 +03:00
..
2021-06-07 14:51:34 -07:00
2021-01-11 14:21:07 -08:00
2021-02-17 01:18:14 +00:00
2020-09-14 12:31:51 -07:00

Mini-libraries

JAX provides some small, experimental libraries for machine learning. These libraries are in part about providing tools and in part about serving as examples for how to build such libraries using JAX. Each one is only <300 source lines of code, so take a look inside and adapt them as you need!

👉 Note: each mini-library is meant to be an inspiration, but not a prescription.

To serve that purpose, it is best to keep their code samples minimal; so we generally will not merge PRs adding new features. Instead, please send your lovely pull requests and design ideas to more fully-featured libraries like Haiku, Flax, or Trax.

Neural-net building with Stax

Stax is a functional neural network building library. The basic idea is that a single layer or an entire network can be modeled as an (init_fun, apply_fun) pair. The init_fun is used to initialize network parameters and the apply_fun takes parameters and inputs to produce outputs. There are constructor functions for common basic pairs, like Conv and Relu, and these pairs can be composed in series using stax.serial or in parallel using stax.parallel.

Heres an example:

import jax.numpy as jnp
from jax import random
from jax.experimental import stax
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax

# Use stax to set up network initialization and evaluation functions
net_init, net_apply = stax.serial(
    Conv(32, (3, 3), padding='SAME'), Relu,
    Conv(64, (3, 3), padding='SAME'), Relu,
    MaxPool((2, 2)), Flatten,
    Dense(128), Relu,
    Dense(10), LogSoftmax,
)

# Initialize parameters, not committing to a batch shape
rng = random.PRNGKey(0)
in_shape = (-1, 28, 28, 1)
out_shape, net_params = net_init(rng, in_shape)

# Apply network to dummy inputs
inputs = jnp.zeros((128, 28, 28, 1))
predictions = net_apply(net_params, inputs)

First-order optimization

JAX has a minimal optimization library focused on stochastic first-order optimizers. Every optimizer is modeled as an (init_fun, update_fun, get_params) triple of functions. The init_fun is used to initialize the optimizer state, which could include things like momentum variables, and the update_fun accepts a gradient and an optimizer state to produce a new optimizer state. The get_params function extracts the current iterate (i.e. the current parameters) from the optimizer state. The parameters being optimized can be ndarrays or arbitrarily-nested list/tuple/dict structures, so you can store your parameters however youd like.

Heres an example, using jit to compile the whole update end-to-end:

from jax.experimental import optimizers
from jax import jit, grad

# Define a simple squared-error loss
def loss(params, batch):
  inputs, targets = batch
  predictions = net_apply(params, inputs)
  return jnp.sum((predictions - targets)**2)

# Use optimizers to set optimizer initialization and update functions
opt_init, opt_update, get_params = optimizers.momentum(step_size=1e-3, mass=0.9)

# Define a compiled update step
@jit
def step(i, opt_state, batch):
  params = get_params(opt_state)
  g = grad(loss)(params, batch)
  return opt_update(i, g, opt_state)

# Dummy input data stream
data_generator = ((jnp.zeros((128, 28, 28, 1)), jnp.zeros((128, 10)))
                  for _ in range(10))

# Optimize parameters in a loop
opt_state = opt_init(net_params)
for i in range(10):
  opt_state = step(i, opt_state, next(data_generator))
net_params = get_params(opt_state)