84 Commits

Author SHA1 Message Date
Julius Kunze
9d12a24b63 Add categorical sampler 2019-12-13 12:41:26 +00:00
Peter Hawkins
534d812b57
Add a handwritten ThreeFry2x32 CUDA kernel. (#1756)
In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time.

When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.
2019-11-24 13:06:23 -05:00
Matthew Johnson
b358c27c92 replace x.shape with onp.shape(x) in random.py
fixes #1748 (thanks @vitchyr)
2019-11-22 10:59:31 -08:00
Peter Hawkins
42dd736afd
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.

Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.

This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.

In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00
Peter Hawkins
f7a44523be
Add some type helpers to lax_numpy. (#1593)
Prefer to use jax.numpy type helpers rather than numpy type helpers in various places.
Cleanup in preparation for adding bfloat16 support to jax.
2019-10-29 20:53:20 -04:00
Matthew Johnson
1f4e45cdcd tweak shape convention again 2019-10-20 21:14:48 +00:00
Matthew Johnson
ab6ac6c876 standardize shape handling in jax.random 2019-10-20 19:37:37 +00:00
Matthew Johnson
8a132b4109 try simplifying random.multivariate_normal api 2019-10-19 22:27:43 +00:00
James Bradbury
74dc4bf72a
Merge pull request #1419 from aeftimia/multivariate-normal
Multivariate normal distribution support
2019-10-15 16:46:13 -07:00
Alex Eftimiades
2f4e88760c fix indentation 2019-10-15 14:35:40 -04:00
Alex Eftimiades
7939bad65a Merge branch 'multivariate-normal' of https://github.com/aeftimia/jax into multivariate-normal 2019-10-12 11:38:06 -04:00
Alex Eftimiades
ed1bd5fccd test cases for ValueError with wrong shape; one test case with negative mean 2019-10-10 09:33:28 -04:00
Alex Eftimiades
d344c6f5e4
Update jax/random.py
Fix typo

Co-Authored-By: James Bradbury <jekbradbury@google.com>
2019-10-04 09:57:17 -04:00
Alex Eftimiades
c1057f77ae minor readability cleanup 2019-10-02 11:39:49 -04:00
Alex Eftimiades
ea67fa8f09 fix failing tests with full covariance matrix 2019-10-02 11:08:17 -04:00
Alex Eftimiades
ad3f25b02f small cleanup and retrigger tests 2019-10-02 09:25:55 -04:00
Trevor Cai
a44c1caf21 Use finfo(dtype).tiny as uniform minval
Otherwise, using the default clopen uniform for truncated_normal
introduces a slight shift of the empirical mean.
2019-10-01 22:28:31 +01:00
Alex Eftimiades
b2b10a3c96 Multivariate normal support
Tested with inverse transform and KS test
2019-10-01 16:49:57 -04:00
Alex Eftimiades
b70a7800d7 multivariate normal support 2019-09-23 20:21:28 -04:00
James Bradbury
146b5d121a
Merge pull request #1262 from google/jb/initializers
Migrate initializers and activation functions to jax.nn
2019-09-04 14:48:20 -07:00
James Bradbury
bf28c44ada address comments 2019-09-03 17:51:29 -07:00
Matthew Johnson
2815c55bb1 switch rolled/unrolled loops in prng hash 2019-08-31 07:35:37 -07:00
James Bradbury
f4aeb363a9 add truncated normal 2019-08-29 18:14:57 -07:00
David Majnemer
1dbdaab765 Use log1p when computing log(1 + x) or log(1 - x)
log(1 + x) is less accurate when its input is near zero whereas log1p
can compute the result without excessive accuracy loss.
2019-08-16 13:44:09 -07:00
Peter Hawkins
719e17ba8e Avoid dynamic slicing in threefry implementation.
The dynamic slice when batched currently turns into an expensive gather because vmap(fori_loop(...)) always batches the loop counter at the moment.
2019-08-15 16:37:04 -04:00
Matthew Johnson
75bb38e741 address reviewer comments, no op-by-op in threefry 2019-08-13 11:30:24 -07:00
Matthew Johnson
d857d3f595 improve prng compile times with outer-loop rolling 2019-08-13 09:43:44 -07:00
Matthew Johnson
275bf9da6d improve prng compile times with loop rolling
cf. #1172
2019-08-13 07:20:09 -07:00
Jamie Townsend
f351dedbb7 Add logistic distribution to jax.random 2019-08-06 12:19:05 +01:00
Matthew Johnson
ec456a181b
Update random.py
Update `jax.random.fold_in` docstring to specify  `data` is treated as a 32bit integer.
2019-07-23 12:21:28 +03:00
Peter Zhokhov
8985d684a1 remove static argnums from random.fold_in 2019-07-19 12:04:33 -07:00
David Majnemer
5607e46922 Avoid generating non-finite values from cauchy
If uniform generates 0, -pi/2 will be sent to tan resulting in a
non-finite result. Instead, generate values on (0,1).
2019-07-06 21:11:02 -07:00
David Majnemer
52fa63af48 Avoid generating non-finite values from gumbel and laplace
In the case of gumbel, we take the log(-log(x)), as such we would not want to let x be 0 or 1 as we would get a non-finite number.

In the case of laplace, we take the log1p(-abs(x)), as such we would not want to let x be -1 or 1 as we would get a non-finite number.

This was found by inspection, I have no evidence that this happens in practice.
2019-07-05 11:01:28 -07:00
fehiepsi
dc91d00afd use split 2 instead of split 1 2019-06-27 17:28:36 -04:00
fehiepsi
d284388577 pass tests locally 2019-06-27 13:07:26 -04:00
fehiepsi
907925fb30 improve compling time of gamma 2019-06-27 13:07:26 -04:00
fehiepsi
9252d596eb implement gamma grad 2019-06-27 13:07:25 -04:00
Matthew Johnson
a56a7d02ff make threefry_2x32 not do any op-by-op stuff 2019-06-11 14:56:21 -07:00
Matthew Johnson
c743c14c8f address reviewer comments 2019-05-22 20:06:12 -07:00
Matthew Johnson
93e201f85b make jax.random default dtypes 64-bit
fixes #756
2019-05-22 16:22:12 -07:00
Matthew Johnson
b68ed2787a fix randint docstring 2019-05-16 10:36:30 -07:00
Matthew Johnson
ea8e414a83 improve jax.random shape error messages 2019-05-09 11:40:19 -07:00
Matthew Johnson
910eac8c01
Merge branch 'master' into dirichlet 2019-04-23 17:07:51 -07:00
Matthew Johnson
376ee6423f
Merge pull request #631 from fehiepsi/bernoulli
Implement bernoulli logpmf
2019-04-23 16:59:40 -07:00
fehiepsi
a790436be4 add dirichlet sampler 2019-04-22 11:55:02 -04:00
fehiepsi
665e72a23a implement bernoulli pmf 2019-04-21 21:22:50 -04:00
fehiepsi
088081711d add tests 2019-04-21 16:43:18 -04:00
fehiepsi
e5ccf0534d add beta/t random samplers 2019-04-21 16:25:20 -04:00
Matthew Johnson
0cf14837c9 make a lax package, revert control flow names (#607)
c.f. #597
pair=skyewm
2019-04-12 16:28:40 -07:00