* Better batching rule for triangular_solve
Now, if only the right hand side argument `b` is batched, we leverage
triangular solve's builtin batching for handling multiple right-hand-side
vectors.
This makes the performance of `vmap` over only the second argument of linear
solves equivalent to relying on builtin batching::
rs = onp.random.RandomState(0)
a = rs.randn(500, 500) + 0.1 * np.eye(500)
b_mat = jax.device_put(rs.randn(500, 10))
solve1 = jax.jit(np.linalg.solve)
solve2 = jax.jit(jax.vmap(np.linalg.solve, in_axes=(None, 1), out_axes=1))
Before::
In [6]: %timeit jax.device_get(solve1(a, b_mat))
3.88 ms ± 293 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 8x slower :(
In [9]: %timeit jax.device_get(solve2(a, b_mat))
23.5 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Now::
In [2]: %timeit jax.device_get(solve1(a, b_mat))
3.76 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# same speed :)
In [3]: %timeit jax.device_get(solve2(a, b_mat))
3.72 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
* Test failures
* Check b.ndim == 2 in triangular solve shape rule
Currently pooling layers have a default channel-last spec that is explicitly 2D. This change will make this default work for arbitrary input dimensionality.
Looking at the recent [doc failures](https://readthedocs.org/projects/jax/builds/), a few are due to
- Cell timeouts (which this tries to fix),
- Execution timeout (readthedocs gives 900seconds to build, total -- most of the time for jax is in executing the notebooks),
- Other somewhat random/inscrutable errors (and I could imagine a world in which one of the timeouts ends up triggering an inscrutable error in the execution).
* Add batching rule for custom_linear_solve
The custom_linear_solve primitive is batched by batching each of the
jaxprs stored in the primitive. Conceptually, this corresponds to
transforming the initial linear system solve into an implicit "block
diagonal" solve, where matvec/vecmat apply a linear operation to each
part of the input, and solve/transpose_solve solve each block
separately.
Note that the batching of the input and output must agree between all
four jaxprs, since the custom_linear_solve JVP and transpose rules
assume the shapes all match. In particular, the JVP passes the output
of solve into (the JVP of) matvec and then that output back into solve,
and the transpose rule causes the (cotangents for the) output of solve
to be fed back into solve_transpose and vecmat. To ensure consistency
we can do a fixed-point iteration to figure out whether each component
of x and b are batched or not.
* Add support for batched solves without a transpose
If there is no transpose solve, we don't need to batch the transposed
versions of the jaxprs.
* Add pytree test for custom linear solve
Custom linear solve supports solves that act on pytrees, not just single
arrays. This commit adds a test for a linear operator and solver that
operate on Python lists of scalars instead of vectors, and confirms that
transformations work correctly. The batching behavior has been chosen
to make sure it requires a few iterations to find the fixed point of
which elements have a batch axis.
* Initial implementation of DLPack support.
Unfortunately there are still a few bugs in the jaxlib DLPack support, so this code won't be ready to use until jaxlib 0.1.39.
* Fix test failures.
* Update XLA.
Fix failing torch test.
* Ensure ShapedArray.shape is always a tuple of builtins integers
Currently, it can sometimes include elements of type int64, e.g.,
In [1]: import jax.numpy as jnp
In [2]: x = jnp.arange(3) + 1
In [3]: x.shape # looks fine at first glance
Out[3]: (3,)
In [4]: type(x.shape[0]) # yikes!
Out[4]: numpy.int64
This confirms my hypothesis that NumPy's scalar types are the root of all evil.
* Allow Poly in shapes
* Simple shape coercion in ShapedArray
* cleaner
Some higher-order primitives, like 'scan' and 'while', benefit from
distinguishing constants from other inputs to their closure-converted
function arguments; the reason is that for those primitives constants
act differently from the other inputs, which are loop carries or
scanned-over values, and are handled differently by transformations. For
example, they're used differently than loop carries in lattice
fixed-point computations. As another example, in scan the constants in
the forward computation are fanned out, so when transposing scan we
generate an accumulate-add.
However, these considerations don't hold true for cond: since there's no
looping going on (and hence no lattice fixed-points), constants are
treated just like the other operands. So we don't need to carry around
the distinction. That simplifies the cond rules a bit.
Co-authored-by: Roy Frostig <frostig@google.com>
It is relatively common to apply collective operations to trees. For example in
sync distributed training it is typical to sum all gradients across replicas
`grads = jax.tree_map(partial(lax.psum, axis_name='i'), grads)`. We can make
this a little more convenient by making lax parallel ops support trees directly:
`grads = lax.psum(grads, 'i')`.
There is room for improvement in this change. We should in some (all?) cases
just pass a tuple of values to XLA (rather than bind the primivive n times bind
once with a tuple of n values) however this produced strange values when
combined with pmap and a fix was not obvious. This is something we can follow up
on without users having to change their code.
Adds additional subsections of the `Building from source` documentation
page to make it more obvious that you can install `jaxlib` from pip
when doing Python-only development.
* work around https://github.com/bazelbuild/bazel/issues/9254 by setting BAZEL_LINKLIBS=-lstdc++
* drop CUDA 9.0 support, since we use a batched kernel only present in CUDA 9.2 or later.
* drop Python 2.7 support.