Matthew Johnson
061d033c2b
add jit around einsum
2018-12-18 23:20:10 -08:00
Matthew Johnson
6261ef729a
more einsum improvements (complete?)
2018-12-18 23:20:10 -08:00
Matthew Johnson
d8388e2d80
complete support for two-operand einsum
2018-12-18 23:20:10 -08:00
Matthew Johnson
fdde6841e6
add support for two-operrand cases
2018-12-18 23:20:10 -08:00
Matthew Johnson
13a0e1168e
fix broadcasted eye bug, enable more einsum
2018-12-18 23:20:10 -08:00
Matthew Johnson
6a71e9d6ec
start drafting an einsum implementation
2018-12-18 23:20:09 -08:00
Matthew Johnson
a18e3f27ac
add tests for device constants
2018-12-18 22:45:34 -08:00
Matthew Johnson
20ca0bd733
add cutoff for materialize-and-xfer vs build-on-device
...
see https://github.com/google/jax/pull/140#issuecomment-448433620
2018-12-18 17:19:47 -08:00
Matthew Johnson
78a4240581
add fix from einsum branch
2018-12-18 09:18:14 -08:00
Matthew Johnson
52c6eac3de
use lax.tie_in in jax.random for better consts
2018-12-18 09:16:59 -08:00
Matthew Johnson
bdc9e92f94
remove full_p to improve power:weight
2018-12-18 09:16:59 -08:00
Matthew Johnson
1ae1ae17a2
add EyeConstant, new np.eye and np.array code
2018-12-18 09:16:59 -08:00
Matthew Johnson
dfc25a06d9
add IotaConstant (untested)
2018-12-18 09:16:59 -08:00
Matthew Johnson
f971415218
add tie_in and full primitives (constant creation)
2018-12-18 09:16:59 -08:00
Matthew Johnson
25cf9358d1
Merge pull request #131 from google/mean-kwargs
...
fix mean/var/std kwargs (closes #125 )
2018-12-17 15:34:05 -08:00
Matthew Johnson
5589a3cc58
fix up error messages
2018-12-17 15:13:41 -08:00
Peter Hawkins
f78bd8a550
Merge pull request #130 from hawkinsp/master
...
Progress towards calling LAPACK kernels on CPU for unbatched Cholesky and Triangular solve
2018-12-17 18:05:21 -05:00
Peter Hawkins
8c25c29295
Merge pull request #127 from j-towns/qr-jvp
...
Implement the JVP of the QR decomposition
2018-12-17 17:54:43 -05:00
Peter Hawkins
e0f421746b
Import CPU Lapack implementation conditionally to ease jaxlib upgrade.
2018-12-17 17:52:16 -05:00
Peter Hawkins
ab1ebc6bad
Add experimental warning to numpy.linalg and scipy.linalg.
2018-12-17 17:39:46 -05:00
Matthew Johnson
1f2925ea8a
add backend-specific translation table in xla.py
2018-12-17 17:30:27 -05:00
Matthew Johnson
7524f2c087
fix mean/var/std kwargs ( closes #125 )
2018-12-17 14:26:28 -08:00
Peter Hawkins
eac96ac239
Fix bugs with bool argument passing; pass PRED values as int32s instead.
2018-12-17 17:17:47 -05:00
Peter Hawkins
0333a98bab
Add triangular solve BLAS implementation.
2018-12-17 16:39:19 -05:00
Peter Hawkins
3c388b98f1
Add support for calling LAPACK primitives from SciPy from JAX linalg.
2018-12-17 16:30:27 -05:00
Roy Frostig
e0ad5bb394
add TODO comment for containers in api.make_jaxpr
2018-12-17 11:42:45 -08:00
Jamie Townsend
5cdf915fb6
Test qr jvp
2018-12-17 16:36:55 +00:00
Jamie Townsend
f5b8d97c95
Add url for qr jvp notes
2018-12-17 16:04:51 +00:00
Jamie Townsend
1743a936eb
Add qr decomposition jvp
2018-12-17 16:02:29 +00:00
Roy Frostig
f4a8e03ce1
add a basic make_jaxpr
transformation to the api module
2018-12-16 13:26:02 -08:00
Roy Frostig
3b8fdb050a
wrap "jit" around generated function name
2018-12-16 13:24:20 -08:00
Roy Frostig
b318f56928
generate name, module, and doctring for functions output from jit
.
2018-12-16 13:04:29 -08:00
Matthew Johnson
ea08ecd5f0
add promote_dtypes logic to tensordot
2018-12-16 11:33:57 -08:00
Matthew Johnson
2e20a60916
add tensordot
2018-12-15 21:59:18 -08:00
Matthew Johnson
bfe653c6b0
Tracer.__len__ should reflect on abstract value
...
This old implementation, which was meant to be revised but which we
forgot about, caused a surprising slowdown: if x were a traced array of
size 50000, evaluating len(x) would create 50000 traced temporary
objects, which led to a lot of overhead! That came up in our
implementation of jax.random.shuffle, which happened to call len()
instead of x.shape[axis] (even though it should have been using x.size
anyway, according to tjablin@'s code that it's based on).
2018-12-15 20:07:10 -08:00
Peter Hawkins
06b7e54c02
Fix bug in "economic" mode in jax.scipy.linalg.qr where it returned the full decomposition.
2018-12-15 10:52:10 -05:00
Peter Hawkins
13a135d424
Implement lower=False case for scipy.linalg.cholesky.
...
Remove np.linalg.{dot,matmul,trace}, because these aren't part of the numpy API. I had previously misinterpreted the np.linalg documentation to mean that they also existed in that module.
2018-12-15 10:22:42 -05:00
Matthew Johnson
13b8e21a1c
squash conv grad bug introduced in 0d64aea
...
(loudly errored, didn't produce silently incorrect results!)
2018-12-14 18:40:50 -08:00
Matthew Johnson
6de5c8a698
add test-running instructions ( fixes #67 )
2018-12-14 16:48:08 -08:00
Matthew Johnson
c268929f2d
add 'dtype' arg to np.std, add test coverage
2018-12-14 16:22:51 -08:00
sschoenholz
5d6ebba2a0
Fixed argument order in call to var from std.
2018-12-14 11:58:03 -08:00
Matthew Johnson
b164d318fb
reduce_and / reduce_or monoid reducer primitives
...
The parent commit reused reduce_min / reduce_max on booleans, which is
formally equivalent but preserves less information when lowering to XLA.
2018-12-14 08:42:02 -08:00
Matthew Johnson
693365c239
np.all and np.any should lead to monoid reducers
...
fixes #108
2018-12-14 08:07:12 -08:00
Peter Hawkins
14acd1dbfa
Merge pull request #110 from hawkinsp/master
...
Implement np.linalg.inv using a QR decomposition.
2018-12-13 21:12:22 -05:00
Peter Hawkins
23525bd9b5
Add scipy.linalg.inv as well. Simplify the QR call in np.linalg.inv.
2018-12-13 21:02:24 -05:00
Matthew Johnson
9b645364c9
tests depend on scipy, tweaks
2018-12-13 16:40:03 -08:00
Peter Hawkins
3aad9b68f6
Implement np.linalg.inv using a QR decomposition.
...
An LU decomposition would probably be preferable; we can switch the implementation when we have an LU decomposition.
Fixes #44 .
2018-12-13 19:28:05 -05:00
Peter Hawkins
bc9e157459
Merge pull request #107 from hawkinsp/master
...
Make JAX flake8-clean.
2018-12-13 16:17:20 -05:00
Peter Hawkins
0d4eb6c1e1
Make JAX flake8-clean.
...
Fixes #1 .
2018-12-13 15:29:39 -05:00
Matthew Johnson
c5d6c9f09f
Merge remote-tracking branch 'origin/master'
2018-12-13 11:56:23 -08:00