280 Commits

Author SHA1 Message Date
Peter Hawkins
6d357fe884 Use select instead of rem to handle index wraparound. 2019-08-15 16:41:05 -04:00
Brian Patton
4b693777aa
Ensure reps is a tuple (allows list or other iterable) 2019-08-15 11:26:25 -05:00
David Majnemer
079ded4062 Use lax.rem less often in remainder 2019-08-14 12:00:04 -07:00
Peter Hawkins
a8ddf071bd Add test case for concurrent device_get and device_put calls.
Fix concurrency problems in memoize_... decorators.
Rename util.memoize to util.cache.
Remove util.memoize_unary and xla_bridge.memoize_thunk, replace with more general and thread-safe util.memoize that wraps fastcache.
2019-08-09 13:12:44 -04:00
James Bradbury
d0c9f45349 fix jax.numpy reduction init_val for bools 2019-08-03 21:27:06 -07:00
Matthew Johnson
3168006f4a fix np.var dtype bug 2019-08-02 11:26:17 -07:00
Matthew Johnson
1f3b4ae97e
Merge pull request #1091 from fehiepsi/tril
expose tril_indices, triu_indices similar to diag_indices
2019-08-01 20:58:22 -07:00
Matthew Johnson
fd98f957a9
Merge pull request #1088 from fehiepsi/median
Add numpy.median and support ddof for numpy.var
2019-08-01 20:57:28 -07:00
fehiepsi
7a5aecea31 expose tril_indices triu_indices 2019-08-01 17:35:36 -04:00
fehiepsi
45c5bd4fba support ddof for var 2019-08-01 16:20:08 -04:00
fehiepsi
98152d9d07 add numpy.median 2019-08-01 14:19:41 -04:00
Jamie Townsend
47f9eedb60
Correct jax.numpy.pad signature 2019-08-01 15:44:23 +01:00
Peter Hawkins
a350191331
Merge pull request #1074 from hawkinsp/pytree
Add C++ implementation of Pytree logic.
2019-07-30 20:50:09 -04:00
wyjw
4dcae5debf
Update lax_numpy.py 2019-07-29 22:56:30 -04:00
Peter Hawkins
510a9167c5 Add C++ implementation of pytree logic.
Move jaxlib version test into jax/lib/__init__.py. Make jax/lib mirror the structure of jaxlib; e.g., xla_client is now available as jax.lib.xla_client.
2019-07-29 15:06:05 -04:00
wyjw
b89e5a7ac0
shape_c variable taken out 2019-07-29 13:11:43 -04:00
wyjw
5487c784d6
added shape check 2019-07-29 13:05:48 -04:00
wyjw
a06883d91f
Made changes based on review. 2019-07-29 11:53:40 -04:00
wyjw
a87627b57b
Revert "Corrcoef" 2019-07-29 11:24:05 -04:00
twnly
3b6edbbe2f corrcoef 2019-07-29 11:12:02 -04:00
twnly
d9b7c5fa39 made changes to corrcoef 2019-07-29 11:06:08 -04:00
twnly
c2f26d7afc here 2019-07-28 15:17:23 -04:00
Peter Hawkins
00fabfe1e4 Implement DeviceArray.__setitem__ with an error message pointing the user to jax.ops.index_update. 2019-07-21 16:47:03 -04:00
Peter Hawkins
6995a2a8d9
Merge pull request #1025 from hawkinsp/master
Merge scatter and gather indexing implementations.
2019-07-21 21:37:32 +01:00
Peter Hawkins
4eb1820ae2 Add documentation to JAX modules. 2019-07-21 15:55:47 -04:00
Peter Hawkins
eefd551767 Support older Numpy versions that don't have np.quantile.
Fix typo.
2019-07-20 08:44:04 +01:00
Peter Hawkins
f25b2f878b Merge scatter and gather indexing implementations. 2019-07-16 18:55:44 +01:00
Peter Hawkins
35729a692f Fix take_along_axis in x64 mode. 2019-07-13 09:59:19 -04:00
Peter Hawkins
76834c7600 Lower jax.numpy.take_along_axis directly to lax.gather().
This both allows us to avoid reshapes, and also allows us to avoid forming some unnecessarily large iota constants.
2019-07-12 21:43:07 -04:00
Matthew Johnson
a12161435f
Merge pull request #990 from superbobry/memoryview-as-array
Added support for creating arrays via the buffer interface
2019-07-11 21:30:57 -07:00
Peter Hawkins
d92823ecde Make type check for quantile stricter to match what is actually tested. 2019-07-08 13:35:57 -04:00
Peter Hawkins
dc16cb9514 Improve error messages. 2019-07-08 12:13:18 -04:00
Peter Hawkins
97e7455aea Implement np.quantile and np.percentile.
Only implements interpolation='linear' at the moment.
2019-07-08 12:08:22 -04:00
Sergei Lebedev
248ce6e388 Added support for creating arrays via the buffer interface
This allows to call `jax.numpy.array` on objects which do not expose
the `__array__` attribute, but can be viewed as an array through
`memoryview`.
2019-07-08 11:56:09 +01:00
Matthew Johnson
a46108c4b0
Merge pull request #983 from google/numpy-funs
add jax.numpy.cov and tests (cf. #70)
2019-07-06 19:20:23 -07:00
Matthew Johnson
ce2833367a add jax.numpy.cov and tests (cf. #70)
also add jax.numpy.array(..., ndmin=n)
2019-07-06 11:19:34 -07:00
Matthew Johnson
a5e86ae128 enable soft_pmap device persistence
Previously soft_pmap didn't allow for sharded device persistence because
it performs reshapes on the input and output of the underlying pmap
computation corrseponding to splitting out and merging together the
hardware-mapped and software-mapped axes, resepectively. These reshapes
forced the ShardedDeviceArray produced by the pmap computation to be
collected into a (single-device-backed) DeviceArray.

The approach in this commit is to make reshape smarter about
ShardedDeviceArrays so that axis-merging logical reshapes don't force
collection (i.e. don't force re-layout). Instead they now produce a new
ShardedDeviceArray subclass called a ChunkedDeviceArray, which
represents the same logical reshape result but without data movement.

One way to think about the key difference between ShardedDeviceArray and
ChunkedDeviceArray is that when forced the former collects its shards
together using onp.stack while the latter collects its shards with
onp.concatenate. The leading letter of each name is meant to remind us
of that difference (s for stack, c for concatenate).

ChunkedDeviceArrays can be turned back into ShardedDeviceArrays under
particular reshapes, namely reshapes that split the hardware-mapped axis
back out into the leading dimension. This way a sequence of soft_pmapped
computations can maintain device persistence (i.e. not force collection).
Every other operation forces collcetion, just like it does for
ShardedDeviceArrays.
2019-07-06 10:21:59 -07:00
Peter Hawkins
59be9b7a04 Minor doc fixes. 2019-07-02 15:00:47 -04:00
Peter Hawkins
81322caf18 Wrap np.cumsum/cumprod in a jit to avoid materializing padded output. 2019-07-02 11:48:43 -04:00
Peter Hawkins
c1b429be48 Add a jax.numpy.__init__ method that throws a TypeError if called.
Improves the error message for #956, where np.ndarray was called explicitly.
2019-07-01 14:55:39 -04:00
Peter Hawkins
f4ec87dec2 Add support for non-constant shifts to np.roll. 2019-06-27 11:25:27 -04:00
Peter Hawkins
014d235e3c Don't explicitly compute the length; we only need to know if the interval is empty. 2019-06-26 16:30:18 -04:00
Peter Hawkins
07723c4309 Use constant-time algorithm for static slice index calculation.
The current code was linear time in the time of the input array in some cases.

For the benchmark in https://github.com/google/jax/issues/927, compilation time improves from 18s to 0.2s on Mac. Interestingly the performance before this fix seems very different across platforms.
2019-06-26 16:08:48 -04:00
Matthew Johnson
fc367b10a1
Merge pull request #917 from google/soft-pmap
add soft_pmap, plus very rough draft of parallelize
2019-06-26 10:21:57 -07:00
Peter Hawkins
84fc8698eb Improve jax.numpy.arange to return a lazy iota even if an explicit dtype is provided. 2019-06-25 13:02:09 -04:00
Peter Hawkins
4abd1dbaa3 Form a tree of concatenations in jax.numpy.concatenate instead of a single wide concatenation.
Wide concatenations can be slow to compile, particularly on the CPU backend.

Benchmark:
%time np.array(list(range(10000)))
Wall time before: 24.6s
Wall time after: 0.86s.

(This still isn't great, but it's much better!)
2019-06-25 09:58:48 -04:00
Matthew Johnson
fe7329e808 iniital soft_pmap implementation 2019-06-24 19:34:48 -07:00
Peter Hawkins
8cab26f8de
Merge pull request #911 from hawkinsp/takealongaxis
Fix handling of broadcasting in jax.numpy.take_along_axis.
2019-06-24 11:27:26 -04:00
Peter Hawkins
8fc4ce2bdd Fix handling of broadcasting in jax.numpy.take_along_axis. 2019-06-24 10:34:48 -04:00
Peter Hawkins
2332b2ee6f Implement jax.numpy.select. 2019-06-24 09:27:01 -04:00