201 Commits

Author SHA1 Message Date
Peter Hawkins
0416d2a5f2
Fix abstract evaluation rule for lax.top_k. (#2290) 2020-02-24 07:31:46 -08:00
Peter Hawkins
af0967fdbf
Add an experimental lax.top_k operator. (#2280) 2020-02-20 17:15:25 -08:00
George Necula
ceab1e3edf Revert "Allow shapecheck of PixelCNN++ (#2017)"
This reverts commit 8f538f4e25d039a76d99af97374e7ece8c1c63a3.

Issue: #2245
2020-02-17 17:56:56 +01:00
Skye Wanderman-Milne
18420936c4
_scatter_jvp bug fix (#2231) 2020-02-14 18:09:52 -08:00
Julius Kunze
8f538f4e25
Allow shapecheck of PixelCNN++ (#2017)
* Allow shapecheck of indexing, slicing, broadcast_to, reshape, random.uniform, iota, simple cases of split

* Fix dynamic slicing

* Fix issue with float64.__index__()

* Fix np.arange with float size, _try_canonicalize_shape

* Cleanup: Make methods to create Poly internal (only use in Poly / shape spec parsing)

* Fix testReshapeWithUnusualShapes (error message)

* Fix syntax for python 3.6

* Remove Poly.__index__

* Fix tests

* Split up masking.py

* Cleanup masking

* Cleanup

* Use abstract_eval for shapecheck, remove ShapeCheckTrace(r)

* Remove shape_rules, fix test

* Remove shapes.py, move code to abstract_arrays.py / api.py

* Remove safe_map/zip, is_instance from abstract_arrays, test + fix Poly hash, minimize import diff

* Add missing shapecheck_test.py

* Cleanup, minimize changes

* Minimize import diff

* Minor

* Allow shapecheck of np.where

* Fix np.where

* Simplify gather to allow retightening type assertion in ConcreteArray

* Remove unused imports

* Make import style consistent

* Remove is_polymorphic, special cases in sampling, split, where.

* Move back Poly, _parse_shape_spec into masking.py to simplify diff

* Move back ShapeTest into masking_test.py to simplify diff

* Minor reverts to further simplify diff

* Fix tests

* Minimize diff

* Restore copyright, cleanup imports in masking.py

* Merge branch 'master' of https://github.com/google/jax into shapecheck-pcnn

# Conflicts:
#	jax/api.py
#	jax/numpy/lax_numpy.py
2020-02-14 06:59:05 -08:00
Matthew Johnson
e140520466
make pmap inside of eager scan work, fixes #2018 (#2183)
* make pmap inside of eager scan work, fixes #2018

Co-authored-by: Sharad Vikram <sharadmv@google.com>

* Ensure AxisEnv is instantiated with tuples (#2186)

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2020-02-06 17:19:54 -08:00
Pavel Sountsov
b2ef5bc095
Canonicalize the shape in the wrapper functions in random.py. (#2165)
* Canonicalize the shape in the wrapper functions in random.py.

This lets the user be more sloppy in using numpy arrays and statically
known DeviceArrays for shapes, and still hit the jit cache. When they
are not, the error is improved.

* Fix some errors.

* No need for the Poly workaround.

* Bypass canonicalization for None shapes in random.py.
2020-02-05 10:10:33 -08:00
George Necula
4f5987ccd9 Simplify Jaxpr: remove freevars.
Freevars played a very small role, and they can be folded with
the invars. This simplifies the Jaxpr data structure.We remove
the `freevars` field from Jaxpr and from the bound_subjaxprs.

The only non-trivial change is for xla_pmap, where we need
to carry one extra parameter `mapped_invars` with a bitmap
to encode which invars are mapped and which are broadcast.
Previously, the freevars were broadcast.
2020-02-03 18:58:05 +01:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Srinivas Vasudevan
62966d9a9f
Add gammainc/gammaincc to JAX (#2064) 2020-01-29 11:25:21 -05:00
Peter Hawkins
0904e5ff74
Fix implementation of cumsum/cumprod for boolean inputs. (#2112)
Check for number inputs in the reduce_window_sum dtype rule.
2020-01-29 10:51:39 -05:00
Roman Novak
6a4bb95169
Mare the reverse operator work on empty list of dimensions
Example that this fixes:
```
from jax import lax
import jax.numpy as np
from jax.api import jacrev

x = np.ones((3, 5))

def f(x):
  return lax.conv_general_dilated(lhs=x, 
                                  rhs=np.ones((5, 2)), 
                                  window_strides=(), 
                                  padding='VALID', 
                                  dimension_numbers=('NC', 'IO', 'NC'))
  
jacrev(f)(x)
```
currently gives
```
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-136-2ad65e41f1de> in <module>()
     12                                   dimension_numbers=('NC', 'IO', 'NC'))
     13 
---> 14 jacrev(f)(x).shape

15 frames
google3/third_party/py/jax/api.py in jacfun(*args, **kwargs)
    514     y, pullback = vjp(f_partial, *dyn_args)
    515     holomorphic or tree_map(_check_real_output_jacrev, y)
--> 516     jac = vmap(pullback)(_std_basis(y))
    517     jac = jac[0] if isinstance(argnums, int) else jac
    518     example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args

google3/third_party/py/jax/api.py in batched_fun(*args)
    692     _check_axis_sizes(in_tree, args_flat, in_axes_flat)
    693     out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 694                               lambda: _flatten_axes(out_tree(), out_axes))
    695     return tree_unflatten(out_tree(), out_flat)
    696 

google3/third_party/py/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests)
     38 def batch(fun, in_vals, in_dims, out_dim_dests):
     39   size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
---> 40   out_vals, out_dims = batch_fun(fun, in_vals, in_dims)
     41   return map(partial(matchaxis, size), out_dims, out_dim_dests(), out_vals)
     42 

google3/third_party/py/jax/interpreters/batching.py in batch_fun(fun, in_vals, in_dims)
     44   with new_master(BatchTrace) as master:
     45     fun, out_dims = batch_subtrace(fun, master, in_dims)
---> 46     out_vals = fun.call_wrapped(*in_vals)
     47     del master
     48   return out_vals, out_dims()

google3/third_party/py/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    150     gen = None
    151 
--> 152     ans = self.f(*args, **dict(self.params, **kwargs))
    153     del args
    154     while stack:

google3/third_party/py/jax/api.py in _vjp_pullback_wrapper(fun, cotangent_dtypes, io_tree, py_args)
   1237              "match type of corresponding primal output ({})")
   1238       raise TypeError(msg.format(_dtype(a), dtype))
-> 1239   ans = fun(*args)
   1240   return tree_unflatten(out_tree, ans)
   1241 

google3/third_party/py/jax/interpreters/ad.py in vjp_(*cts)
    114     dummy_primals_and_cts = (core.unit,) * len(cts) + cts
    115     dummy_args = (undefined_primal,) * len(jaxpr.invars)
--> 116     _, arg_cts = backward_pass(jaxpr, consts, (), dummy_args, dummy_primals_and_cts)
    117     arg_cts = arg_cts[len(primals):]
    118     return map(instantiate_zeros, primals, arg_cts)

google3/third_party/py/jax/interpreters/ad.py in backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in)
    222       map(write_cotangent, bound_vars, ct_free_vars_out)
    223     else:
--> 224       cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params)
    225     cts_out = [zero] * len(eqn.invars) if cts_out is zero else cts_out
    226     map(write_cotangent, eqn.invars, cts_out)

google3/third_party/py/jax/interpreters/ad.py in bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs)
    505   assert (x is undefined_primal) ^ (y is undefined_primal)
    506   if x is undefined_primal:
--> 507     out = zero if cotangent is zero else lhs_rule(cotangent, y, **kwargs)
    508     return out, None
    509   else:

google3/third_party/py/jax/lax/lax.py in _conv_general_dilated_transpose_lhs(g, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, lhs_shape, rhs_shape, precision)
   2042       window_strides, onp.take(g.shape, out_sdims), padding, lhs_dilation,
   2043       rhs_dilation)
-> 2044   revd_weights = rev(rhs, rhs_sdims)
   2045   return conv_general_dilated(
   2046       g, revd_weights, window_strides=lhs_dilation, padding=padding,

google3/third_party/py/jax/lax/lax.py in rev(operand, dimensions)
    671   operator.
    672   """
--> 673   return rev_p.bind(operand, dimensions=tuple(dimensions))
    674 
    675 def select(pred, on_true, on_false):

google3/third_party/py/jax/core.py in bind(self, *args, **kwargs)
    157     top_trace = find_top_trace(args)
    158     if top_trace is None:
--> 159       return self.impl(*args, **kwargs)
    160 
    161     tracers = map(top_trace.full_raise, args)

google3/third_party/py/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
    159 def apply_primitive(prim, *args, **params):
    160   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
--> 161   compiled_fun = xla_primitive_callable(prim, *map(arg_spec, args), **params)
    162   return compiled_fun(*args)
    163 

google3/third_party/py/jax/interpreters/xla.py in xla_primitive_callable(prim, *arg_specs, **params)
    167   device = _device_from_arg_devices(arg_devices)
    168   backend = xb.get_device_backend(device)
--> 169   aval_out = prim.abstract_eval(*avals, **params)
    170   if not prim.multiple_results:
    171     handle_result = aval_to_result_handler(device, aval_out)

google3/third_party/py/jax/lax/lax.py in standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs)
   1540     return ConcreteArray(prim.impl(*[x.val for x in args], **kwargs))
   1541   elif least_specialized is ShapedArray:
-> 1542     return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
   1543   elif least_specialized is UnshapedArray:
   1544     return UnshapedArray(dtype_rule(*args, **kwargs))

google3/third_party/py/jax/lax/lax.py in _rev_shape_rule(operand, dimensions)
   2620     msg = 'rev dimensions must be unique, got {}.'
   2621     raise TypeError(msg.format(dimensions))
-> 2622   if not _max(dimensions) < operand.ndim:
   2623     msg = ('rev dimensions must all be less than operand ndim, got dimensions '
   2624            '{} for operand ndim {}.')

ValueError: max() arg is an empty sequence
```
2020-01-27 00:16:04 -08:00
James Bradbury
a15aa9bd4d
include call stack + transforms in XLA metadata (#2073) 2020-01-26 23:27:56 -08:00
Peter Hawkins
632326ac5c
Add unsupported wrapper around XLA RngUniform API. (#2068) 2020-01-24 16:58:00 -05:00
Peter Hawkins
a3de80201f
Fix type specifications for bitwise ops. (#2054) 2020-01-23 11:53:55 -05:00
brett koonce
e18d697ac6 minor spelling tweaks (#2043) 2020-01-22 17:18:00 -08:00
Peter Hawkins
7dbc8dc1bc
Minimal changes to make Jax pass a pytype check. (#2024) 2020-01-18 08:26:23 -05:00
Julius Kunze
55c971e47f Implement shapecheck for more primitives (#1990)
* shapecheck of jit, device_put, broadcast_in_dim, better error for unsupported ops, parse multi-digit integer literals

* WIP shapecheck np.pad

* Implement shapecheck of gather, pad

* Fix shapecheck of pad

* Implement polymorphic shape rule for (strided/dilated) convolution, refactor

* Cleanup

* Fix

* Remove all polymorphic shape rules, reuse shape rules instead.

* Register shape_rule for all standard_primitives

* Remove ShapeExpr, canonicalize_poly, renames

* Complete shapecheck(binop) implementation, remove special cases for polymorphic shapes

* Allow Poly of form d*poly + k to be divided by d

* Fix bug, inline poly_without_zeros.
2020-01-15 16:36:00 -08:00
Srinivas Vasudevan
80b35dd4e5 Add betainc to JAX (#1998)
Adds betaln, a wrapper for the Beta function (scipy.special.betaln).
2020-01-15 16:13:11 -05:00
Peter Hawkins
facbe0d76a
Handle 0D convolutions correctly in shape rule. (#1972) 2020-01-09 14:36:37 -05:00
Matthew Johnson
327dca8f76
Merge pull request #1944 from clemisch/master
Implement numpy.gradient
2020-01-09 10:46:57 -08:00
Peter Hawkins
ab2582585e
Implement np.sign for unsigned integers. (#1970)
Fix definition of np.sign for complex numbers.
Document lax.sign better for non-float types.
2020-01-09 11:16:52 -05:00
clemisch
c907504078
Merge branch 'master' into master 2020-01-09 07:42:55 +01:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
Clemens Schmid
48cb6af6b4 Support None and negative indices in slice_in_dim 2020-01-08 12:22:12 +01:00
Matthew Johnson
ad9b6d4d94 implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(n)

But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(m)[:, None]

The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.

Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).

This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.

Incidentally fixes #1431

See https://github.com/google/jax/pull/1668 for more.
2020-01-07 20:48:26 -08:00
Matthew Johnson
bb9cd23368 tweak shape error message, add test 2020-01-06 21:38:00 -08:00
Roy Frostig
1ca9e9b251 Concatenate error messages under numpy.{zeros,ones,full}. Closes #1822 2020-01-06 18:20:57 -08:00
Peter Hawkins
574a9ed2cb
Fix incorrect symbolic zero instantiation in scatter JVP rule. (#1903) 2019-12-20 16:09:55 -05:00
Peter Hawkins
178c0d821e
Fix type problem in dynamic_slice_in_dim in int32 default dtype mode. (#1902) 2019-12-20 13:29:53 -05:00
Matthew Johnson
8bd1a46ce7 revise handling of 'backend' values 2019-12-18 14:40:20 -08:00
Peter Hawkins
d692965f88
Implement missing case in scatter batching rule. (#1885)
Add systematic batching tests for gather and scatter-add.
2019-12-17 21:42:37 -05:00
tamaranorman
4af04cefa9 Support dilated transposed convolutions in the conv_transpose op. (#1823)
PiperOrigin-RevId: 284155973
2019-12-16 18:03:17 -08:00
Peter Hawkins
b26a12a358
Implement bool_ support for jnp.add, jnp.multiply, jnp.einsum, lax.do… (#1872)
* Implement bool_ support for jnp.add, jnp.multiply, jnp.einsum, lax.dot and lax.dot_general.

Fix dtype rules for `lax._reduce_sum` and `lax._reduce_prod` to check for number inputs.

Improve error messages for type mismatches to correctly describe scalar type categories (e.g. 'floating') rather than what `onp.dtype(...).name` returns (e.g., 'float64').

Remove redundant `bfloat16` type in `lax._float`, which has been redundant since `dtypes.issubdtype` was taught about `bfloat16` support.
2019-12-16 20:48:19 -05:00
Matthew Johnson
fbde09f567 add tuple_args logic to xla primitive application 2019-12-12 05:21:11 -08:00
Peter Hawkins
3a07c69d0c
Implement jax.numpy.nextafter. (#1845) 2019-12-11 16:41:24 -05:00
Stephan Hoyer
6ac1c569e8
Use HIGHEST precision for dot_general in linalg JVP rules (#1835) 2019-12-10 00:38:18 -08:00
Peter Hawkins
687b9050df
Prepare to switch default dtypes in JAX to be 32-bit types. (#1827)
This change prepares for switching the default types in JAX's NumPy to be 32-bit types. In particular, it makes the JAX tests pass in the event that jax.numpy.int_, jax.numpy.float_, and jax.numpy.complex_ are defined to be 32-bit types instead of 64-bit types, but does not yet change the defaults.
2019-12-09 21:18:39 -05:00
tamaranorman
26e863923a Support atrous conv in same padded convolution and add warning if use transposed convolution with same or valid padding. (#1806)
PiperOrigin-RevId: 283517237
2019-12-09 08:06:59 -08:00
Peter Hawkins
f3c8af49e7
Fix bugs in handling of convolutions whose LHS has spatial size 0. (#1794)
* Fix bugs in handling of convolutions whose LHS has spatial size 0.

* Use onp.shape to compute shapes.
2019-12-02 14:43:43 -05:00
Matthew Johnson
115d365a92 raise error if we do concrete aval FLOPs w/o remat 2019-11-27 19:52:24 -08:00
Matthew Johnson
9a8523603c Add experimental rematerialization decorator
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.

See https://github.com/google/jax/pull/1749 for more.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-11-27 19:52:24 -08:00
Peter Hawkins
a8a19e196c
Implement batching rule for lax._select_and_gather_add (#1736) 2019-11-21 11:52:58 -05:00
Peter Hawkins
ee36818a58
Add bfloat16 support to JAX. (#1720)
bfloat16 support is still immature, but this PR adds some initial support.

Fixes #76, at least enough that we can declare it fixed and open specific issues for specific bfloat16 problems.

The main awkwardness that this change deals with is that classic NumPy doesn't understand bfloat16 promotion rules, so we must:

implement our own type promotion operators that understand bfloat16 types
wrap a number of the reference implementations in tests to temporarily cast to float32 for computation.
2019-11-20 22:43:46 -05:00
Peter Hawkins
42dd736afd
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.

Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.

This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.

In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00
Peter Hawkins
e670bd1a9a
Add stricter type checks for the start_indices arguments to dynamic_slice and dynamic_update_slice. (#1691) 2019-11-14 15:51:27 -05:00
Matthew Johnson
6cd995e3ff allow tokens in op-by-op by calling into _xla_callable_args 2019-11-12 18:38:07 -08:00
Matthew Johnson
67a9247ebe avoid staging out some trivial convert_element_types 2019-11-05 16:52:46 -08:00
Stephan Hoyer
89c90923db Add np.fft.ifftn (#1594)
Fixes GH1010
2019-10-30 10:40:02 -07:00