121 Commits

Author SHA1 Message Date
Matthew Johnson
cac042c34a move asinh/acosh/atanh to lax_numpy.py only 2019-08-31 22:39:51 -07:00
Matthew Johnson
478832c944 avoid Calls inside While/Cond
fixes #1267
2019-08-31 07:35:37 -07:00
Skye Wanderman-Milne
ae835b747e Add jax.devices() and friends, and add devices arg to pmap.
This change adds the following APIs:
* jax.devices(). This returns a list of available Device subclass instances.
* jax.host_id(). Currently always 0, but will be useful on multi-host platforms.
* jax.local_device_count(). Currently always equal to jax.device_count(), but
    will be useful on multi-host platforms.
* Optional `devices` argument to pmap. This can be used to specify which devices
    should be used in the replicated computation.
2019-08-26 11:46:45 -07:00
Matthew Johnson
0cc21c8d72
Merge branch 'master' into multibackend 2019-08-25 13:30:21 -07:00
Matthew Johnson
e90457d737 add dtype warnings to array-creation routines
fixes #1230
2019-08-24 08:19:05 -07:00
Anselm Levskaya
685ca6765e resolve merge conflicts with master 2019-08-22 19:56:27 -07:00
Anselm Levskaya
10e0842f47 Merge branch 'master' into multibackend 2019-08-22 19:52:29 -07:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Dougal Maclaurin
6d71396d56 Start exploring jaxprs without tuples
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Anselm Levskaya
f01fc35ce5 Make op-by-op work with all jit-returned devicearrays. 2019-08-21 00:22:53 -07:00
Anselm Levskaya
cc87fb6013 WIP: experimental multibackend jit 2019-08-19 23:45:36 -07:00
Peter Hawkins
6d357fe884 Use select instead of rem to handle index wraparound. 2019-08-15 16:41:05 -04:00
Peter Hawkins
932877dde6 Remove unnecessary reshape/concatenate in dynamic_slice_in_dim. 2019-08-15 13:31:37 -04:00
Peter Hawkins
099354aab0 Fix Python 2 compatibility. 2019-08-15 13:14:41 -04:00
Peter Hawkins
e28e73b38f Address review comment. 2019-08-15 12:33:36 -04:00
Peter Hawkins
e57a5c42c5 Fix batching rule. 2019-08-15 12:24:38 -04:00
Peter Hawkins
e4a7d30741 Fix batching rule. 2019-08-15 11:42:08 -04:00
Peter Hawkins
d09924f71c Change dynamic-slice and dynamic-update-slice primitives to have one argument per index, not a single array index.
XLA deprecated the single-array-of-indices form of dynamic-slices. It is preferable to use a list of scalar indices since it helps XLA generate more efficient code in the case that some indices are constant but others are not.
2019-08-15 11:26:30 -04:00
Peter Hawkins
3e78a0e290 Keep ShapedArray avals on xla.DeviceArray values
Makes abstractification of DeviceArray values cheaper, which is on the critical path for executing a compiled function.
2019-08-12 10:03:04 -04:00
Peter Hawkins
a8ddf071bd Add test case for concurrent device_get and device_put calls.
Fix concurrency problems in memoize_... decorators.
Rename util.memoize to util.cache.
Remove util.memoize_unary and xla_bridge.memoize_thunk, replace with more general and thread-safe util.memoize that wraps fastcache.
2019-08-09 13:12:44 -04:00
Peter Hawkins
51eb67f755 pmap optimization: Don't precompute size and ndim on DeviceArrays.
We don't even look at them most of the time, and they are in the critical path for running jit/pmap code.

Saves ~1-2ms on a pmap microbenchmark.
2019-08-06 10:29:25 -04:00
James Bradbury
a26963fe87
Merge pull request #1106 from google/jb/bool-reduction
fix jax.numpy reduction init_val for bools
2019-08-05 10:45:17 -07:00
Peter Hawkins
0ef05d7586 Cleanups to xla_bridge.py
Remove stringification of dtypes. The NumPy dtype handling bug has to do with types with different hashes comparing as equal. This only does not happen between two np.dtype objects; it is sufficient to ismply ensure we actually have an np.dtype rather than something dtype-like (e.g., a string or NumPy type object).
Remove xla_bridge.infeed_put, which is unused.
Remove xla_bridge.Shape (use xla_client.Shape instead).
Remove xla_bridge.dtype_to_etype_exact (use xla_client.dtype_to_etype instead).
Remove xla_bridge.device_put (inlined the definition into its callers)
Remove xla_bridge.make_tuple (inlined the definition into its callers).
2019-08-04 12:52:39 -04:00
James Bradbury
d0c9f45349 fix jax.numpy reduction init_val for bools 2019-08-03 21:27:06 -07:00
fehiepsi
1b490fb5e0 Merge remote-tracking branch 'upstream/master' into sort 2019-08-01 12:39:53 -04:00
fehiepsi
e1ee87b559 add batching rule for lax.sort 2019-08-01 12:39:33 -04:00
Matthew Johnson
0600b738f4 fix symbolic zero handling in _pad_transpose
tested manually against example from @matthewdhoffman
2019-07-31 13:27:19 -07:00
Peter Hawkins
d0644d6a3a Remove old xla_data_pb2 compatibility shim. 2019-07-29 15:21:47 -04:00
Peter Hawkins
510a9167c5 Add C++ implementation of pytree logic.
Move jaxlib version test into jax/lib/__init__.py. Make jax/lib mirror the structure of jaxlib; e.g., xla_client is now available as jax.lib.xla_client.
2019-07-29 15:06:05 -04:00
Peter Hawkins
2369d1fe61 Increase minimum Jaxlib version to 0.1.22.
Remove code that preserves backward compatibility with older jaxlib versions.
2019-07-23 21:45:41 -04:00
Peter Hawkins
1479ae9066 Add a common lax._canonicalize_shape method, use on methods that accept shapes in lax.
Explicitly convert shape entries to integers using the Python __index__() method.
Implement __index__ on DeviceArrays so shapes like (1, DeviceArray(2)) work.

