65 Commits

Author SHA1 Message Date
Trevor Cai
a44c1caf21 Use finfo(dtype).tiny as uniform minval
Otherwise, using the default clopen uniform for truncated_normal
introduces a slight shift of the empirical mean.
2019-10-01 22:28:31 +01:00
James Bradbury
146b5d121a
Merge pull request #1262 from google/jb/initializers
Migrate initializers and activation functions to jax.nn
2019-09-04 14:48:20 -07:00
James Bradbury
bf28c44ada address comments 2019-09-03 17:51:29 -07:00
Matthew Johnson
2815c55bb1 switch rolled/unrolled loops in prng hash 2019-08-31 07:35:37 -07:00
James Bradbury
f4aeb363a9 add truncated normal 2019-08-29 18:14:57 -07:00
David Majnemer
1dbdaab765 Use log1p when computing log(1 + x) or log(1 - x)
log(1 + x) is less accurate when its input is near zero whereas log1p
can compute the result without excessive accuracy loss.
2019-08-16 13:44:09 -07:00
Peter Hawkins
719e17ba8e Avoid dynamic slicing in threefry implementation.
The dynamic slice when batched currently turns into an expensive gather because vmap(fori_loop(...)) always batches the loop counter at the moment.
2019-08-15 16:37:04 -04:00
Matthew Johnson
75bb38e741 address reviewer comments, no op-by-op in threefry 2019-08-13 11:30:24 -07:00
Matthew Johnson
d857d3f595 improve prng compile times with outer-loop rolling 2019-08-13 09:43:44 -07:00
Matthew Johnson
275bf9da6d improve prng compile times with loop rolling
cf. #1172
2019-08-13 07:20:09 -07:00
Jamie Townsend
f351dedbb7 Add logistic distribution to jax.random 2019-08-06 12:19:05 +01:00
Matthew Johnson
ec456a181b
Update random.py
Update `jax.random.fold_in` docstring to specify  `data` is treated as a 32bit integer.
2019-07-23 12:21:28 +03:00
Peter Zhokhov
8985d684a1 remove static argnums from random.fold_in 2019-07-19 12:04:33 -07:00
David Majnemer
5607e46922 Avoid generating non-finite values from cauchy
If uniform generates 0, -pi/2 will be sent to tan resulting in a
non-finite result. Instead, generate values on (0,1).
2019-07-06 21:11:02 -07:00
David Majnemer
52fa63af48 Avoid generating non-finite values from gumbel and laplace
In the case of gumbel, we take the log(-log(x)), as such we would not want to let x be 0 or 1 as we would get a non-finite number.

In the case of laplace, we take the log1p(-abs(x)), as such we would not want to let x be -1 or 1 as we would get a non-finite number.

This was found by inspection, I have no evidence that this happens in practice.
2019-07-05 11:01:28 -07:00
fehiepsi
dc91d00afd use split 2 instead of split 1 2019-06-27 17:28:36 -04:00
fehiepsi
d284388577 pass tests locally 2019-06-27 13:07:26 -04:00
fehiepsi
907925fb30 improve compling time of gamma 2019-06-27 13:07:26 -04:00
fehiepsi
9252d596eb implement gamma grad 2019-06-27 13:07:25 -04:00
Matthew Johnson
a56a7d02ff make threefry_2x32 not do any op-by-op stuff 2019-06-11 14:56:21 -07:00
Matthew Johnson
c743c14c8f address reviewer comments 2019-05-22 20:06:12 -07:00
Matthew Johnson
93e201f85b make jax.random default dtypes 64-bit
fixes #756
2019-05-22 16:22:12 -07:00
Matthew Johnson
b68ed2787a fix randint docstring 2019-05-16 10:36:30 -07:00
Matthew Johnson
ea8e414a83 improve jax.random shape error messages 2019-05-09 11:40:19 -07:00
Matthew Johnson
910eac8c01
Merge branch 'master' into dirichlet 2019-04-23 17:07:51 -07:00
Matthew Johnson
376ee6423f
Merge pull request #631 from fehiepsi/bernoulli
Implement bernoulli logpmf
2019-04-23 16:59:40 -07:00
fehiepsi
a790436be4 add dirichlet sampler 2019-04-22 11:55:02 -04:00
fehiepsi
665e72a23a implement bernoulli pmf 2019-04-21 21:22:50 -04:00
fehiepsi
088081711d add tests 2019-04-21 16:43:18 -04:00
fehiepsi
e5ccf0534d add beta/t random samplers 2019-04-21 16:25:20 -04:00
Matthew Johnson
0cf14837c9 make a lax package, revert control flow names (#607)
c.f. #597
pair=skyewm
2019-04-12 16:28:40 -07:00
Skye Wanderman-Milne
105e46f379 Factor out control flow from lax.py into lax_control_flow.py.
Also moves control flow tests to lax_control_flow_test.py.
2019-04-12 13:57:18 -07:00
Matthew Johnson
35e340670c
Merge branch 'master' into gamma 2019-04-12 07:15:41 -07:00
Matthew Johnson
9c2e1c35b1 prevent jit from treating keyword args as static
fixes #523
2019-04-10 22:09:14 -07:00
Jamie Townsend
996c62337c Rm unnecessary dtype arg name 2019-04-02 10:56:44 +01:00
Jamie Townsend
7b13ae4b41 Add Gumbel to jax.random, and test 2019-04-02 10:55:03 +01:00
fehiepsi
69c4e22524 add dtype for samplers in gamma_one 2019-04-01 00:32:42 -04:00
fehiepsi
80e22d2fdf Merge remote-tracking branch 'upstream/master' into gamma 2019-03-31 23:54:46 -04:00
fehiepsi
cbf45282ee convert gamma_one to lax api and move the inner while_loop to the outside 2019-03-31 23:54:31 -04:00
fehiepsi
e0567b6d16 add Gamma sampler 2019-03-30 18:07:34 -04:00
fehiepsi
bb095d3df5 implement pareto logpdf and sampler 2019-03-30 16:34:20 -04:00
fehiepsi
890ba842a9 implement expon and laplace sampler 2019-03-28 23:57:00 -04:00
fehiepsi
a0636eaedd implement cauchy pdf and random sampler 2019-03-28 17:59:42 -04:00
fehiepsi
f05d0bcbd8 fix bernoulli shape bug 2019-03-02 19:11:16 -05:00
Masahiro H
d3acbc1c73
Fix typo in PRNGKey docstring 2019-02-16 23:31:27 +09:00
Matthew Johnson
910848afe1 fix typo in random.py docstring 2019-02-13 20:05:15 -08:00
Matthew Johnson
11122bc8e3 improve jax.random docs 2019-02-13 19:42:47 -08:00
Matthew Johnson
0ff98a74eb add random.fold_in, update mnist_vae.py loops 2019-02-13 09:55:36 -08:00
Matthew Johnson
70d1a00443 set behavior when random.randint has invalid range
(closes #222)
2019-01-12 13:12:40 -08:00
Matthew Johnson
a627cc80e8 make random.split return something vmap-compatible
(in particular, return an array rather than a tuple, c.f. #181)
2019-01-02 12:52:39 -08:00