102 Commits

Author SHA1 Message Date
John Aslanides
c06fe56fc5 Add some types to jax.random and jnp.ndarray. 2020-04-12 09:14:54 +01:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
Matthew Johnson
47df7b95c4
change the xla representation of JAX's unit (#2416)
* change the xla representation of JAX's unit

Previously the representation of JAX's unit value (a sentinel /
placeholder) was an empty tuple, but by changing the representation to
something else we can further reduce our dependence on runtime tuples.

This commit makes the representation fairly easy to change. There are
three functions in xla.py that define the representation. Here are
versions that would keep the old XLA representation as an empty tuple:

```
def _make_unit(c): return c.Tuple()
def _make_abstract_unit(_): return xc.Shape.tuple_shape(())
def _device_put_unit(_, device):
  return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device))
```

The new representation is as a trivial array. An alternative
representation would be nothing at all: we don't need to generate XLA
computations that have representations of JAX units. While that
alterntaive is probably the best choice, it seemed like it would require
a bit more refactoring/bookkeeping (e.g. to allow XLA computations to
have a smaller number of outputs than the corresponding JAX function),
and would also mean the XLA representation would be a step further
removed from the jaxpr representation. So I stuck with a trivial array
for now.

The mapping from JAX types to XLA types need not be invertible. However,
XLA translation rules currently don't take as arguments the
corresponding JAX types (abstract values), and there were a few cases
where we relied on checking whether an argument's XLA type was that of
an empty tuple so as to determine if we were effectively operating on a
JAX unit.

In particular, the AD-related primitive add_jaxvals_p could in principle
add two units, and get lowered to an XLA addition on the unit
representation. Previously, the translation rule for add_jaxvals_p
checked the XLA type so that adding two empty tuples didn't produce any
XLA operation; now it adds its inputs, and so if unit is represented as
a trivial array we could be inserting trivial scalar adds where we had
none before. However, if that case is ever possible, it doesn't come up
in our tests (which I checked by keeping the representation as an empty
tuple and then asserting an XLA tuple type is never seen by that
translation rule).