Fixes bug where np.full accepted floating point shapes; __index__() errors for non-integer inputs, where int() would silently cast and drop information.
2019-07-23 16:19:02 -04:00
Peter Hawkins
f64332b394 Remove assertions in scatter/dynamic_update_slice JVP rules that test whether index tangents are symbolically zero.
Since indices are integers, their tangents should be zero anyway, and symbolic zeros should always be treated as an optimization rather than a necessary precondition.
2019-07-23 14:18:07 -04:00
Peter Hawkins
0850318a83 Add support for mixing basic and advanced indexing in the same scatter operation. 2019-07-14 11:55:26 -04:00
Peter Hawkins
05ff396716 Add batching rule for reduce_window_p. Allows batching of np.cumprod. 2019-07-13 10:22:26 -04:00
Matthew Johnson
79668ae4ed fix reduce_window batching rule 2019-07-06 11:58:33 -07:00
Matthew Johnson
ddf7f69cad fix seleect broadcasting rule 2019-07-06 11:52:24 -07:00
Matthew Johnson
febad2d863 fix broadcast_in_dim batching rule 2019-07-06 11:47:50 -07:00
Matthew Johnson
ccb1760f49 add a lot of systematic vmap tests 2019-07-06 11:28:15 -07:00
Matthew Johnson
a5e86ae128 enable soft_pmap device persistence
Previously soft_pmap didn't allow for sharded device persistence because
it performs reshapes on the input and output of the underlying pmap
computation corrseponding to splitting out and merging together the
hardware-mapped and software-mapped axes, resepectively. These reshapes
forced the ShardedDeviceArray produced by the pmap computation to be
collected into a (single-device-backed) DeviceArray.

The approach in this commit is to make reshape smarter about
ShardedDeviceArrays so that axis-merging logical reshapes don't force
collection (i.e. don't force re-layout). Instead they now produce a new
ShardedDeviceArray subclass called a ChunkedDeviceArray, which
represents the same logical reshape result but without data movement.

One way to think about the key difference between ShardedDeviceArray and
ChunkedDeviceArray is that when forced the former collects its shards
together using onp.stack while the latter collects its shards with
onp.concatenate. The leading letter of each name is meant to remind us
of that difference (s for stack, c for concatenate).

ChunkedDeviceArrays can be turned back into ShardedDeviceArrays under
particular reshapes, namely reshapes that split the hardware-mapped axis
back out into the leading dimension. This way a sequence of soft_pmapped
computations can maintain device persistence (i.e. not force collection).
Every other operation forces collcetion, just like it does for
ShardedDeviceArrays.
2019-07-06 10:21:59 -07:00
Matthew Johnson
db52d42597 also fix lax.complex jvp, enable test 2019-07-05 14:39:32 -07:00
Matthew Johnson
93841df822 fix lax.imag jvp and enable test, fixes #979 2019-07-05 14:32:04 -07:00
Peter Hawkins
a06ba06f97 Update comments. 2019-07-02 13:23:05 -04:00
Peter Hawkins
165df6204b Simplify reduce-precision logic.
Enable TPU gradient tests only up to order 1. The first-order JVP of reduce-window tests select_and_scatter_add, which is the part changed by this PR.
2019-07-02 11:34:49 -04:00
Peter Hawkins
40560d2c9a Refactor select_and_gather_add implementation to improve readability.
Change implementation to use ReducePrecision to perform half-word reductions.
2019-07-01 22:26:36 -04:00
Peter Hawkins
db369091a2 Add support for higher derivatives of reduce-window-min/max at reduced precision. On CPU/GPU this means support for float64 derivatives, and on TPU this means support for float32 derivatives.
Warn if we are forced to be imprecise.
2019-06-28 20:27:10 -04:00
Peter Hawkins
3e914e17b0 Improve documentation for precision. 2019-06-28 14:06:24 -04:00
Peter Hawkins
bca27fea8b Simplify precision specification: only allow a single precision for an entire operator. 2019-06-28 12:48:44 -04:00
Peter Hawkins
0af9da7662 Add precision option to lax dot and conv APIs.
Set a default precision of "highest" in LU decomposition.
Enable a number of dot and conv tests on TPU under highest precision.
Enable linalg tests that use LU decomposition on TPU.
2019-06-28 10:00:39 -04:00
Peter Hawkins
3b4521b1f6 Enable convolutions for non float32 types. 2019-06-27 17:17:49 -04:00
Peter Hawkins
990c2df123 Implement a pure Python LU decomposition that can be used on platforms where we do not otherwise have a better implementation.
Restructure xla.lower_fun and trace_unwrapped_to_jaxpr so the instantiate option can be passed to them, separately from any function arguments.
2019-06-27 14:50:29 -04:00