16407 Commits

Author SHA1 Message Date
Matthew Johnson
c4d32ca2cd support x[[0, 2, 4], [0, 2, 4]] indexing, fix #187 2019-01-02 17:46:46 -08:00
Matthew Johnson
a627cc80e8 make random.split return something vmap-compatible
(in particular, return an array rather than a tuple, c.f. #181)
2019-01-02 12:52:39 -08:00
Matthew Johnson
072e6f78f1 replace PRNGKey class with uint32[2] array 2018-12-30 21:42:55 -08:00
Matthew Johnson
c5c6e6c5c7
Merge pull request #178 from google/dkd
add np.append and np.polyval
2018-12-30 20:53:11 -08:00
Matthew Johnson
5d2fc7c05b only test np.polyval on nonscalar array shapes 2018-12-30 18:07:50 -08:00
Matthew Johnson
9c49a9bfe6 add np.append and np.polyval 2018-12-30 17:49:11 -08:00
Matthew Johnson
61d5d79e39
Merge pull request #175 from google/stax-fan-in-concat
Add stax.FanInConcat
2018-12-30 17:12:12 -08:00
Matthew Johnson
b4246163ac add stax.FanInConcat (fixes #174) 2018-12-30 16:51:32 -08:00
Roy Frostig
b2d6ce175a add a placeholder comment for lbr tests 2018-12-28 13:51:32 -08:00
Matthew Johnson
59142a2494 add error message if Dropout gets no rng key
closes #170 (though there's more work to be done here)
2018-12-24 10:33:13 -08:00
Matthew Johnson
6d6b5263fe add non-advanced boolean indexing support
also don't sub-sample indexing tests (run them all)
fixes #166
2018-12-23 11:02:20 -08:00
Matthew Johnson
d48e7ef43d add batching (vmap) rule for lax.dynamic_slice
fixes #165
2018-12-23 09:28:23 -08:00
Peter Hawkins
b1ff7b4ff9 Fix missing raise in error path. 2018-12-22 13:59:06 -05:00
Peter Hawkins
06135fa6f5 Implement numpy.linalg.solve and scipy.linalg.solve.
Make Cholesky and TriangularSolve work for complex numbers on CPU. The HLO implementations are broken for complex numbers on GPU/TPU, so no tests enabled for these yet.
2018-12-21 16:29:45 -05:00
Peter Hawkins
54772e1b9b
Merge pull request #163 from hawkinsp/master
Use get() rather than a try-catch block in memoized function lookup.
2018-12-21 15:21:02 -05:00
Peter Hawkins
a4386457e2 Fix test failures due to type mismatches in linear algebra tests.
Minor code cleanups.
2018-12-21 15:18:34 -05:00
Peter Hawkins
8127392a18 Use get() rather than a try-catch block in memoized function lookup.
Currently backtraces often look like this:
```
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/p/jax/jax/util.py in memoized_fun(*args, **kwargs)
    133     try:
--> 134       return cache[key]
    135     except KeyError:

KeyError: ((lu, ShapedArray(int32[2,2])), ())

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
~/p/jax/jax/util.py in memoized_fun(*args, **kwargs)
    133     try:
--> 134       return cache[key]
    135     except KeyError:

KeyError: ((lu, xla_client.Shape(_dtype=dtype('int32'), _dimensions=(2, 2), _is_tuple=False, _minor_to_major=None)), ())

During handling of the above exception, another exception occurred:

NotImplementedError                       Traceback (most recent call last)
<ipython-input-26-d6c00d50e3c9> in <module>
```

The "during handling of the above exception..." message is mostly a distraction for the user that occurs because we perform the memoized function evaluation inside a `catch` block. By performing the function evaluation outside the catch block, we can get better backtraces without the distraction of the KeyError exception.

```
2018-12-21 13:32:56 -05:00
Peter Hawkins
e05ac8d5de
Merge pull request #158 from hawkinsp/master
Implement np.kron.
2018-12-21 13:11:48 -05:00
Matthew Johnson
58364f0333
Merge pull request #161 from google/add-jaxvals-batching
cover unimplemented add_jaxvals_p batching case
2018-12-21 09:29:27 -08:00
Matthew Johnson
43f2e2a70a cover unimplemented add_jaxvals_p batching case 2018-12-21 08:11:36 -08:00
Peter Hawkins
0a75007c19 Implement np.kron. 2018-12-21 10:28:45 -05:00
Peter Hawkins
bd50c5b6b5 Implement np.{isposinf,isneginf,nansum,nanprod,nanmin,nanmax,nan_to_num}.
No tests until we figure about a story about fast-math semantics.
2018-12-21 08:52:01 -05:00
Peter Hawkins
b68c93d37f Implement np.linalg.slogdet.
Change implementation of np.linalg.logdet to call np.linalg.slogdet.

Add support for complex64 LU decomposition.
2018-12-20 22:18:20 -05:00
Peter Hawkins
1815f17a5b Be more defensive about old jaxlib versions and non-CPU devices in LU decomposition usage. 2018-12-20 21:04:02 -05:00
Peter Hawkins
dfdc2e3806 Add LU decomposition implementation backed by LAPACK on the CPU platform.
Implement np.linalg.det, and scipy.linalg.{lu,lu_factor,det}.

Add missing abstractification to loop arguments.
Implement XLA abstractification rules for AbstractTuple, ConcreteArray, and ShapedArray.
2018-12-20 18:45:34 -05:00
Peter Hawkins
5cf642e326
Merge pull request #153 from hawkinsp/numpy
Implement np.{identity,count_nonzero}.
2018-12-20 16:12:41 -05:00
Peter Hawkins
4fc73205b2 Implement np.{diagonal,count_nonzero}.
Fix shape error if a scalar is passed to reducers. Add test for scalar reductions.
2018-12-20 15:47:49 -05:00
Matthew Johnson
4124932b89
Merge pull request #151 from google/value-and-grad
Add `value_and_grad` (closes #149)
2018-12-20 10:31:28 -08:00
Matthew Johnson
8f1bc997ca add value-and-grad fun (closes #149) 2018-12-20 10:09:34 -08:00
Peter Hawkins
bf73c1e282 Use lax.bitwise_ in isnan(). 2018-12-20 10:49:52 -05:00
Peter Hawkins
7f6119c7cd Implement np.{full_like,isinf,isnan}. Fix np.isfinite.
Note that inf/nan behavior may not be correct in fastmath mode.
2018-12-20 10:36:32 -05:00
Peter Hawkins
95135377d0 Add implementations of np.{meshgrid,linspace,logspace,geomspace,diag_indices} that forward to the usual numpy implementation. 2018-12-20 08:28:36 -05:00
Matthew Johnson
449da4cdb5
Merge pull request #147 from google/einsum
fix einsum bugs, add test cases
2018-12-19 17:12:16 -08:00
Matthew Johnson
6bb9609fb8 disable test, py3 opt_einsum nondeterministic bug? 2018-12-19 16:58:31 -08:00
Matthew Johnson
6a138202ef fix several einsum bugs 2018-12-19 16:15:43 -08:00
Matthew Johnson
569db1698c
Merge pull request #148 from google/solve-triangular-jvp
Solve triangular jvp rule
2018-12-19 15:34:11 -08:00
Dougal Maclaurin
e6b23dd2b7 Fixed triangular_solve_jvp_rule for transpose_a=True case 2018-12-19 17:47:56 -05:00
Matthew Johnson
9c722db373 einsum: update id strings after moving batch dims 2018-12-19 10:59:03 -08:00
Matthew Johnson
9a68bce567 add comment marking a bug 2018-12-19 10:42:40 -08:00
Matthew Johnson
43e77acca5 fix select transpose rule 2018-12-19 09:40:40 -08:00
Matthew Johnson
dc1d0c260a always fall back to onp.arange for now 2018-12-19 09:21:30 -08:00
Matthew Johnson
6a9952a939 jax.numpy.arange should fall back to onp.arange
fixes #145
2018-12-19 09:07:04 -08:00
Matthew Johnson
c56f43f2a5
Merge pull request #144 from google/einsum
fix einsum tensor product logic (fixes #37)
2018-12-19 08:27:16 -08:00
Dougal Maclaurin
87922fdf13 Generalized make_jaxpr to handle python containers 2018-12-19 10:59:13 -05:00
Matthew Johnson
997c9c5a50 fix einsum tensor product logic (fixes #37)
The error was that `lhs_names` and `rhs_names` included `batch_names` as
prefixes, but the reshaping logic was written as if they did not include
batch_names (and so batch_names had to be prepended).
2018-12-19 07:59:00 -08:00
Peter Hawkins
a154b9c691
Merge pull request #143 from hawkinsp/astype
Fix implementation of ndarray.astype method, add a test.
2018-12-19 09:57:26 -05:00
Peter Hawkins
d3bb93f82a Fix implementation of ndarray.astype method, add a test.
Previously we were creating an operator named __astype__, which isn't a thing in numpy.
2018-12-19 09:28:08 -05:00
Peter Hawkins
dbab47fdb0 Implement np.inner and np.outer. 2018-12-19 08:58:07 -05:00
Matthew Johnson
0a9ee106b3 implement triangular solve lhs jvp (w/ @froystig) 2018-12-18 23:39:46 -08:00
Matthew Johnson
d7745dd9af actually fix py3 str translate 2018-12-18 23:20:10 -08:00