3569 Commits

Author SHA1 Message Date
Peter Hawkins
3c9ae5e221
Add jax.scipy.stats.logistic to documentation. (#2149) 2020-02-03 12:44:57 -05:00
Peter Hawkins
0b1d2fc3d1
Avoid accidental type promotion in gamma sampler gradient. (#2150)
Reformat gamma sampler to use 2 space indent, consistent with the rest of JAX.
2020-02-03 12:44:46 -05:00
Stephan Hoyer
0644f5c561
Better batching rule for triangular_solve (#2138)
* Better batching rule for triangular_solve

Now, if only the right hand side argument `b` is batched, we leverage
triangular solve's builtin batching for handling multiple right-hand-side
vectors.

This makes the performance of `vmap` over only the second argument of linear
solves equivalent to relying on builtin batching::

    rs = onp.random.RandomState(0)
    a = rs.randn(500, 500) + 0.1 * np.eye(500)
    b_mat = jax.device_put(rs.randn(500, 10))
    solve1 = jax.jit(np.linalg.solve)
    solve2 = jax.jit(jax.vmap(np.linalg.solve, in_axes=(None, 1), out_axes=1))

Before::

    In [6]: %timeit jax.device_get(solve1(a, b_mat))
    3.88 ms ± 293 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

    # 8x slower :(
    In [9]: %timeit jax.device_get(solve2(a, b_mat))
    23.5 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Now::

    In [2]: %timeit jax.device_get(solve1(a, b_mat))
    3.76 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

    # same speed :)
    In [3]: %timeit jax.device_get(solve2(a, b_mat))
    3.72 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

* Test failures

* Check b.ndim == 2 in triangular solve shape rule
2020-02-03 09:27:03 -08:00
Roman Novak
1022573b26
Make stax pooling layers accept spec=None (#2145)
Currently pooling layers have a default channel-last spec that is explicitly 2D. This change will make this default work for arbitrary input dimensionality.
2020-02-03 10:31:12 -05:00
Colin
d6489103f7
Bump cell execution timeout (#2147)
Looking at the recent [doc failures](https://readthedocs.org/projects/jax/builds/), a few are due to 

- Cell timeouts (which this tries to fix),
- Execution timeout (readthedocs gives 900seconds to build, total -- most of the time for jax is in executing the notebooks),
- Other somewhat random/inscrutable errors (and I could imagine a world in which one of the timeouts ends up triggering an inscrutable error in the execution).
2020-02-03 10:15:19 -05:00
Peter Hawkins
fe041c7590 Set minimum Bazel version to 1.2.1. 2020-02-03 10:13:51 -05:00
Ruizhe Zhao
8c7fc3919d
Upgrade bazel from 0.29.1 to 1.2.1 (#2137) 2020-02-03 10:12:40 -05:00
Matthew Johnson
ae1d6b875f
fix remat with nontrivial env (#2136)
fixes #2030
2020-01-31 23:47:30 -08:00
Skye Wanderman-Milne
efbdaf66bf Adjust scipy_stats_test.py tolerance. 2020-01-31 11:19:55 -08:00
Peter Hawkins
91cd20b173
Update documentation and changelog to mention DLPack and array interface support. (#2134) 2020-01-31 11:15:04 -05:00
Peter Hawkins
843e22dd17
Support __cuda_array_interface__ on JAX DeviceArrays. (#2133)
Allows exporting GPU device-resident arrays to other libraries, e.g., CuPy.
2020-01-31 10:09:40 -05:00
Tuan Nguyen
4c30c0285c
Implement scipy.stats.logistic (#1993) 2020-01-30 20:19:01 -05:00
Peter Hawkins
0103929930
Revert "Use lax.erf_inv to implement ndtri. (#2122)" (#2128)
This reverts commit bbcbe23c1ee52cf76542f3a60f8344832a0dd05f.

This change appears to cause test failures in TF probability's JAX backend.
2020-01-30 19:19:41 -05:00
Roy Frostig
664a4e123d
VJP of cond, via partial eval + transpose (#2091)
VJP (grad) of lax.cond, via partial eval + transpose


Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-01-30 15:03:00 -08:00
Peter Hawkins
5362ea1d3d
Use make_tuple instead of from_pyval to build types. (#2127) 2020-01-30 16:17:08 -05:00
Daniel Johnson
60bc0c0d77
Batching rule for custom_linear_solve (#2099)
* Add batching rule for custom_linear_solve

The custom_linear_solve primitive is batched by batching each of the
jaxprs stored in the primitive. Conceptually, this corresponds to
transforming the initial linear system solve into an implicit "block
diagonal" solve, where matvec/vecmat apply a linear operation to each
part of the input, and solve/transpose_solve solve each block
separately.

Note that the batching of the input and output must agree between all
four jaxprs, since the custom_linear_solve JVP and transpose rules
assume the shapes all match. In particular, the JVP passes the output
of solve into (the JVP of) matvec and then that output back into solve,
and the transpose rule causes the (cotangents for the) output of solve
to be fed back into solve_transpose and vecmat. To ensure consistency
we can do a fixed-point iteration to figure out whether each component
of x and b are batched or not.

* Add support for batched solves without a transpose

If there is no transpose solve, we don't need to batch the transposed
versions of the jaxprs.

* Add pytree test for custom linear solve

Custom linear solve supports solves that act on pytrees, not just single
arrays. This commit adds a test for a linear operator and solver that
operate on Python lists of scalars instead of vectors, and confirms that
transformations work correctly. The batching behavior has been chosen
to make sure it requires a few iterations to find the fixed point of
which elements have a batch axis.
2020-01-30 10:02:58 -08:00
Peter Hawkins
511d33983a
Initial implementation of DLPack support. (#2123)
* Initial implementation of DLPack support.

Unfortunately there are still a few bugs in the jaxlib DLPack support, so this code won't be ready to use until jaxlib 0.1.39.

* Fix test failures.

* Update XLA.

Fix failing torch test.
2020-01-29 23:19:14 -05:00
Stephan Hoyer
be2704e425
Ensure ShapedArray.shape is always a tuple of builtins integers (#2039)
* Ensure ShapedArray.shape is always a tuple of builtins integers

Currently, it can sometimes include elements of type int64, e.g.,

    In [1]: import jax.numpy as jnp

    In [2]: x = jnp.arange(3) + 1

    In [3]: x.shape  # looks fine at first glance
    Out[3]: (3,)

    In [4]: type(x.shape[0])  # yikes!
    Out[4]: numpy.int64

This confirms my hypothesis that NumPy's scalar types are the root of all evil.

* Allow Poly in shapes

* Simple shape coercion in ShapedArray

* cleaner
2020-01-29 14:24:11 -08:00
Peter Hawkins
1c134f8a6d
Rename Tracer.trace to Tracer._trace. (#2114)
Makes the .trace() method work on arrays.
2020-01-29 16:23:27 -05:00
Matthew Johnson
96102dc727
simplify cond by removing consts (#2102)
Some higher-order primitives, like 'scan' and 'while', benefit from
distinguishing constants from other inputs to their closure-converted
function arguments; the reason is that for those primitives constants
act differently from the other inputs, which are loop carries or
scanned-over values, and are handled differently by transformations. For
example, they're used differently than loop carries in lattice
fixed-point computations. As another example, in scan the constants in
the forward computation are fanned out, so when transposing scan we
generate an accumulate-add.

However, these considerations don't hold true for cond: since there's no
looping going on (and hence no lattice fixed-points), constants are
treated just like the other operands. So we don't need to carry around
the distinction. That simplifies the cond rules a bit.

Co-authored-by: Roy Frostig <frostig@google.com>
2020-01-29 13:17:39 -08:00
Peter Hawkins
bbcbe23c1e
Use lax.erf_inv to implement ndtri. (#2122) 2020-01-29 16:11:30 -05:00
Chris Jones
a503395692
Add num_partitions parameter to get_compile_options. (#2052) 2020-01-29 11:35:48 -08:00
Skye Wanderman-Milne
633aa00392 Update README to drop CUDA 9.0 and add CUDA 10.2 2020-01-29 11:23:27 -08:00
Peter Hawkins
991324f8df
Increase minimum jaxlib version to 0.1.38. (#2120) 2020-01-29 14:16:58 -05:00
Skye Wanderman-Milne
09d2421f28 Update jaxlib version in README to 0.1.38 2020-01-29 10:48:20 -08:00
Skye Wanderman-Milne
409d057f76
Build CUDA 10.2 jaxlibs. (#2121)
Also adds install_cuda.sh script that sets appropriate nccl and cuDNN versions.
2020-01-29 10:47:17 -08:00
Peter Hawkins
48928a18c4
Fix test failures in incomplete gamma functions with Jaxlib 0.1.38. (#2118) 2020-01-29 13:24:45 -05:00
Tom Hennigan
ab13cf3f65
Add lax.pmean(x, axis_name). (#2081) 2020-01-29 10:10:48 -08:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Peter Hawkins
4803a75c3b
Implement np.block. (#2106)
Rename np.removechars to _removechars; it should never have been public.
2020-01-29 11:55:53 -05:00
Srinivas Vasudevan
62966d9a9f
Add gammainc/gammaincc to JAX (#2064) 2020-01-29 11:25:21 -05:00
Peter Hawkins
cfef568dd6
Implement jax.scipy.linalg.block_diag. (#2113) 2020-01-29 11:24:40 -05:00
Peter Hawkins
0904e5ff74
Fix implementation of cumsum/cumprod for boolean inputs. (#2112)
Check for number inputs in the reduce_window_sum dtype rule.
2020-01-29 10:51:39 -05:00
Peter Hawkins
04befac4f6
Fix error case in tensordot. (#2111) 2020-01-29 10:14:36 -05:00
Peter Hawkins
102ce6f0ac
Merge pull request #2100 from hawkinsp/devices
Use Device hash and equality instead of using a (class, id) pair.
2020-01-29 09:11:14 -05:00
Tom Hennigan
4e575e1492
Support trees in lax parallel operations. (#1953)
It is relatively common to apply collective operations to trees. For example in
sync distributed training it is typical to sum all gradients across replicas
`grads = jax.tree_map(partial(lax.psum, axis_name='i'), grads)`. We can make
this a little more convenient by making lax parallel ops support trees directly:
`grads = lax.psum(grads, 'i')`.

There is room for improvement in this change. We should in some (all?) cases
just pass a tuple of values to XLA (rather than bind the primivive n times bind
once with a tuple of n values) however this produced strange values when
combined with pmap and a fix was not obvious. This is something we can follow up
on without users having to change their code.
2020-01-28 19:04:59 -08:00
Peter Hawkins
7b7c89db98
Merge pull request #2086 from romanngg/patch-6
Make the reverse operator work on empty list of dimensions
2020-01-28 21:51:53 -05:00
James Bradbury
1a5d9c531a
clear compilation cache before metadata tests (#2103) 2020-01-28 18:45:45 -08:00
Peter Hawkins
9f7f161c5f Incorporate review comments. 2020-01-28 21:42:45 -05:00
Matthew Johnson
d46e82d0ab
tweak readme announcement text again 2020-01-28 18:16:04 -08:00
Matthew Johnson
71811be3b9
tweak top-line announcement text in readme 2020-01-28 18:15:16 -08:00
Skye Wanderman-Milne
6aaf257d8a Update WORKSPACE jaxlib-v0.1.38 2020-01-28 18:04:52 -08:00
Matthew Johnson
1afcac70df
tweak readme not to have bad line wrap 2020-01-28 16:41:21 -08:00
Peter Hawkins
35810c9dcd
Merge pull request #2101 from hawkinsp/tolist
Implement ndarray.tolist() on DeviceArray.
2020-01-28 17:03:35 -05:00
Peter Hawkins
126ae7fccf Implement ndarray.tolist() on DeviceArray. 2020-01-28 15:58:02 -05:00
Daniel Johnson
b68d8b5c4f Clarify instructions for building from source. (#2093)
Adds additional subsections of the `Building from source` documentation
page to make it more obvious that you can install `jaxlib` from pip
when doing Python-only development.
2020-01-28 12:48:37 -08:00
Peter Hawkins
b54c18efb4 Use Device hash and equality instead of using a (class, id) pair.
We couldn't figure out why we did it this way in the first place and all the tests we have pass.
2020-01-28 15:45:40 -05:00
Peter Hawkins
58f949f316
Merge pull request #2098 from hawkinsp/jaxlib
Update Jaxlib docker build.
2020-01-28 11:33:58 -05:00
Peter Hawkins
55f2d3be27 Update Jaxlib docker build.
* work around https://github.com/bazelbuild/bazel/issues/9254 by setting BAZEL_LINKLIBS=-lstdc++
* drop CUDA 9.0 support, since we use a batched kernel only present in CUDA 9.2 or later.
* drop Python 2.7 support.
2020-01-28 11:17:21 -05:00
Peter Hawkins
9a0338d6aa
Update README.md and CHANGELOG.md. (#2096) 2020-01-28 10:01:17 -05:00