24 Commits

Author SHA1 Message Date
fehiepsi
890ba842a9 implement expon and laplace sampler 2019-03-28 23:57:00 -04:00
fehiepsi
a0636eaedd implement cauchy pdf and random sampler 2019-03-28 17:59:42 -04:00
fehiepsi
f05d0bcbd8 fix bernoulli shape bug 2019-03-02 19:11:16 -05:00
Masahiro H
d3acbc1c73
Fix typo in PRNGKey docstring 2019-02-16 23:31:27 +09:00
Matthew Johnson
910848afe1 fix typo in random.py docstring 2019-02-13 20:05:15 -08:00
Matthew Johnson
11122bc8e3 improve jax.random docs 2019-02-13 19:42:47 -08:00
Matthew Johnson
0ff98a74eb add random.fold_in, update mnist_vae.py loops 2019-02-13 09:55:36 -08:00
Matthew Johnson
70d1a00443 set behavior when random.randint has invalid range
(closes #222)
2019-01-12 13:12:40 -08:00
Matthew Johnson
a627cc80e8 make random.split return something vmap-compatible
(in particular, return an array rather than a tuple, c.f. #181)
2019-01-02 12:52:39 -08:00
Matthew Johnson
072e6f78f1 replace PRNGKey class with uint32[2] array 2018-12-30 21:42:55 -08:00
Matthew Johnson
52c6eac3de use lax.tie_in in jax.random for better consts 2018-12-18 09:16:59 -08:00
Matthew Johnson
1ae1ae17a2 add EyeConstant, new np.eye and np.array code 2018-12-18 09:16:59 -08:00
Matthew Johnson
bfe653c6b0 Tracer.__len__ should reflect on abstract value
This old implementation, which was meant to be revised but which we
forgot about, caused a surprising slowdown: if x were a traced array of
size 50000, evaluating len(x) would create 50000 traced temporary
objects, which led to a lot of overhead! That came up in our
implementation of jax.random.shuffle, which happened to call len()
instead of x.shape[axis] (even though it should have been using x.size
anyway, according to tjablin@'s code that it's based on).
2018-12-15 20:07:10 -08:00
Matthew Johnson
a285017110 fix failing tests (misc small bugs) 2018-12-13 11:52:41 -08:00
Alex Wiltschko
0b7bed8778 Adding unimplemented functions to numpy.random, numpy.fft and numpy.linalg 2018-12-11 12:44:02 -05:00
Matthew Johnson
bbc92ce6eb
Split out jax and jaxlib packages (#11)
factor out 'jaxlib' as separate package
2018-12-06 21:35:03 -05:00
Peter Hawkins
c1b9eb19ea [JAX] Change semantics of dtype promotion to just call numpy.result_type.
* Enable tests for numpy scalars in lax_numpy_test.py.
* Fix invalid promotion in random.py.
* Split tests for bitwise ops into their own test case and test mixed signedness.
* Add complex64 to the set of types supported by abstractify.
2018-12-06 13:25:42 -05:00
Matthew Johnson
c293f7c875 minor: add @jit to threefry hash function in random.py
PiperOrigin-RevId: 222841601
2018-11-27 16:51:13 -08:00
Matthew Johnson
8317cc3618 source sync
PiperOrigin-RevId: 222484671
2018-11-21 20:32:33 -08:00
Matthew Johnson
2ae9a2bc35 source sync
PiperOrigin-RevId: 222461242
2018-11-21 20:32:16 -08:00
Peter Hawkins
e180f08113 source sync
PiperOrigin-RevId: 222451919
2018-11-21 20:22:51 -08:00
Peter Hawkins
065bb0baa2 source sync
PiperOrigin-RevId: 222291726
2018-11-21 20:22:41 -08:00
Matthew Johnson
46c6a9170f sync updates 2018-11-19 07:47:59 -08:00
Matthew Johnson
a30e858e59 populating source tree 2018-11-17 18:03:33 -08:00