261 Commits

Author SHA1 Message Date
Peter Hawkins
afdd1a7367
Add more return types to api.py. (#2452) 2020-03-19 10:28:29 -04:00
Peter Hawkins
68b32bf704
Add mypy type checking (#2430)
* Add type annotations to make mypy pass.

* Add mypy to .travis.yml.
2020-03-18 17:06:05 -04:00
Peter Hawkins
985d5f7327
Fix Python 3.5 support. (#2439)
* Fix Python 3.5 compatibility problems.
2020-03-17 17:01:04 -04:00
jheek
5f2d2254a0
Fix typo in ShapeDtypeStruct (#2253)
* fix ShapeDtypeSTruct dtype bug

* move dtype conversion to constructor
2020-03-17 07:45:17 +01:00
Matthew Johnson
efaedf4b57 undo previous commit 2020-03-16 12:13:25 -07:00
Matthew Johnson
5280793191 fix custom_transforms + jit bug from #2416 2020-03-16 10:23:24 -07:00
Matthew Johnson
a7b3be71e8 move jet into jax.experimental 2020-03-15 11:10:56 -07:00
Matthew Johnson
92a0b3d40a add basic pytree support to jet 2020-03-15 09:58:54 -07:00
Matthew Johnson
e84a621184 new jet implementation, with conv-based rules
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
2020-03-14 18:41:44 -07:00
Matthew Johnson
47df7b95c4
change the xla representation of JAX's unit (#2416)
* change the xla representation of JAX's unit

Previously the representation of JAX's unit value (a sentinel /
placeholder) was an empty tuple, but by changing the representation to
something else we can further reduce our dependence on runtime tuples.

This commit makes the representation fairly easy to change. There are
three functions in xla.py that define the representation. Here are
versions that would keep the old XLA representation as an empty tuple:

```
def _make_unit(c): return c.Tuple()
def _make_abstract_unit(_): return xc.Shape.tuple_shape(())
def _device_put_unit(_, device):
  return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device))
```

The new representation is as a trivial array. An alternative
representation would be nothing at all: we don't need to generate XLA
computations that have representations of JAX units. While that
alterntaive is probably the best choice, it seemed like it would require
a bit more refactoring/bookkeeping (e.g. to allow XLA computations to
have a smaller number of outputs than the corresponding JAX function),
and would also mean the XLA representation would be a step further
removed from the jaxpr representation. So I stuck with a trivial array
for now.

The mapping from JAX types to XLA types need not be invertible. However,
XLA translation rules currently don't take as arguments the
corresponding JAX types (abstract values), and there were a few cases
where we relied on checking whether an argument's XLA type was that of
an empty tuple so as to determine if we were effectively operating on a
JAX unit.

In particular, the AD-related primitive add_jaxvals_p could in principle
add two units, and get lowered to an XLA addition on the unit
representation. Previously, the translation rule for add_jaxvals_p
checked the XLA type so that adding two empty tuples didn't produce any
XLA operation; now it adds its inputs, and so if unit is represented as
a trivial array we could be inserting trivial scalar adds where we had
none before. However, if that case is ever possible, it doesn't come up
in our tests (which I checked by keeping the representation as an empty
tuple and then asserting an XLA tuple type is never seen by that
translation rule).

* add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
Matthew Johnson
ebbcbad547
allow vmap in_axes to be a list, fixes #2367 (#2395) 2020-03-10 08:29:46 -07:00
George Necula
c52f32b59d
Removed unused imports (#2385)
Also disabled a couple more linalg tests that crash on my Mac
2020-03-09 20:42:08 +01:00
George Necula
282225f676
Added some pytype annotations (#2386)
Tried to catch all uses of linear_util.WrappedFun
2020-03-09 20:41:01 +01:00
Chris Jones
0080c89548
Fix a few type annotations in api.py. (#2387) 2020-03-09 06:35:21 -07:00
Peter Hawkins
8e251e1f8d
Add pytype annotations to JAX API. (#2368) 2020-03-06 10:56:29 -05:00
Skye Wanderman-Milne
a1fa6296cc
Document jax.device_put (#2366) 2020-03-05 14:45:01 -08:00
Ram Rachum
52a41311c5
Fix exception causes in api.py (#2336) 2020-03-04 10:08:52 -05:00
George Necula
6e4ea4f70d
Revert "add np.copy method to abstract arrays (#2257)" (#2272)
This reverts commit ae1214de74e9ec42da8ff813dab8577c6bd9231d.

This is only to test the internal pre-submits.
2020-02-20 13:43:53 +01:00
Matthew Johnson
ae1214de74
add np.copy method to abstract arrays (#2257)
* add np.copy method to abstract arrays

fixes #2248

* make device_get use onp.asarray, not .copy()
2020-02-18 12:39:03 -08:00
George Necula
ceab1e3edf Revert "Allow shapecheck of PixelCNN++ (#2017)"
This reverts commit 8f538f4e25d039a76d99af97374e7ece8c1c63a3.

Issue: #2245
2020-02-17 17:56:56 +01:00
Alexander Botev
43ee917511
Adding broadcast_argnums to pmap for allowing similar behaviour t… (#1786)
* Adding `static_argnums` to `pmap` for similar behaviour to `static_argnums` of `jit`.

* Removed check for ShardedDeviceArray

* Final clean up and rename.
2020-02-14 07:45:26 -08:00
Julius Kunze
8f538f4e25
Allow shapecheck of PixelCNN++ (#2017)
* Allow shapecheck of indexing, slicing, broadcast_to, reshape, random.uniform, iota, simple cases of split

* Fix dynamic slicing

* Fix issue with float64.__index__()

* Fix np.arange with float size, _try_canonicalize_shape

* Cleanup: Make methods to create Poly internal (only use in Poly / shape spec parsing)

* Fix testReshapeWithUnusualShapes (error message)

* Fix syntax for python 3.6

* Remove Poly.__index__

* Fix tests

* Split up masking.py

* Cleanup masking

* Cleanup

* Use abstract_eval for shapecheck, remove ShapeCheckTrace(r)

* Remove shape_rules, fix test

* Remove shapes.py, move code to abstract_arrays.py / api.py

* Remove safe_map/zip, is_instance from abstract_arrays, test + fix Poly hash, minimize import diff

* Add missing shapecheck_test.py

* Cleanup, minimize changes

* Minimize import diff

* Minor

* Allow shapecheck of np.where

* Fix np.where

* Simplify gather to allow retightening type assertion in ConcreteArray

* Remove unused imports

* Make import style consistent

* Remove is_polymorphic, special cases in sampling, split, where.

* Move back Poly, _parse_shape_spec into masking.py to simplify diff

* Move back ShapeTest into masking_test.py to simplify diff

* Minor reverts to further simplify diff

* Fix tests

* Minimize diff

* Restore copyright, cleanup imports in masking.py

* Merge branch 'master' of https://github.com/google/jax into shapecheck-pcnn

# Conflicts:
#	jax/api.py
#	jax/numpy/lax_numpy.py
2020-02-14 06:59:05 -08:00
George Necula
a5c3468c93 Added the first draft of the Jaxpr documentation.
This replaces the previous Google Doc version, and is now
updated with the latest changes in Jaxpr.
2020-02-12 13:01:43 +01:00
Tom Hennigan
9797ea2485
Implement size/ndim/__len__/repr/str/eq/hash for ShapeDtypeStruct. (#2206) 2020-02-11 09:11:48 -05:00
George Necula
b79c7948ee Removed dependency on distutils.strtobool 2020-02-06 17:27:46 +01:00
George Necula
4f5987ccd9 Simplify Jaxpr: remove freevars.
Freevars played a very small role, and they can be folded with
the invars. This simplifies the Jaxpr data structure.We remove
the `freevars` field from Jaxpr and from the bound_subjaxprs.

The only non-trivial change is for xla_pmap, where we need
to carry one extra parameter `mapped_invars` with a bitmap
to encode which invars are mapped and which are broadcast.
Previously, the freevars were broadcast.
2020-02-03 18:58:05 +01:00
Matthew Johnson
ae1d6b875f
fix remat with nontrivial env (#2136)
fixes #2030
2020-01-31 23:47:30 -08:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
James Bradbury
a15aa9bd4d
include call stack + transforms in XLA metadata (#2073) 2020-01-26 23:27:56 -08:00
brett koonce
e18d697ac6 minor spelling tweaks (#2043) 2020-01-22 17:18:00 -08:00
Peter Hawkins
7dbc8dc1bc
Minimal changes to make Jax pass a pytype check. (#2024) 2020-01-18 08:26:23 -05:00
Jamie Townsend
3974df0aee [docs] Pmap compiles functions with XLA (#2021) 2020-01-17 09:48:27 -08:00
Julius Kunze
55c971e47f Implement shapecheck for more primitives (#1990)
* shapecheck of jit, device_put, broadcast_in_dim, better error for unsupported ops, parse multi-digit integer literals

* WIP shapecheck np.pad

* Implement shapecheck of gather, pad

* Fix shapecheck of pad

* Implement polymorphic shape rule for (strided/dilated) convolution, refactor

* Cleanup

* Fix

* Remove all polymorphic shape rules, reuse shape rules instead.

* Register shape_rule for all standard_primitives

* Remove ShapeExpr, canonicalize_poly, renames

* Complete shapecheck(binop) implementation, remove special cases for polymorphic shapes

* Allow Poly of form d*poly + k to be divided by d

* Fix bug, inline poly_without_zeros.
2020-01-15 16:36:00 -08:00
Trevor Cai
12975bbcc8 [pmap] Add support for nested pmaps on multihost platforms via axis_size (#2002)
One issue with nested pmaps on multihost platforms is inferring the global
pmap axis size without communication. This commit sidesteps the issue by adding
an `axis_size` argument to manually provide this information.

This change only enables a single cross-host pmap; all inner pmaps must be
single-host.

Addressing: #1753
2020-01-15 10:09:02 -08:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
Matthew Johnson
82dbf91311 add tests for #1640, adapt make_jaxpr staging 2019-12-31 11:53:02 -08:00
Pavel Sountsov
cc92bb6411 Improve the VJP structure mismatch errors. (#1854) 2019-12-13 08:41:51 -05:00
Stephan Hoyer
7bf2d77bd9
Clarify SPMD requirement for pmap (#1826) 2019-12-06 12:03:22 -08:00
Matthew Johnson
0899673363 switch xla_computation instantiate outputs default 2019-12-04 10:34:02 -08:00
Matthew Johnson
c1aeaf511c xla_computation option to instantiate const output 2019-12-04 10:34:02 -08:00
George Necula
2b0b04fcad Merge remote-tracking branch 'upstream/master' into jaxpr_pp 2019-11-28 08:56:00 +01:00
George Necula
0cb3b433b5 Change in how we print sorted params for eqns 2019-11-28 07:34:40 +01:00
Matthew Johnson
9a8523603c Add experimental rematerialization decorator
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.

See https://github.com/google/jax/pull/1749 for more.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-11-27 19:52:24 -08:00
George Necula
e0706ff864 Relaxed check to allow both tuples and lists 2019-11-27 14:24:41 +01:00
George Necula
c1d8d3f74d Add error checking that arguments of jvp are tuples 2019-11-27 13:12:24 +01:00
George Necula
5c15dda2c9 Changed api.make_jaxpr to return a TypedJaxpr
* A TypedJaxpr contains more useful information (consts, types)
* Also forced the instantiation of constants when producing the jaxpr.
  Before:
  >>>print(api.make_jaxpr(lambda x: 1.)(0.))
     lambda ; ; a.
     let
     in [*]}
  After this change:
  >>>print(api.make_jaxpr(lambda x: 1.)(0.))
     lambda ; ; a.
     let
     in [1.0]}
2019-11-26 09:17:03 +01:00
Skye Wanderman-Milne
f415f266b8
Remove 'backend' argument from device_put. (#1762)
The appropriate Backend is instead inferred from the 'device' argument. This is a first step towards removing the 'backend' argument from more functions.
2019-11-25 16:23:40 -08:00
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00
Matthew Johnson
728cb7fba8 improve grad error message without enough args
fixes #1696
2019-11-14 21:18:23 -08:00
Peter Hawkins
6125157db8
Add type checks that verify JVP primal inputs have the same types as tangent inputs, and JVP cotangent inputs have the same type as primal outputs. (#1690) 2019-11-14 15:37:33 -05:00