* add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
Ram Rachum
f3f0abb53e
Fix exception causes all over the codebase (#2376)
Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-03-09 16:06:12 -04:00
George Necula
ceab1e3edf Revert "Allow shapecheck of PixelCNN++ (#2017)"
This reverts commit 8f538f4e25d039a76d99af97374e7ece8c1c63a3.

Issue: #2245
2020-02-17 17:56:56 +01:00
Colin
6d0d6fd6c7
Docstring typo (#2228) 2020-02-14 08:04:20 -08:00
Julius Kunze
8f538f4e25
Allow shapecheck of PixelCNN++ (#2017)
* Allow shapecheck of indexing, slicing, broadcast_to, reshape, random.uniform, iota, simple cases of split

* Fix dynamic slicing

* Fix issue with float64.__index__()

* Fix np.arange with float size, _try_canonicalize_shape

* Cleanup: Make methods to create Poly internal (only use in Poly / shape spec parsing)

* Fix testReshapeWithUnusualShapes (error message)

* Fix syntax for python 3.6

* Remove Poly.__index__

* Fix tests

* Split up masking.py

* Cleanup masking

* Cleanup

* Use abstract_eval for shapecheck, remove ShapeCheckTrace(r)

* Remove shape_rules, fix test

* Remove shapes.py, move code to abstract_arrays.py / api.py

* Remove safe_map/zip, is_instance from abstract_arrays, test + fix Poly hash, minimize import diff

* Add missing shapecheck_test.py

* Cleanup, minimize changes

* Minimize import diff

* Minor

* Allow shapecheck of np.where

* Fix np.where

* Simplify gather to allow retightening type assertion in ConcreteArray

* Remove unused imports

* Make import style consistent

* Remove is_polymorphic, special cases in sampling, split, where.

* Move back Poly, _parse_shape_spec into masking.py to simplify diff

* Move back ShapeTest into masking_test.py to simplify diff

* Minor reverts to further simplify diff

* Fix tests

* Minimize diff

* Restore copyright, cleanup imports in masking.py

* Merge branch 'master' of https://github.com/google/jax into shapecheck-pcnn

# Conflicts:
#	jax/api.py
#	jax/numpy/lax_numpy.py
2020-02-14 06:59:05 -08:00
Pavel Sountsov
b2ef5bc095
Canonicalize the shape in the wrapper functions in random.py. (#2165)
* Canonicalize the shape in the wrapper functions in random.py.

This lets the user be more sloppy in using numpy arrays and statically
known DeviceArrays for shapes, and still hit the jit cache. When they
are not, the error is improved.

* Fix some errors.

* No need for the Poly workaround.

* Bypass canonicalization for None shapes in random.py.
2020-02-05 10:10:33 -08:00
Peter Hawkins
0b1d2fc3d1
Avoid accidental type promotion in gamma sampler gradient. (#2150)
Reformat gamma sampler to use 2 space indent, consistent with the rest of JAX.
2020-02-03 12:44:46 -05:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Matthew Johnson
00be20bdfa
Merge pull request #1855 from JuliusKunze/categorical
Add categorical sampler
2020-01-10 07:59:21 -08:00
Julius Kunze
f36d858c4e Require shape = sample_shape + batch_shape in random.categorical 2020-01-10 13:28:03 +00:00
fehiepsi
edf0e61bc9 support nested vmap for gamma sampler 2019-12-26 22:43:06 -05:00
fehiepsi
c75bf4ab72 make beta sampler faster 2019-12-23 23:02:08 -05:00
fehiepsi
cdfa57dfcc merge master 2019-12-23 22:52:15 -05:00
Matthew Johnson
7175c1dfe1 fix transpose bug in multivariate normal, add test
fixes #1869
2019-12-17 15:08:08 -08:00
Julius Kunze
698327080b Clarify documentation 2019-12-15 15:43:39 +00:00
Julius Kunze
9d12a24b63 Add categorical sampler 2019-12-13 12:41:26 +00:00
fehiepsi
7ec2ac58ca not use custom transform for gamma sampler 2019-12-01 09:44:45 -05:00
Peter Hawkins
534d812b57
Add a handwritten ThreeFry2x32 CUDA kernel. (#1756)
In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time.

When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.
2019-11-24 13:06:23 -05:00
Matthew Johnson
b358c27c92 replace x.shape with onp.shape(x) in random.py
fixes #1748 (thanks @vitchyr)
2019-11-22 10:59:31 -08:00
Peter Hawkins
42dd736afd
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.

Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.

This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.

In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00
Peter Hawkins
f7a44523be
Add some type helpers to lax_numpy. (#1593)
Prefer to use jax.numpy type helpers rather than numpy type helpers in various places.
Cleanup in preparation for adding bfloat16 support to jax.
2019-10-29 20:53:20 -04:00
Matthew Johnson
1f4e45cdcd tweak shape convention again 2019-10-20 21:14:48 +00:00
Matthew Johnson
ab6ac6c876 standardize shape handling in jax.random 2019-10-20 19:37:37 +00:00
Matthew Johnson
8a132b4109 try simplifying random.multivariate_normal api 2019-10-19 22:27:43 +00:00
James Bradbury
74dc4bf72a
Merge pull request #1419 from aeftimia/multivariate-normal
Multivariate normal distribution support
2019-10-15 16:46:13 -07:00
Alex Eftimiades
2f4e88760c fix indentation 2019-10-15 14:35:40 -04:00
Alex Eftimiades
7939bad65a Merge branch 'multivariate-normal' of https://github.com/aeftimia/jax into multivariate-normal 2019-10-12 11:38:06 -04:00
Alex Eftimiades
ed1bd5fccd test cases for ValueError with wrong shape; one test case with negative mean 2019-10-10 09:33:28 -04:00
Alex Eftimiades
d344c6f5e4
Update jax/random.py
Fix typo

Co-Authored-By: James Bradbury <jekbradbury@google.com>
2019-10-04 09:57:17 -04:00
Alex Eftimiades
c1057f77ae minor readability cleanup 2019-10-02 11:39:49 -04:00
Alex Eftimiades
ea67fa8f09 fix failing tests with full covariance matrix 2019-10-02 11:08:17 -04:00
Alex Eftimiades
ad3f25b02f small cleanup and retrigger tests 2019-10-02 09:25:55 -04:00
Trevor Cai
a44c1caf21 Use finfo(dtype).tiny as uniform minval
Otherwise, using the default clopen uniform for truncated_normal
introduces a slight shift of the empirical mean.
2019-10-01 22:28:31 +01:00
Alex Eftimiades
b2b10a3c96 Multivariate normal support
Tested with inverse transform and KS test
2019-10-01 16:49:57 -04:00
Alex Eftimiades
b70a7800d7 multivariate normal support 2019-09-23 20:21:28 -04:00
James Bradbury
146b5d121a
Merge pull request #1262 from google/jb/initializers
Migrate initializers and activation functions to jax.nn
2019-09-04 14:48:20 -07:00
James Bradbury
bf28c44ada address comments 2019-09-03 17:51:29 -07:00
Matthew Johnson
2815c55bb1 switch rolled/unrolled loops in prng hash 2019-08-31 07:35:37 -07:00
James Bradbury
f4aeb363a9 add truncated normal 2019-08-29 18:14:57 -07:00
David Majnemer
1dbdaab765 Use log1p when computing log(1 + x) or log(1 - x)
log(1 + x) is less accurate when its input is near zero whereas log1p
can compute the result without excessive accuracy loss.
2019-08-16 13:44:09 -07:00
Peter Hawkins
719e17ba8e Avoid dynamic slicing in threefry implementation.
The dynamic slice when batched currently turns into an expensive gather because vmap(fori_loop(...)) always batches the loop counter at the moment.
2019-08-15 16:37:04 -04:00
Matthew Johnson
75bb38e741 address reviewer comments, no op-by-op in threefry 2019-08-13 11:30:24 -07:00
Matthew Johnson
d857d3f595 improve prng compile times with outer-loop rolling 2019-08-13 09:43:44 -07:00
Matthew Johnson
275bf9da6d improve prng compile times with loop rolling
cf. #1172
2019-08-13 07:20:09 -07:00
Jamie Townsend
f351dedbb7 Add logistic distribution to jax.random 2019-08-06 12:19:05 +01:00
Matthew Johnson
ec456a181b
Update random.py
Update `jax.random.fold_in` docstring to specify  `data` is treated as a 32bit integer.
2019-07-23 12:21:28 +03:00
Peter Zhokhov
8985d684a1 remove static argnums from random.fold_in 2019-07-19 12:04:33 -07:00