fehiepsi
890ba842a9
implement expon and laplace sampler
2019-03-28 23:57:00 -04:00
fehiepsi
a0636eaedd
implement cauchy pdf and random sampler
2019-03-28 17:59:42 -04:00
fehiepsi
f05d0bcbd8
fix bernoulli shape bug
2019-03-02 19:11:16 -05:00
Masahiro H
d3acbc1c73
Fix typo in PRNGKey docstring
2019-02-16 23:31:27 +09:00
Matthew Johnson
910848afe1
fix typo in random.py docstring
2019-02-13 20:05:15 -08:00
Matthew Johnson
11122bc8e3
improve jax.random docs
2019-02-13 19:42:47 -08:00
Matthew Johnson
0ff98a74eb
add random.fold_in, update mnist_vae.py loops
2019-02-13 09:55:36 -08:00
Matthew Johnson
70d1a00443
set behavior when random.randint has invalid range
...
(closes #222 )
2019-01-12 13:12:40 -08:00
Matthew Johnson
a627cc80e8
make random.split return something vmap-compatible
...
(in particular, return an array rather than a tuple, c.f. #181 )
2019-01-02 12:52:39 -08:00
Matthew Johnson
072e6f78f1
replace PRNGKey class with uint32[2] array
2018-12-30 21:42:55 -08:00
Matthew Johnson
52c6eac3de
use lax.tie_in in jax.random for better consts
2018-12-18 09:16:59 -08:00
Matthew Johnson
1ae1ae17a2
add EyeConstant, new np.eye and np.array code
2018-12-18 09:16:59 -08:00
Matthew Johnson
bfe653c6b0
Tracer.__len__ should reflect on abstract value
...
This old implementation, which was meant to be revised but which we
forgot about, caused a surprising slowdown: if x were a traced array of
size 50000, evaluating len(x) would create 50000 traced temporary
objects, which led to a lot of overhead! That came up in our
implementation of jax.random.shuffle, which happened to call len()
instead of x.shape[axis] (even though it should have been using x.size
anyway, according to tjablin@'s code that it's based on).
2018-12-15 20:07:10 -08:00
Matthew Johnson
a285017110
fix failing tests (misc small bugs)
2018-12-13 11:52:41 -08:00
Alex Wiltschko
0b7bed8778
Adding unimplemented functions to numpy.random, numpy.fft and numpy.linalg
2018-12-11 12:44:02 -05:00
Matthew Johnson
bbc92ce6eb
Split out jax
and jaxlib
packages ( #11 )
...
factor out 'jaxlib' as separate package
2018-12-06 21:35:03 -05:00
Peter Hawkins
c1b9eb19ea
[JAX] Change semantics of dtype promotion to just call numpy.result_type.
...
* Enable tests for numpy scalars in lax_numpy_test.py.
* Fix invalid promotion in random.py.
* Split tests for bitwise ops into their own test case and test mixed signedness.
* Add complex64 to the set of types supported by abstractify.
2018-12-06 13:25:42 -05:00
Matthew Johnson
c293f7c875
minor: add @jit to threefry hash function in random.py
...
PiperOrigin-RevId: 222841601
2018-11-27 16:51:13 -08:00
Matthew Johnson
8317cc3618
source sync
...
PiperOrigin-RevId: 222484671
2018-11-21 20:32:33 -08:00
Matthew Johnson
2ae9a2bc35
source sync
...
PiperOrigin-RevId: 222461242
2018-11-21 20:32:16 -08:00
Peter Hawkins
e180f08113
source sync
...
PiperOrigin-RevId: 222451919
2018-11-21 20:22:51 -08:00
Peter Hawkins
065bb0baa2
source sync
...
PiperOrigin-RevId: 222291726
2018-11-21 20:22:41 -08:00
Matthew Johnson
46c6a9170f
sync updates
2018-11-19 07:47:59 -08:00
Matthew Johnson
a30e858e59
populating source tree
2018-11-17 18:03:33 -08:00