197 Commits

Author SHA1 Message Date
Peter Hawkins
97e7455aea Implement np.quantile and np.percentile.
Only implements interpolation='linear' at the moment.
2019-07-08 12:08:22 -04:00
Matthew Johnson
a46108c4b0
Merge pull request #983 from google/numpy-funs
add jax.numpy.cov and tests (cf. #70)
2019-07-06 19:20:23 -07:00
Matthew Johnson
ce2833367a add jax.numpy.cov and tests (cf. #70)
also add jax.numpy.array(..., ndmin=n)
2019-07-06 11:19:34 -07:00
Matthew Johnson
a5e86ae128 enable soft_pmap device persistence
Previously soft_pmap didn't allow for sharded device persistence because
it performs reshapes on the input and output of the underlying pmap
computation corrseponding to splitting out and merging together the
hardware-mapped and software-mapped axes, resepectively. These reshapes
forced the ShardedDeviceArray produced by the pmap computation to be
collected into a (single-device-backed) DeviceArray.

The approach in this commit is to make reshape smarter about
ShardedDeviceArrays so that axis-merging logical reshapes don't force
collection (i.e. don't force re-layout). Instead they now produce a new
ShardedDeviceArray subclass called a ChunkedDeviceArray, which
represents the same logical reshape result but without data movement.

One way to think about the key difference between ShardedDeviceArray and
ChunkedDeviceArray is that when forced the former collects its shards
together using onp.stack while the latter collects its shards with
onp.concatenate. The leading letter of each name is meant to remind us
of that difference (s for stack, c for concatenate).

ChunkedDeviceArrays can be turned back into ShardedDeviceArrays under
particular reshapes, namely reshapes that split the hardware-mapped axis
back out into the leading dimension. This way a sequence of soft_pmapped
computations can maintain device persistence (i.e. not force collection).
Every other operation forces collcetion, just like it does for
ShardedDeviceArrays.
2019-07-06 10:21:59 -07:00
Peter Hawkins
59be9b7a04 Minor doc fixes. 2019-07-02 15:00:47 -04:00
Peter Hawkins
81322caf18 Wrap np.cumsum/cumprod in a jit to avoid materializing padded output. 2019-07-02 11:48:43 -04:00
Peter Hawkins
c1b429be48 Add a jax.numpy.__init__ method that throws a TypeError if called.
Improves the error message for #956, where np.ndarray was called explicitly.
2019-07-01 14:55:39 -04:00
Peter Hawkins
f4ec87dec2 Add support for non-constant shifts to np.roll. 2019-06-27 11:25:27 -04:00
Peter Hawkins
014d235e3c Don't explicitly compute the length; we only need to know if the interval is empty. 2019-06-26 16:30:18 -04:00
Peter Hawkins
07723c4309 Use constant-time algorithm for static slice index calculation.
The current code was linear time in the time of the input array in some cases.

For the benchmark in https://github.com/google/jax/issues/927, compilation time improves from 18s to 0.2s on Mac. Interestingly the performance before this fix seems very different across platforms.
2019-06-26 16:08:48 -04:00
Matthew Johnson
fc367b10a1
Merge pull request #917 from google/soft-pmap
add soft_pmap, plus very rough draft of parallelize
2019-06-26 10:21:57 -07:00
Peter Hawkins
84fc8698eb Improve jax.numpy.arange to return a lazy iota even if an explicit dtype is provided. 2019-06-25 13:02:09 -04:00
Peter Hawkins
4abd1dbaa3 Form a tree of concatenations in jax.numpy.concatenate instead of a single wide concatenation.
Wide concatenations can be slow to compile, particularly on the CPU backend.

Benchmark:
%time np.array(list(range(10000)))
Wall time before: 24.6s
Wall time after: 0.86s.

(This still isn't great, but it's much better!)
2019-06-25 09:58:48 -04:00
Matthew Johnson
fe7329e808 iniital soft_pmap implementation 2019-06-24 19:34:48 -07:00
Peter Hawkins
8cab26f8de
Merge pull request #911 from hawkinsp/takealongaxis
Fix handling of broadcasting in jax.numpy.take_along_axis.
2019-06-24 11:27:26 -04:00
Peter Hawkins
8fc4ce2bdd Fix handling of broadcasting in jax.numpy.take_along_axis. 2019-06-24 10:34:48 -04:00
Peter Hawkins
2332b2ee6f Implement jax.numpy.select. 2019-06-24 09:27:01 -04:00
Peter Hawkins
3615635d19 Set copy=False in asarray. 2019-06-21 14:02:11 -04:00
Peter Hawkins
cc0cdc30d7 Force a copy to device in jax.numpy.array() if copy=True.
The motivation is the following example where the array is mutated after being passed to jax.numpy.array:
```
>>> a = np.array([42])
>>> b = jnp.array(a)
>>> a[0] = 24
>>> b
array([24])
```

Also fix up asarray() to have the precise signature of onp.asarray.
2019-06-21 12:12:22 -04:00
Peter Hawkins
293b983981 Enable direct device-to-device copies on GPU and TPU.
Update XLA to include device-to-device copies.
2019-06-21 10:27:34 -04:00
Peter Hawkins
c2ac84e7d4 Add support for reflect, symmetric, and wrap padding modes to np.pad. 2019-06-20 19:50:12 -04:00
Peter Hawkins
f9c72effed Use _canonicalize_axis for reduction axis dimensions to catch invalid axes. 2019-06-20 08:40:21 -04:00
Peter Hawkins
24ea3ac32e Add a jit annotation around np.where.
Avoids materializing broadcast scalars inside where in op-by-op mode.

Since np.tril appears in the linear part of the Cholesky JVP rule, change np.tril/triu to avoid where in favor of calling lax.select() directly. Ban 1D arguments to np.tril/triu, which aren't a documented behavior of the numpy implementation.
2019-06-18 19:08:16 -04:00
Peter Hawkins
0190684ee2
Merge pull request #866 from hawkinsp/master
Implement np.ix_, for non-bool inputs.
2019-06-17 18:39:45 -06:00
Peter Hawkins
fbdb204d68 Fix type of np.ix_ for empty arrays to match numpy. 2019-06-17 20:03:10 -04:00
Peter Hawkins
ec685bf8ae Implement np.ix_, for non-bool inputs. 2019-06-17 17:08:27 -04:00
Matthew Johnson
a56a7d02ff make threefry_2x32 not do any op-by-op stuff 2019-06-11 14:56:21 -07:00
Matthew Johnson
aebf7eb088 fix jax.numpy.transpose arg name 'axes' 2019-06-10 12:17:20 -07:00
Matthew Johnson
1829508b28 np.arange shouldn't pop its kwargs (fixes #830) 2019-06-09 20:18:18 -07:00
Peter Hawkins
6e1ec38a14 Improve behavior of a number of math functions for extreme inputs.
Call XLA's sqrt instead of defining sqrt to be x**0.5. The two have different behaviors for infinite inputs.

Incorporate improvements to acos, sinh, cosh, asinh, and acosh that have previously been made to the versions in the XLA C++ client libraries.
2019-05-29 12:51:24 -04:00
Matthew Johnson
8eb10835ff bring save, savez, load into jax.numpy namespace
fixes #712
2019-05-28 20:52:52 -07:00
Matthew Johnson
14d16b2d66 support both .reshape(*shape) and .reshape(shape)
fixes #746
2019-05-21 21:37:52 -07:00
Matthew Johnson
bf64d9642d
Merge pull request #742 from google/issue740
handle tensordot with zero contracting dims
2019-05-20 17:53:11 -07:00
Matthew Johnson
dcb8584b07 handle tensordot with zero contracting dims
fixes #740
2019-05-20 17:21:53 -07:00
Matthew Johnson
d0e1b7be35 wrap np.trace axes (fixes #738) 2019-05-20 17:11:18 -07:00
Matthew Johnson
42a1ad4307 change dtype promotion behavior for jit-invariance
Here are two desiderata for jax.numpy dtype promotion behavior:
1. follow what NumPy does
2. be invariant to `@jit`

The latter is much more important, so whenever the two are in tension we
prefer the latter. (Also we already can't do a perfect job following
what NumPy does, e.g. around its value-dependent dtype promotion logic.)

Issue #732 showed our code had a special behavior that essentially
handled a case of the former desideratum but also broke the latter. #732
also showed us (again) that our tests really should cover Python
scalars.

In summary, in this commit:
* revise jax.numpy dtype promotion behavior to be invariant to `@jit`
* add Python scalar types to lax_numpy tests
* simplify and update kron implementation to fix dtype issues
2019-05-19 18:49:16 -07:00
Aditya Vaidya
4ef154a3af Fixes, and skip test if ZeroDivisionError 2019-05-08 02:29:37 -05:00
Aditya Vaidya
3468e87f02 Average fixes. Doesn't work for non-empty shapes 2019-05-08 02:19:20 -05:00
Aditya Vaidya
725a469538 Started work on np.average 2019-05-08 02:19:20 -05:00
Peter Hawkins
fd53e394ff Fix test case for nan_to_num when enable_x64 is False. 2019-05-07 16:29:45 -04:00
Peter Hawkins
6d77fb7d20 Fix type mismatch for nan_to_num for 64-bit types. Fixes #683.
Add tests for isinf/isnan/isposinf/isneginf/nan_to_num now that nan/inf are honored on the CPU backend.

Add complex number support to more of the RNG test utils. Add test RNG that emits both nans and infs.
2019-05-07 15:07:43 -04:00
Matthew Johnson
9adfb80625 add advanced indexing support to jax index ops
fixes #658

This commit adds advanced indexing support to jax index operations,
namely index_update and index_add, but does *not* add support for mixed
advanced indexing and slicing. That's left as a NotImplementedError.

This commit also added a segment_sum convenience wrapper.
2019-05-06 14:20:24 -07:00
Matthew Johnson
ddd29e724e fix DeviceArray.__repr__ for complex dtypes, test
c.f. #666
2019-05-02 19:27:22 -07:00
Matthew Johnson
a0fcb3fb9d
Merge pull request #660 from Bharat123rox/npfix
Implemented np.fix
2019-04-30 16:56:03 -07:00
Sharad Vikram
d92fb06939 add lax numpy.tile implementation 2019-04-30 13:13:56 -07:00
Bharat123rox
69d12111fc Implemented np.fix 2019-04-30 22:33:25 +05:30
Peter Hawkins
a18d1971ff
Merge pull request #640 from kroq-gar78/atleast_nd-scalars
Fix 'atleast_{1,2,3}d' for scalars
2019-04-30 07:08:46 -07:00
Peter Hawkins
c47cca2058 Perform division in mean using the target dtype, rather than performing a true_divide and then casting back to the correct type. 2019-04-26 15:51:45 -07:00
Aditya Vaidya
40e3056e65 Fix 'atleast_<n>d' for scalars 2019-04-25 01:01:02 -05:00
Dheeraj Rajaram Reddy
8bcd0afac5 Use lax.sign instead of lax.div(x, lax.abs(x)) 2019-04-14 22:42:52 +05:30