48 Commits

Author SHA1 Message Date
fehiepsi
907925fb30 improve compling time of gamma 2019-06-27 13:07:26 -04:00
fehiepsi
9252d596eb implement gamma grad 2019-06-27 13:07:25 -04:00
Matthew Johnson
a56a7d02ff make threefry_2x32 not do any op-by-op stuff 2019-06-11 14:56:21 -07:00
Matthew Johnson
c743c14c8f address reviewer comments 2019-05-22 20:06:12 -07:00
Matthew Johnson
93e201f85b make jax.random default dtypes 64-bit
fixes #756
2019-05-22 16:22:12 -07:00
Matthew Johnson
b68ed2787a fix randint docstring 2019-05-16 10:36:30 -07:00
Matthew Johnson
ea8e414a83 improve jax.random shape error messages 2019-05-09 11:40:19 -07:00
Matthew Johnson
910eac8c01
Merge branch 'master' into dirichlet 2019-04-23 17:07:51 -07:00
Matthew Johnson
376ee6423f
Merge pull request #631 from fehiepsi/bernoulli
Implement bernoulli logpmf
2019-04-23 16:59:40 -07:00
fehiepsi
a790436be4 add dirichlet sampler 2019-04-22 11:55:02 -04:00
fehiepsi
665e72a23a implement bernoulli pmf 2019-04-21 21:22:50 -04:00
fehiepsi
088081711d add tests 2019-04-21 16:43:18 -04:00
fehiepsi
e5ccf0534d add beta/t random samplers 2019-04-21 16:25:20 -04:00
Matthew Johnson
0cf14837c9 make a lax package, revert control flow names (#607)
c.f. #597
pair=skyewm
2019-04-12 16:28:40 -07:00
Skye Wanderman-Milne
105e46f379 Factor out control flow from lax.py into lax_control_flow.py.
Also moves control flow tests to lax_control_flow_test.py.
2019-04-12 13:57:18 -07:00
Matthew Johnson
35e340670c
Merge branch 'master' into gamma 2019-04-12 07:15:41 -07:00
Matthew Johnson
9c2e1c35b1 prevent jit from treating keyword args as static
fixes #523
2019-04-10 22:09:14 -07:00
Jamie Townsend
996c62337c Rm unnecessary dtype arg name 2019-04-02 10:56:44 +01:00
Jamie Townsend
7b13ae4b41 Add Gumbel to jax.random, and test 2019-04-02 10:55:03 +01:00
fehiepsi
69c4e22524 add dtype for samplers in gamma_one 2019-04-01 00:32:42 -04:00
fehiepsi
80e22d2fdf Merge remote-tracking branch 'upstream/master' into gamma 2019-03-31 23:54:46 -04:00
fehiepsi
cbf45282ee convert gamma_one to lax api and move the inner while_loop to the outside 2019-03-31 23:54:31 -04:00
fehiepsi
e0567b6d16 add Gamma sampler 2019-03-30 18:07:34 -04:00
fehiepsi
bb095d3df5 implement pareto logpdf and sampler 2019-03-30 16:34:20 -04:00
fehiepsi
890ba842a9 implement expon and laplace sampler 2019-03-28 23:57:00 -04:00
fehiepsi
a0636eaedd implement cauchy pdf and random sampler 2019-03-28 17:59:42 -04:00
fehiepsi
f05d0bcbd8 fix bernoulli shape bug 2019-03-02 19:11:16 -05:00
Masahiro H
d3acbc1c73
Fix typo in PRNGKey docstring 2019-02-16 23:31:27 +09:00
Matthew Johnson
910848afe1 fix typo in random.py docstring 2019-02-13 20:05:15 -08:00
Matthew Johnson
11122bc8e3 improve jax.random docs 2019-02-13 19:42:47 -08:00
Matthew Johnson
0ff98a74eb add random.fold_in, update mnist_vae.py loops 2019-02-13 09:55:36 -08:00
Matthew Johnson
70d1a00443 set behavior when random.randint has invalid range
(closes #222)
2019-01-12 13:12:40 -08:00
Matthew Johnson
a627cc80e8 make random.split return something vmap-compatible
(in particular, return an array rather than a tuple, c.f. #181)
2019-01-02 12:52:39 -08:00
Matthew Johnson
072e6f78f1 replace PRNGKey class with uint32[2] array 2018-12-30 21:42:55 -08:00
Matthew Johnson
52c6eac3de use lax.tie_in in jax.random for better consts 2018-12-18 09:16:59 -08:00
Matthew Johnson
1ae1ae17a2 add EyeConstant, new np.eye and np.array code 2018-12-18 09:16:59 -08:00
Matthew Johnson
bfe653c6b0 Tracer.__len__ should reflect on abstract value
This old implementation, which was meant to be revised but which we
forgot about, caused a surprising slowdown: if x were a traced array of
size 50000, evaluating len(x) would create 50000 traced temporary
objects, which led to a lot of overhead! That came up in our
implementation of jax.random.shuffle, which happened to call len()
instead of x.shape[axis] (even though it should have been using x.size
anyway, according to tjablin@'s code that it's based on).
2018-12-15 20:07:10 -08:00
Matthew Johnson
a285017110 fix failing tests (misc small bugs) 2018-12-13 11:52:41 -08:00
Alex Wiltschko
0b7bed8778 Adding unimplemented functions to numpy.random, numpy.fft and numpy.linalg 2018-12-11 12:44:02 -05:00
Matthew Johnson
bbc92ce6eb
Split out jax and jaxlib packages (#11)
factor out 'jaxlib' as separate package
2018-12-06 21:35:03 -05:00
Peter Hawkins
c1b9eb19ea [JAX] Change semantics of dtype promotion to just call numpy.result_type.
* Enable tests for numpy scalars in lax_numpy_test.py.
* Fix invalid promotion in random.py.
* Split tests for bitwise ops into their own test case and test mixed signedness.
* Add complex64 to the set of types supported by abstractify.
2018-12-06 13:25:42 -05:00
Matthew Johnson
c293f7c875 minor: add @jit to threefry hash function in random.py
PiperOrigin-RevId: 222841601
2018-11-27 16:51:13 -08:00
Matthew Johnson
8317cc3618 source sync
PiperOrigin-RevId: 222484671
2018-11-21 20:32:33 -08:00
Matthew Johnson
2ae9a2bc35 source sync
PiperOrigin-RevId: 222461242
2018-11-21 20:32:16 -08:00
Peter Hawkins
e180f08113 source sync
PiperOrigin-RevId: 222451919
2018-11-21 20:22:51 -08:00
Peter Hawkins
065bb0baa2 source sync
PiperOrigin-RevId: 222291726
2018-11-21 20:22:41 -08:00
Matthew Johnson
46c6a9170f sync updates 2018-11-19 07:47:59 -08:00
Matthew Johnson
a30e858e59 populating source tree 2018-11-17 18:03:33 -08:00