224 Commits

Author SHA1 Message Date
Peter Hawkins
68b32bf704
Add mypy type checking (#2430)
* Add type annotations to make mypy pass.

* Add mypy to .travis.yml.
2020-03-18 17:06:05 -04:00
Matthew Johnson
f1d9130f25 remove safe_mul (undo #383, also cf. #1052) 2020-03-17 22:07:53 -07:00
George Necula
0ddc2ec360 Fixed failing tests 2020-03-17 06:51:01 +01:00
George Necula
5cf82c756e Improved argument checking for lax.broadcast_in_dim
* Added checking that the output shape has higher or equal rank to input
* Added checking that the broadcast_dims are sorted (required by XLA)
* Relaxed check that operand dimension size can be 1
* Added lax.broadcast_in_dim docstring
2020-03-17 06:51:01 +01:00
Matthew Johnson
c0c3a4a506
Merge pull request #2401 from hawkinsp/ones
Check for invalid shapes in broadcast_in_dim and fail gracefully.
2020-03-16 19:46:52 -07:00
Roy Frostig
6545cf3421
Merge pull request #2424 from google/broadcast-shapecheck
add lax.broadcast_in_dim shape check and test
2020-03-15 22:22:24 -07:00
Roy Frostig
94832f9627 add lax.broadcast_in_dim shape check and test
Operand dimensions must equal their corresponding dimensions in the broadcast shape.
2020-03-15 20:30:44 -07:00
Matthew Johnson
a7b3be71e8 move jet into jax.experimental 2020-03-15 11:10:56 -07:00
Matthew Johnson
668a1703bc add jet tests, remove top-level files 2020-03-14 21:22:10 -07:00
Jacob Kelly
840797d4a1 refactor reduce_max jet rule 2020-03-14 18:42:51 -07:00
Jacob Kelly
b4d003d460 jet rule for log
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
2020-03-14 18:42:51 -07:00
Jacob Kelly
30830dfc25 linear rule for sub
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
2020-03-14 18:42:51 -07:00
Jacob Kelly
dcebe50562 jet for reduce_max
Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu>
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
2020-03-14 18:42:51 -07:00
Jacob Kelly
3bcf02a191 Add gather rule
Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu>
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
2020-03-14 18:42:51 -07:00
Jacob Kelly
098aabefcd fix typo 2020-03-14 18:42:51 -07:00
Jacob Kelly
ddd52c4730 adding div and linear prims
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
2020-03-14 18:42:51 -07:00
Matthew Johnson
7adf9fe84f add more jet rules!
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: Jacob Kelly <jacob.jin.kelly@gmail.com>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
2020-03-14 18:41:44 -07:00
Matthew Johnson
a21fdf8669 more jet rules and tests
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
2020-03-14 18:41:44 -07:00
Matthew Johnson
e84a621184 new jet implementation, with conv-based rules
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
2020-03-14 18:41:44 -07:00
Matthew Johnson
7f0463e2c9
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:

  { lambda  ; a.
    let b = reduce_sum[ axes=(0,) ] a
    in b }

The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!

Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:

  { lambda  ; a.
    let b = reduce_sum[ axes=(0,)
                        input_shape=(3,) ] a
    in b }

That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)

But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!

That's exactly what this commit does!

Co-authored-by: Roy Frostig <frostig@google.com>

Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
Peter Hawkins
419961f9dd Check for invalid shapes in broadcast_in_dim and fail gracefully. 2020-03-11 09:57:20 -04:00
Ram Rachum
f3f0abb53e
Fix exception causes all over the codebase (#2376)
Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-03-09 16:06:12 -04:00
Skye Wanderman-Milne
efa0315c8f
[docs] Add docstring for jax.lax.tie_in (#2364) 2020-03-05 16:21:19 -08:00
Peter Hawkins
0416d2a5f2
Fix abstract evaluation rule for lax.top_k. (#2290) 2020-02-24 07:31:46 -08:00
Peter Hawkins
af0967fdbf
Add an experimental lax.top_k operator. (#2280) 2020-02-20 17:15:25 -08:00
George Necula
ceab1e3edf Revert "Allow shapecheck of PixelCNN++ (#2017)"
This reverts commit 8f538f4e25d039a76d99af97374e7ece8c1c63a3.

Issue: #2245
2020-02-17 17:56:56 +01:00
Skye Wanderman-Milne
18420936c4
_scatter_jvp bug fix (#2231) 2020-02-14 18:09:52 -08:00
Julius Kunze
8f538f4e25
Allow shapecheck of PixelCNN++ (#2017)
* Allow shapecheck of indexing, slicing, broadcast_to, reshape, random.uniform, iota, simple cases of split

* Fix dynamic slicing

* Fix issue with float64.__index__()

* Fix np.arange with float size, _try_canonicalize_shape

* Cleanup: Make methods to create Poly internal (only use in Poly / shape spec parsing)

* Fix testReshapeWithUnusualShapes (error message)

* Fix syntax for python 3.6

* Remove Poly.__index__

* Fix tests

* Split up masking.py

* Cleanup masking

* Cleanup

* Use abstract_eval for shapecheck, remove ShapeCheckTrace(r)

* Remove shape_rules, fix test

* Remove shapes.py, move code to abstract_arrays.py / api.py

* Remove safe_map/zip, is_instance from abstract_arrays, test + fix Poly hash, minimize import diff

* Add missing shapecheck_test.py

* Cleanup, minimize changes

* Minimize import diff

* Minor

* Allow shapecheck of np.where

* Fix np.where

* Simplify gather to allow retightening type assertion in ConcreteArray

* Remove unused imports

* Make import style consistent

* Remove is_polymorphic, special cases in sampling, split, where.

* Move back Poly, _parse_shape_spec into masking.py to simplify diff

* Move back ShapeTest into masking_test.py to simplify diff

* Minor reverts to further simplify diff

* Fix tests

* Minimize diff

* Restore copyright, cleanup imports in masking.py

* Merge branch 'master' of https://github.com/google/jax into shapecheck-pcnn

# Conflicts:
#	jax/api.py
#	jax/numpy/lax_numpy.py
2020-02-14 06:59:05 -08:00
Matthew Johnson
e140520466
make pmap inside of eager scan work, fixes #2018 (#2183)
* make pmap inside of eager scan work, fixes #2018

Co-authored-by: Sharad Vikram <sharadmv@google.com>

* Ensure AxisEnv is instantiated with tuples (#2186)

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2020-02-06 17:19:54 -08:00
Pavel Sountsov
b2ef5bc095
Canonicalize the shape in the wrapper functions in random.py. (#2165)
* Canonicalize the shape in the wrapper functions in random.py.

This lets the user be more sloppy in using numpy arrays and statically
known DeviceArrays for shapes, and still hit the jit cache. When they
are not, the error is improved.

* Fix some errors.

* No need for the Poly workaround.

* Bypass canonicalization for None shapes in random.py.
2020-02-05 10:10:33 -08:00
George Necula
4f5987ccd9 Simplify Jaxpr: remove freevars.
Freevars played a very small role, and they can be folded with
the invars. This simplifies the Jaxpr data structure.We remove
the `freevars` field from Jaxpr and from the bound_subjaxprs.

The only non-trivial change is for xla_pmap, where we need
to carry one extra parameter `mapped_invars` with a bitmap
to encode which invars are mapped and which are broadcast.
Previously, the freevars were broadcast.
2020-02-03 18:58:05 +01:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Srinivas Vasudevan
62966d9a9f
Add gammainc/gammaincc to JAX (#2064) 2020-01-29 11:25:21 -05:00
Peter Hawkins
0904e5ff74
Fix implementation of cumsum/cumprod for boolean inputs. (#2112)
Check for number inputs in the reduce_window_sum dtype rule.
2020-01-29 10:51:39 -05:00
Roman Novak
6a4bb95169
Mare the reverse operator work on empty list of dimensions
Example that this fixes:
```
from jax import lax
import jax.numpy as np
from jax.api import jacrev

x = np.ones((3, 5))

def f(x):
  return lax.conv_general_dilated(lhs=x, 
                                  rhs=np.ones((5, 2)), 
                                  window_strides=(), 
                                  padding='VALID', 
                                  dimension_numbers=('NC', 'IO', 'NC'))
  
jacrev(f)(x)
```
currently gives
```
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-136-2ad65e41f1de> in <module>()
     12                                   dimension_numbers=('NC', 'IO', 'NC'))
     13 
---> 14 jacrev(f)(x).shape

15 frames
google3/third_party/py/jax/api.py in jacfun(*args, **kwargs)
    514     y, pullback = vjp(f_partial, *dyn_args)
    515     holomorphic or tree_map(_check_real_output_jacrev, y)
--> 516     jac = vmap(pullback)(_std_basis(y))
    517     jac = jac[0] if isinstance(argnums, int) else jac
    518     example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args

google3/third_party/py/jax/api.py in batched_fun(*args)
    692     _check_axis_sizes(in_tree, args_flat, in_axes_flat)
    693     out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 694                               lambda: _flatten_axes(out_tree(), out_axes))
    695     return tree_unflatten(out_tree(), out_flat)
    696 

google3/third_party/py/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests)
     38 def batch(fun, in_vals, in_dims, out_dim_dests):
     39   size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
---> 40   out_vals, out_dims = batch_fun(fun, in_vals, in_dims)
     41   return map(partial(matchaxis, size), out_dims, out_dim_dests(), out_vals)
     42 

google3/third_party/py/jax/interpreters/batching.py in batch_fun(fun, in_vals, in_dims)
     44   with new_master(BatchTrace) as master:
     45     fun, out_dims = batch_subtrace(fun, master, in_dims)
---> 46     out_vals = fun.call_wrapped(*in_vals)
     47     del master
     48   return out_vals, out_dims()

google3/third_party/py/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    150     gen = None
    151 
--> 152     ans = self.f(*args, **dict(self.params, **kwargs))
    153     del args
    154     while stack:

google3/third_party/py/jax/api.py in _vjp_pullback_wrapper(fun, cotangent_dtypes, io_tree, py_args)
   1237              "match type of corresponding primal output ({})")
   1238       raise TypeError(msg.format(_dtype(a), dtype))
-> 1239   ans = fun(*args)
   1240   return tree_unflatten(out_tree, ans)
   1241 

google3/third_party/py/jax/interpreters/ad.py in vjp_(*cts)
    114     dummy_primals_and_cts = (core.unit,) * len(cts) + cts
    115     dummy_args = (undefined_primal,) * len(jaxpr.invars)
--> 116     _, arg_cts = backward_pass(jaxpr, consts, (), dummy_args, dummy_primals_and_cts)
    117     arg_cts = arg_cts[len(primals):]
    118     return map(instantiate_zeros, primals, arg_cts)

google3/third_party/py/jax/interpreters/ad.py in backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in)
    222       map(write_cotangent, bound_vars, ct_free_vars_out)
    223     else:
--> 224       cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params)
    225     cts_out = [zero] * len(eqn.invars) if cts_out is zero else cts_out
    226     map(write_cotangent, eqn.invars, cts_out)

google3/third_party/py/jax/interpreters/ad.py in bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs)
    505   assert (x is undefined_primal) ^ (y is undefined_primal)
    506   if x is undefined_primal:
--> 507     out = zero if cotangent is zero else lhs_rule(cotangent, y, **kwargs)
    508     return out, None
    509   else:

google3/third_party/py/jax/lax/lax.py in _conv_general_dilated_transpose_lhs(g, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, lhs_shape, rhs_shape, precision)
   2042       window_strides, onp.take(g.shape, out_sdims), padding, lhs_dilation,
   2043       rhs_dilation)
-> 2044   revd_weights = rev(rhs, rhs_sdims)
   2045   return conv_general_dilated(
   2046       g, revd_weights, window_strides=lhs_dilation, padding=padding,

google3/third_party/py/jax/lax/lax.py in rev(operand, dimensions)
    671   operator.
    672   """
--> 673   return rev_p.bind(operand, dimensions=tuple(dimensions))
    674 
    675 def select(pred, on_true, on_false):

google3/third_party/py/jax/core.py in bind(self, *args, **kwargs)
    157     top_trace = find_top_trace(args)
    158     if top_trace is None:
--> 159       return self.impl(*args, **kwargs)
    160 
    161     tracers = map(top_trace.full_raise, args)

google3/third_party/py/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
    159 def apply_primitive(prim, *args, **params):
    160   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
--> 161   compiled_fun = xla_primitive_callable(prim, *map(arg_spec, args), **params)
    162   return compiled_fun(*args)
    163 

google3/third_party/py/jax/interpreters/xla.py in xla_primitive_callable(prim, *arg_specs, **params)
    167   device = _device_from_arg_devices(arg_devices)
    168   backend = xb.get_device_backend(device)
--> 169   aval_out = prim.abstract_eval(*avals, **params)
    170   if not prim.multiple_results:
    171     handle_result = aval_to_result_handler(device, aval_out)

google3/third_party/py/jax/lax/lax.py in standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs)
   1540     return ConcreteArray(prim.impl(*[x.val for x in args], **kwargs))
   1541   elif least_specialized is ShapedArray:
-> 1542     return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
   1543   elif least_specialized is UnshapedArray:
   1544     return UnshapedArray(dtype_rule(*args, **kwargs))

google3/third_party/py/jax/lax/lax.py in _rev_shape_rule(operand, dimensions)
   2620     msg = 'rev dimensions must be unique, got {}.'
   2621     raise TypeError(msg.format(dimensions))
-> 2622   if not _max(dimensions) < operand.ndim:
   2623     msg = ('rev dimensions must all be less than operand ndim, got dimensions '
   2624            '{} for operand ndim {}.')

ValueError: max() arg is an empty sequence
```
2020-01-27 00:16:04 -08:00
James Bradbury
a15aa9bd4d
include call stack + transforms in XLA metadata (#2073) 2020-01-26 23:27:56 -08:00
Peter Hawkins
632326ac5c
Add unsupported wrapper around XLA RngUniform API. (#2068) 2020-01-24 16:58:00 -05:00
Peter Hawkins
a3de80201f
Fix type specifications for bitwise ops. (#2054) 2020-01-23 11:53:55 -05:00
brett koonce
e18d697ac6 minor spelling tweaks (#2043) 2020-01-22 17:18:00 -08:00
Peter Hawkins
7dbc8dc1bc
Minimal changes to make Jax pass a pytype check. (#2024) 2020-01-18 08:26:23 -05:00
Julius Kunze
55c971e47f Implement shapecheck for more primitives (#1990)
* shapecheck of jit, device_put, broadcast_in_dim, better error for unsupported ops, parse multi-digit integer literals

* WIP shapecheck np.pad

* Implement shapecheck of gather, pad

* Fix shapecheck of pad

* Implement polymorphic shape rule for (strided/dilated) convolution, refactor

* Cleanup

* Fix

* Remove all polymorphic shape rules, reuse shape rules instead.

* Register shape_rule for all standard_primitives

* Remove ShapeExpr, canonicalize_poly, renames

* Complete shapecheck(binop) implementation, remove special cases for polymorphic shapes

* Allow Poly of form d*poly + k to be divided by d

* Fix bug, inline poly_without_zeros.
2020-01-15 16:36:00 -08:00
Srinivas Vasudevan
80b35dd4e5 Add betainc to JAX (#1998)
Adds betaln, a wrapper for the Beta function (scipy.special.betaln).
2020-01-15 16:13:11 -05:00
Peter Hawkins
facbe0d76a
Handle 0D convolutions correctly in shape rule. (#1972) 2020-01-09 14:36:37 -05:00
Matthew Johnson
327dca8f76
Merge pull request #1944 from clemisch/master
Implement numpy.gradient
2020-01-09 10:46:57 -08:00
Peter Hawkins
ab2582585e
Implement np.sign for unsigned integers. (#1970)
Fix definition of np.sign for complex numbers.
Document lax.sign better for non-float types.
2020-01-09 11:16:52 -05:00
clemisch
c907504078
Merge branch 'master' into master 2020-01-09 07:42:55 +01:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
Clemens Schmid
48cb6af6b4 Support None and negative indices in slice_in_dim 2020-01-08 12:22:12 +01:00
Matthew Johnson
ad9b6d4d94 implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(n)

But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(m)[:, None]

The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.

Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).

This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.

Incidentally fixes #1431

See https://github.com/google/jax/pull/1668 for more.
2020-01-07 20:48:26 -08:00
Matthew Johnson
bb9cd23368 tweak shape error message, add test 2020-01-06 21:38:00 -08:00