303 Commits

Author SHA1 Message Date
Erich Elsen
ae9e6851cc use correct iinfo finfo names 2020-06-28 19:44:36 +01:00
Erich Elsen
812d246295 don't require passing identity value. It isn't the initial value - identity is required for implementation correctness 2020-06-28 19:33:20 +01:00
Erich Elsen
95e15b64e3 fix typo 2020-06-28 18:37:50 +01:00
Erich Elsen
bf06633a87 add tests 2020-06-28 18:21:09 +01:00
Roy Frostig
ccb640afdb lax.sort: stable by default 2020-06-26 20:37:23 -07:00
Matthew Johnson
11caa21eca
ensure lax.reduce monoid test uses original numpy (#3573) 2020-06-26 11:44:16 -07:00
Norman Casagrande
99a43f20db
Added missing is_stable argument to lax.sort (#3553) 2020-06-26 10:40:00 -07:00
Jamie Townsend
c9670d50c5
Fix lazy broadcast issue (#3536) 2020-06-25 07:50:11 -07:00
Jake Vanderplas
d5a5d301f2
lax.sort: allow any sequence of Arrays, not just tuples (#3367) 2020-06-23 08:28:04 -07:00
Srinivas Vasudevan
927c209148
Add random_gamma_grad and use in jax.random.gamma (#3281) 2020-06-19 09:34:18 -04:00
Jacob Kelly
575216e094
add jet primitives, refactor tests (#3468)
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
2020-06-16 19:48:25 -07:00
Adam Paszke
4d40b208ed
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.

The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
  that can be inverted correctly if only we iterate over its equations
  in reverse. This is a bit strict, because users generally don't have
  too much control over that, and there are functions that produce
  jaxprs which will be treated as invertible when one topological
  ordering of equations is used, while they will be considered
  non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
  that zero argument pruning in JVPTrace makes it pretty much impossible
  to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
  that all the VJPs of primal primitives are jittable (not sure if
  that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
  `custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
Stephan Hoyer
3deada9ede
Document valid enum values for precision. (#3441)
This is a little tricky to figure out otherwise.
2020-06-14 21:42:45 -07:00
Jake Vanderplas
71461a37f3
Revert "Initial implementation of variadic lax.reduce() (#3342)" (#3384)
This reverts commit 99401c5a844cc19c6ce66cc26997f999c9ecf6d8.
2020-06-09 16:09:50 -04:00
Jake Vanderplas
99401c5a84
Initial implementation of variadic lax.reduce() (#3342) 2020-06-09 09:22:29 -07:00
Matthew Johnson
866c17c32e fix a couple ad_util.Zero type checks 2020-06-08 13:22:13 -07:00
Adam Paszke
3f1d3a73ac Remove example from ad.instantiate_zeros, fix vmap bug 2020-06-05 15:52:01 +00:00
Adam Paszke
adb442eb8a Make ad_util.zero a class that carries avals (similar to UndefinedPrimal)
This is useful for remat transpose rule submitted in #3162 and e.g.
allowed me to catch a slight overuse of defjvp2 for `random_gamma_p` (it
was unnecessarily declared as having multiple outputs).
2020-06-05 15:51:30 +00:00
Jake Vanderplas
b187663a87
deflake jax/lax & add to flake8 check (#3310) 2020-06-04 13:50:44 -07:00
Roy Frostig
6015a2a689 introduce lax.switch 2020-06-03 22:19:15 -07:00
Skye Wanderman-Milne
5ad9feda5f
Fix handling of infeed token inside sharded_jit (#3313) 2020-06-03 15:23:49 -07:00
Julius Kunze
d1dbf7c7d8
Implement mask for some primitives + jit. (#2922)
* Implement mask for slice, conv, pad, transpose, where

* Remove tentative mask(jit)

* Add explanatory comment to dot_general masking rule

* Rm reshape from select masking rule

* Rm unnecessary check from lax slice abstract_eval rule

* Revert to standard indentation in masking_test.py

* Begin simplifying masking tests

* Finish drafting masking check function

* More progress simplifying tests

* Add conv masking in batch dim

* Finish fixing up tests

* Revert to old API, making out_shape compulsory again

* More efficient conv masking rule

* Tidy up masking_test imports

* Check that out tree is preserved by masking

* fix flake errors

Co-authored-by: Jamie Townsend <jamestownsend@google.com>
Co-authored-by: Jamie Townsend <jamiehntownsend@gmail.com>
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-03 13:40:48 -07:00
Jake Vanderplas
0db57cb541
Fix validation code in lax.conv (#3279) 2020-06-03 10:33:19 -07:00
Peter Hawkins
dd81a8dded
Fix some type errors in lax.py found by pytype. (#3292) 2020-06-02 10:27:14 -04:00
Jamie Townsend
3909875f9d
Improve speed of tracing dynamic_update_slice (#3247)
* Improve tracing performance of _dynamic_slice_indices

* More precisely preserve semantics of dynamic_slice_indices

* Use safe_map in dynamic_slice_indices
2020-06-02 09:37:32 -04:00
James Bradbury
f1a7073738
pmap(in_axes=None) of sharded_jit (#3257)
* pmap(in_axes=None) of sharded_jit

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>

* address comments

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2020-06-01 16:50:22 -07:00
Peter Hawkins
cf624196ed
Documentation fixes. (#3282)
Improve some cross-references and poorly quoted text.
2020-06-01 18:09:45 -04:00
Matthew Johnson
49a441f745
revisions to #3197 (#3264)
revert find_top_trace change from #3197

The previous version was written and tested for performance; the revised
version caused at least a 25% slowdown in the dispatch time of
`lax.add(1, 2)` (and so likely a much bigger slowdown for the
find_top_trace timing alone).

Instead, we can just change the error message in xla.abstractify, since
invalid types lead to abstractification errors when we apply primitive
impls.
2020-06-01 13:24:40 -07:00
Skye Wanderman-Milne
f78ece0f98
Allow sharding infeed inside sharded_jit. (#3256) 2020-06-01 12:35:18 -07:00
Stephan Hoyer
cc8fbb7669
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py

`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.

This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.

I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.

I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)

* Remove unnecessary branch

* Add lax.squeeze primitive

* Changes per review

* Fix typing

* Move expand_dims into lax

* Update per review; add comments/documentation

* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
Julius Kunze
02b4fd3500
Fix broadcast_shapes for polymorphic dims (#3216) (#3224)
* Fix #3216

* Simplify
2020-05-27 18:15:01 -04:00
Skye Wanderman-Milne
6ffde8061d
Implement pmap of sharded_jit (#3144)
* Implement pmap of sharded_jit

* Update jax/interpreters/pxla.py

Co-authored-by: James Bradbury <jekbradbury@google.com>

* Address comments

Co-authored-by: James Bradbury <jekbradbury@google.com>
2020-05-26 14:26:53 -07:00
George Necula
f1ae2166d0
Added argument check to all primitives. (#3197)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previously if core.skip_checks == False
because then `bind` checks its arguments. I have essentially added
an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and `numpy`
would report the error somehow, perhaps.

* Merged find_top_trace with check_args

This was previously merged as #2948 but reverted awaiting the fixes
in some user code.
2020-05-24 19:12:37 +03:00
alexdavies
85fe5a28f1
Add gradients to the scatter_max and scatter_min operations. (#3111)
This is being done to allow the creation of a differentiable segment_max. Segment_max is an important operation for GraphNets and is an open feature request at https://github.com/google/jax/issues/2255

Co-authored-by: Alex Davies <adavies@google.com>
2020-05-18 23:06:32 -07:00
Skye Wanderman-Milne
888c9c77b3 Implement pmap of sharded_jit 2020-05-18 18:40:28 -07:00
Peter Hawkins
36e7fad1e2
Add a primitive integer_pow() for values raised to a fixed integer scalar. (#3140)
* Add a primitive integer_pow() for values raised to fixed integer scalar.

Use integer_pow() in the RHS JVP of div(). Also use it in square() and reciprocal().

Fixes #3136

```
In [1]: from jax import grad, make_jaxpr
In [2]: def inv(x): return 1/x
In [3]: print(grad(grad(grad(grad(grad(grad(inv))))))(4.))
0.043945312

In [4]: make_jaxpr(grad(grad(grad(grad(grad(grad(inv)))))))(4.)
Out[4]:
{ lambda  ; a.
  let b = integer_pow[ y=-7 ] a
      c = mul -6.0 b
      d = mul -120.0 c
  in (d,) }

In [5]:
```

* Use x ** 3 in gelu definition.
2020-05-18 17:54:20 -04:00
Ed Schmerling
510af1de64
Fix documentation for nn.elu, nn.celu, and lax.expm1. (#3116) 2020-05-15 20:51:53 -07:00
Peter Hawkins
77703b8925
Add support for sorting complex values, defaulting to a NumPy-style l… (#3096)
* Add support for sorting complex values, defaulting to a NumPy-style lexicographic ordering.

Implemented using a custom comparator, since the XLA-level default comparator doesn't impose and ordering for complex values.

* Disable sort test on CPU and TPU.
2020-05-14 19:17:44 -04:00
Peter Hawkins
4ce2aa2563
Make lax.sort support tuple arguments using a variadic sort. (#3085)
* Make lax.sort support tuple arguments using a variadic sort.

Change sort_jvp to use a gather of ids to compute the JVP rather than sorting repeatedly.

Remove sort_key_val_p, since it is redundant with a variadic sort_p.

* Fix mypy errors.

* Change JVP rule to use NumPy indexing.
Remove redundant case in batching rule.
2020-05-14 11:13:15 -04:00
Peter Hawkins
d55ea510e2
Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46. (#3046)
* Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46.

* Bump minimum jaxlib version to 0.1.47.
2020-05-11 17:43:55 -04:00
George Necula
970e475e0a
Undo strict checking of LAX primitives (#2996)
This undoes d08dec5d20
2020-05-07 16:16:22 +03:00
George Necula
d08dec5d63
Added argument check to all primitives. (#2948)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previosuly if core.skip_checks == False
because then `bind` checks its arguments. I have essentially
added an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and
`numpy` would report the error somehow, perhaps.

* Merged find_top_trace with check_args
2020-05-07 09:37:20 +03:00
notEvil
969ed8085c
Add decorator for performing broadcasting inside translation rules (#2468)
* Add decorator for broadcasting at the translation rule layer.

* Fix broadcasting in igamma gradients.

Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-05-06 10:15:17 -04:00
Srinivas Vasudevan
e51c7d7482
Add IgammaGradA (#2504) 2020-05-05 20:10:31 -04:00
tamaranorman
04102e5b9d
Allow ConvDimensionNumbers to be passed into conv_transpose (#2915) 2020-05-04 14:02:13 -04:00
James Bradbury
1cdd8f1b99
Add support for in_axes=None (but not out_axes, or in_axes>0) to pmap (#2896)
* allow in_axes=None for pmap in api.py

* wire in_axes=None through parallel_callable

* add test

* fix error string

* fixes

* fixes

* add test for nested pmap with in_axes

* test pmap still defaults to (implicit) out_axes=0
2020-05-01 14:37:13 -07:00
Julius Kunze
c00e9a2a52
Reapply #2017 (Allow shapecheck of PixelCNN++), fixing #2245 (#2800)
* Unrevert "Allow shapecheck of PixelCNN++ (google#2017)"

This reverts commit ceab1e3edf1e2395035173dc50f24ce6a27475f6.

* Fix out-of-bound slices (#2245)

* Minor

* Add type annotations

* Fix Poly.__rsub__

* any -> _any

* tweaks, mostly comments/whitespace

* separate polymorphic code path, patch _slice_sizes

* put back some logic for handling Poly sizes

* improve test_slice_indices

* Remove to_index, replace with canonicalize_shape

* Fix slicing with polymorphic start/stop

* Test negative step for polymorphic slicing

* Refactor polymorphic slicing

* Simplify diff

* Fix shapecheck(iota)

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-05-01 12:34:29 -07:00
Tom Hennigan
0736679c33
Explicitly broadcast values in nn.one_hot and nn.initializers.orthogonal. (#2901)
At head the following fails:

```python
>>> import jax
>>> import jax.numpy as jnp
>>> jax.config.update('jax_numpy_rank_promotion', 'raise')
>>> jax.nn.one_hot(jnp.ones([8]), 512)
...
ValueError: Operands could not be broadcast together for equal on shapes (8, 1) (512,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.
```
2020-05-01 10:00:38 -07:00
George Necula
ac023bf28f
Fixed a few places where device sticky-ness was lost. Added FAQ (#2882)
* Fixed a few places where device sitckyness was lost. Added FAQ for device
placement.

I have also added a new test (multi_device_test.test_computation_follows_data),
written more as part of the documentation. It is shorted than the
old test_computation_follows_data (which is still there, renamed
as test_computation_follows_data_old). I believe there is no
extra coverage in test_computation_follows_data_old w.r.t. all the
other tests we have.

* Fix mypy annotations and updates based on comments

* Undid some changes, will make another PR
2020-05-01 10:06:59 +03:00
Peter Hawkins
0557248fbd
Check for unsupported dtypes and issue a helpful error. (#2885) 2020-04-29 14:14:49 -04:00