16407 Commits

Author SHA1 Message Date
Matthew Johnson
061d033c2b add jit around einsum 2018-12-18 23:20:10 -08:00
Matthew Johnson
6261ef729a more einsum improvements (complete?) 2018-12-18 23:20:10 -08:00
Matthew Johnson
d8388e2d80 complete support for two-operand einsum 2018-12-18 23:20:10 -08:00
Matthew Johnson
fdde6841e6 add support for two-operrand cases 2018-12-18 23:20:10 -08:00
Matthew Johnson
13a0e1168e fix broadcasted eye bug, enable more einsum 2018-12-18 23:20:10 -08:00
Matthew Johnson
6a71e9d6ec start drafting an einsum implementation 2018-12-18 23:20:09 -08:00
Matthew Johnson
a18e3f27ac add tests for device constants 2018-12-18 22:45:34 -08:00
Matthew Johnson
20ca0bd733 add cutoff for materialize-and-xfer vs build-on-device
see https://github.com/google/jax/pull/140#issuecomment-448433620
2018-12-18 17:19:47 -08:00
Matthew Johnson
78a4240581 add fix from einsum branch 2018-12-18 09:18:14 -08:00
Matthew Johnson
52c6eac3de use lax.tie_in in jax.random for better consts 2018-12-18 09:16:59 -08:00
Matthew Johnson
bdc9e92f94 remove full_p to improve power:weight 2018-12-18 09:16:59 -08:00
Matthew Johnson
1ae1ae17a2 add EyeConstant, new np.eye and np.array code 2018-12-18 09:16:59 -08:00
Matthew Johnson
dfc25a06d9 add IotaConstant (untested) 2018-12-18 09:16:59 -08:00
Matthew Johnson
f971415218 add tie_in and full primitives (constant creation) 2018-12-18 09:16:59 -08:00
Matthew Johnson
25cf9358d1
Merge pull request #131 from google/mean-kwargs
fix mean/var/std kwargs (closes #125)
2018-12-17 15:34:05 -08:00
Matthew Johnson
5589a3cc58 fix up error messages 2018-12-17 15:13:41 -08:00
Peter Hawkins
f78bd8a550
Merge pull request #130 from hawkinsp/master
Progress towards calling LAPACK kernels on CPU for unbatched Cholesky and Triangular solve
2018-12-17 18:05:21 -05:00
Peter Hawkins
8c25c29295
Merge pull request #127 from j-towns/qr-jvp
Implement the JVP of the QR decomposition
2018-12-17 17:54:43 -05:00
Peter Hawkins
e0f421746b Import CPU Lapack implementation conditionally to ease jaxlib upgrade. 2018-12-17 17:52:16 -05:00
Peter Hawkins
ab1ebc6bad Add experimental warning to numpy.linalg and scipy.linalg. 2018-12-17 17:39:46 -05:00
Matthew Johnson
1f2925ea8a add backend-specific translation table in xla.py 2018-12-17 17:30:27 -05:00
Matthew Johnson
7524f2c087 fix mean/var/std kwargs (closes #125) 2018-12-17 14:26:28 -08:00
Peter Hawkins
eac96ac239 Fix bugs with bool argument passing; pass PRED values as int32s instead. 2018-12-17 17:17:47 -05:00
Peter Hawkins
0333a98bab Add triangular solve BLAS implementation. 2018-12-17 16:39:19 -05:00
Peter Hawkins
3c388b98f1 Add support for calling LAPACK primitives from SciPy from JAX linalg. 2018-12-17 16:30:27 -05:00
Roy Frostig
e0ad5bb394 add TODO comment for containers in api.make_jaxpr 2018-12-17 11:42:45 -08:00
Jamie Townsend
5cdf915fb6 Test qr jvp 2018-12-17 16:36:55 +00:00
Jamie Townsend
f5b8d97c95 Add url for qr jvp notes 2018-12-17 16:04:51 +00:00
Jamie Townsend
1743a936eb Add qr decomposition jvp 2018-12-17 16:02:29 +00:00
Roy Frostig
f4a8e03ce1 add a basic make_jaxpr transformation to the api module 2018-12-16 13:26:02 -08:00
Roy Frostig
3b8fdb050a wrap "jit" around generated function name 2018-12-16 13:24:20 -08:00
Roy Frostig
b318f56928 generate name, module, and doctring for functions output from jit. 2018-12-16 13:04:29 -08:00
Matthew Johnson
ea08ecd5f0 add promote_dtypes logic to tensordot 2018-12-16 11:33:57 -08:00
Matthew Johnson
2e20a60916 add tensordot 2018-12-15 21:59:18 -08:00
Matthew Johnson
bfe653c6b0 Tracer.__len__ should reflect on abstract value
This old implementation, which was meant to be revised but which we
forgot about, caused a surprising slowdown: if x were a traced array of
size 50000, evaluating len(x) would create 50000 traced temporary
objects, which led to a lot of overhead! That came up in our
implementation of jax.random.shuffle, which happened to call len()
instead of x.shape[axis] (even though it should have been using x.size
anyway, according to tjablin@'s code that it's based on).
2018-12-15 20:07:10 -08:00
Peter Hawkins
06b7e54c02 Fix bug in "economic" mode in jax.scipy.linalg.qr where it returned the full decomposition. 2018-12-15 10:52:10 -05:00
Peter Hawkins
13a135d424 Implement lower=False case for scipy.linalg.cholesky.
Remove np.linalg.{dot,matmul,trace}, because these aren't part of the numpy API. I had previously misinterpreted the np.linalg documentation to mean that they also existed in that module.
2018-12-15 10:22:42 -05:00
Matthew Johnson
13b8e21a1c squash conv grad bug introduced in 0d64aea
(loudly errored, didn't produce silently incorrect results!)
2018-12-14 18:40:50 -08:00
Matthew Johnson
6de5c8a698 add test-running instructions (fixes #67) 2018-12-14 16:48:08 -08:00
Matthew Johnson
c268929f2d add 'dtype' arg to np.std, add test coverage 2018-12-14 16:22:51 -08:00
sschoenholz
5d6ebba2a0
Fixed argument order in call to var from std. 2018-12-14 11:58:03 -08:00
Matthew Johnson
b164d318fb reduce_and / reduce_or monoid reducer primitives
The parent commit reused reduce_min / reduce_max on booleans, which is
formally equivalent but preserves less information when lowering to XLA.
2018-12-14 08:42:02 -08:00
Matthew Johnson
693365c239 np.all and np.any should lead to monoid reducers
fixes #108
2018-12-14 08:07:12 -08:00
Peter Hawkins
14acd1dbfa
Merge pull request #110 from hawkinsp/master
Implement np.linalg.inv using a QR decomposition.
2018-12-13 21:12:22 -05:00
Peter Hawkins
23525bd9b5 Add scipy.linalg.inv as well. Simplify the QR call in np.linalg.inv. 2018-12-13 21:02:24 -05:00
Matthew Johnson
9b645364c9 tests depend on scipy, tweaks 2018-12-13 16:40:03 -08:00
Peter Hawkins
3aad9b68f6 Implement np.linalg.inv using a QR decomposition.
An LU decomposition would probably be preferable; we can switch the implementation when we have an LU decomposition.

Fixes #44.
2018-12-13 19:28:05 -05:00
Peter Hawkins
bc9e157459
Merge pull request #107 from hawkinsp/master
Make JAX flake8-clean.
2018-12-13 16:17:20 -05:00
Peter Hawkins
0d4eb6c1e1 Make JAX flake8-clean.
Fixes #1.
2018-12-13 15:29:39 -05:00
Matthew Johnson
c5d6c9f09f Merge remote-tracking branch 'origin/master' 2018-12-13 11:56:23 -08:00