Matthew Johnson
ec456a181b
Update random.py
...
Update `jax.random.fold_in` docstring to specify `data` is treated as a 32bit integer.
2019-07-23 12:21:28 +03:00
Peter Zhokhov
8985d684a1
remove static argnums from random.fold_in
2019-07-19 12:04:33 -07:00
David Majnemer
5607e46922
Avoid generating non-finite values from cauchy
...
If uniform generates 0, -pi/2 will be sent to tan resulting in a
non-finite result. Instead, generate values on (0,1).
2019-07-06 21:11:02 -07:00
David Majnemer
52fa63af48
Avoid generating non-finite values from gumbel and laplace
...
In the case of gumbel, we take the log(-log(x)), as such we would not want to let x be 0 or 1 as we would get a non-finite number.
In the case of laplace, we take the log1p(-abs(x)), as such we would not want to let x be -1 or 1 as we would get a non-finite number.
This was found by inspection, I have no evidence that this happens in practice.
2019-07-05 11:01:28 -07:00
fehiepsi
dc91d00afd
use split 2 instead of split 1
2019-06-27 17:28:36 -04:00
fehiepsi
d284388577
pass tests locally
2019-06-27 13:07:26 -04:00
fehiepsi
907925fb30
improve compling time of gamma
2019-06-27 13:07:26 -04:00
fehiepsi
9252d596eb
implement gamma grad
2019-06-27 13:07:25 -04:00
Matthew Johnson
a56a7d02ff
make threefry_2x32 not do any op-by-op stuff
2019-06-11 14:56:21 -07:00
Matthew Johnson
c743c14c8f
address reviewer comments
2019-05-22 20:06:12 -07:00
Matthew Johnson
93e201f85b
make jax.random default dtypes 64-bit
...
fixes #756
2019-05-22 16:22:12 -07:00
Matthew Johnson
b68ed2787a
fix randint docstring
2019-05-16 10:36:30 -07:00
Matthew Johnson
ea8e414a83
improve jax.random shape error messages
2019-05-09 11:40:19 -07:00
Matthew Johnson
910eac8c01
Merge branch 'master' into dirichlet
2019-04-23 17:07:51 -07:00
Matthew Johnson
376ee6423f
Merge pull request #631 from fehiepsi/bernoulli
...
Implement bernoulli logpmf
2019-04-23 16:59:40 -07:00
fehiepsi
a790436be4
add dirichlet sampler
2019-04-22 11:55:02 -04:00
fehiepsi
665e72a23a
implement bernoulli pmf
2019-04-21 21:22:50 -04:00
fehiepsi
088081711d
add tests
2019-04-21 16:43:18 -04:00
fehiepsi
e5ccf0534d
add beta/t random samplers
2019-04-21 16:25:20 -04:00
Matthew Johnson
0cf14837c9
make a lax package, revert control flow names ( #607 )
...
c.f. #597
pair=skyewm
2019-04-12 16:28:40 -07:00
Skye Wanderman-Milne
105e46f379
Factor out control flow from lax.py into lax_control_flow.py.
...
Also moves control flow tests to lax_control_flow_test.py.
2019-04-12 13:57:18 -07:00
Matthew Johnson
35e340670c
Merge branch 'master' into gamma
2019-04-12 07:15:41 -07:00
Matthew Johnson
9c2e1c35b1
prevent jit from treating keyword args as static
...
fixes #523
2019-04-10 22:09:14 -07:00
Jamie Townsend
996c62337c
Rm unnecessary dtype arg name
2019-04-02 10:56:44 +01:00
Jamie Townsend
7b13ae4b41
Add Gumbel to jax.random, and test
2019-04-02 10:55:03 +01:00
fehiepsi
69c4e22524
add dtype for samplers in gamma_one
2019-04-01 00:32:42 -04:00
fehiepsi
80e22d2fdf
Merge remote-tracking branch 'upstream/master' into gamma
2019-03-31 23:54:46 -04:00
fehiepsi
cbf45282ee
convert gamma_one to lax api and move the inner while_loop to the outside
2019-03-31 23:54:31 -04:00
fehiepsi
e0567b6d16
add Gamma sampler
2019-03-30 18:07:34 -04:00
fehiepsi
bb095d3df5
implement pareto logpdf and sampler
2019-03-30 16:34:20 -04:00
fehiepsi
890ba842a9
implement expon and laplace sampler
2019-03-28 23:57:00 -04:00
fehiepsi
a0636eaedd
implement cauchy pdf and random sampler
2019-03-28 17:59:42 -04:00
fehiepsi
f05d0bcbd8
fix bernoulli shape bug
2019-03-02 19:11:16 -05:00
Masahiro H
d3acbc1c73
Fix typo in PRNGKey docstring
2019-02-16 23:31:27 +09:00
Matthew Johnson
910848afe1
fix typo in random.py docstring
2019-02-13 20:05:15 -08:00
Matthew Johnson
11122bc8e3
improve jax.random docs
2019-02-13 19:42:47 -08:00
Matthew Johnson
0ff98a74eb
add random.fold_in, update mnist_vae.py loops
2019-02-13 09:55:36 -08:00
Matthew Johnson
70d1a00443
set behavior when random.randint has invalid range
...
(closes #222 )
2019-01-12 13:12:40 -08:00
Matthew Johnson
a627cc80e8
make random.split return something vmap-compatible
...
(in particular, return an array rather than a tuple, c.f. #181 )
2019-01-02 12:52:39 -08:00
Matthew Johnson
072e6f78f1
replace PRNGKey class with uint32[2] array
2018-12-30 21:42:55 -08:00
Matthew Johnson
52c6eac3de
use lax.tie_in in jax.random for better consts
2018-12-18 09:16:59 -08:00
Matthew Johnson
1ae1ae17a2
add EyeConstant, new np.eye and np.array code
2018-12-18 09:16:59 -08:00
Matthew Johnson
bfe653c6b0
Tracer.__len__ should reflect on abstract value
...
This old implementation, which was meant to be revised but which we
forgot about, caused a surprising slowdown: if x were a traced array of
size 50000, evaluating len(x) would create 50000 traced temporary
objects, which led to a lot of overhead! That came up in our
implementation of jax.random.shuffle, which happened to call len()
instead of x.shape[axis] (even though it should have been using x.size
anyway, according to tjablin@'s code that it's based on).
2018-12-15 20:07:10 -08:00
Matthew Johnson
a285017110
fix failing tests (misc small bugs)
2018-12-13 11:52:41 -08:00
Alex Wiltschko
0b7bed8778
Adding unimplemented functions to numpy.random, numpy.fft and numpy.linalg
2018-12-11 12:44:02 -05:00
Matthew Johnson
bbc92ce6eb
Split out jax
and jaxlib
packages ( #11 )
...
factor out 'jaxlib' as separate package
2018-12-06 21:35:03 -05:00
Peter Hawkins
c1b9eb19ea
[JAX] Change semantics of dtype promotion to just call numpy.result_type.
...
* Enable tests for numpy scalars in lax_numpy_test.py.
* Fix invalid promotion in random.py.
* Split tests for bitwise ops into their own test case and test mixed signedness.
* Add complex64 to the set of types supported by abstractify.
2018-12-06 13:25:42 -05:00
Matthew Johnson
c293f7c875
minor: add @jit to threefry hash function in random.py
...
PiperOrigin-RevId: 222841601
2018-11-27 16:51:13 -08:00
Matthew Johnson
8317cc3618
source sync
...
PiperOrigin-RevId: 222484671
2018-11-21 20:32:33 -08:00
Matthew Johnson
2ae9a2bc35
source sync
...
PiperOrigin-RevId: 222461242
2018-11-21 20:32:16 -08:00