174 Commits

Author SHA1 Message Date
Clemens Schmid
48cb6af6b4 Support None and negative indices in slice_in_dim 2020-01-08 12:22:12 +01:00
Peter Hawkins
574a9ed2cb
Fix incorrect symbolic zero instantiation in scatter JVP rule. (#1903) 2019-12-20 16:09:55 -05:00
Peter Hawkins
178c0d821e
Fix type problem in dynamic_slice_in_dim in int32 default dtype mode. (#1902) 2019-12-20 13:29:53 -05:00
Matthew Johnson
8bd1a46ce7 revise handling of 'backend' values 2019-12-18 14:40:20 -08:00
Peter Hawkins
d692965f88
Implement missing case in scatter batching rule. (#1885)
Add systematic batching tests for gather and scatter-add.
2019-12-17 21:42:37 -05:00
tamaranorman
4af04cefa9 Support dilated transposed convolutions in the conv_transpose op. (#1823)
PiperOrigin-RevId: 284155973
2019-12-16 18:03:17 -08:00
Peter Hawkins
b26a12a358
Implement bool_ support for jnp.add, jnp.multiply, jnp.einsum, lax.do… (#1872)
* Implement bool_ support for jnp.add, jnp.multiply, jnp.einsum, lax.dot and lax.dot_general.

Fix dtype rules for `lax._reduce_sum` and `lax._reduce_prod` to check for number inputs.

Improve error messages for type mismatches to correctly describe scalar type categories (e.g. 'floating') rather than what `onp.dtype(...).name` returns (e.g., 'float64').

Remove redundant `bfloat16` type in `lax._float`, which has been redundant since `dtypes.issubdtype` was taught about `bfloat16` support.
2019-12-16 20:48:19 -05:00
Matthew Johnson
fbde09f567 add tuple_args logic to xla primitive application 2019-12-12 05:21:11 -08:00
Peter Hawkins
3a07c69d0c
Implement jax.numpy.nextafter. (#1845) 2019-12-11 16:41:24 -05:00
Stephan Hoyer
6ac1c569e8
Use HIGHEST precision for dot_general in linalg JVP rules (#1835) 2019-12-10 00:38:18 -08:00
Peter Hawkins
687b9050df
Prepare to switch default dtypes in JAX to be 32-bit types. (#1827)
This change prepares for switching the default types in JAX's NumPy to be 32-bit types. In particular, it makes the JAX tests pass in the event that jax.numpy.int_, jax.numpy.float_, and jax.numpy.complex_ are defined to be 32-bit types instead of 64-bit types, but does not yet change the defaults.
2019-12-09 21:18:39 -05:00
tamaranorman
26e863923a Support atrous conv in same padded convolution and add warning if use transposed convolution with same or valid padding. (#1806)
PiperOrigin-RevId: 283517237
2019-12-09 08:06:59 -08:00
Peter Hawkins
f3c8af49e7
Fix bugs in handling of convolutions whose LHS has spatial size 0. (#1794)
* Fix bugs in handling of convolutions whose LHS has spatial size 0.

* Use onp.shape to compute shapes.
2019-12-02 14:43:43 -05:00
Matthew Johnson
115d365a92 raise error if we do concrete aval FLOPs w/o remat 2019-11-27 19:52:24 -08:00
Matthew Johnson
9a8523603c Add experimental rematerialization decorator
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.

See https://github.com/google/jax/pull/1749 for more.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-11-27 19:52:24 -08:00
Peter Hawkins
a8a19e196c
Implement batching rule for lax._select_and_gather_add (#1736) 2019-11-21 11:52:58 -05:00
Peter Hawkins
ee36818a58
Add bfloat16 support to JAX. (#1720)
bfloat16 support is still immature, but this PR adds some initial support.

Fixes #76, at least enough that we can declare it fixed and open specific issues for specific bfloat16 problems.

The main awkwardness that this change deals with is that classic NumPy doesn't understand bfloat16 promotion rules, so we must:

implement our own type promotion operators that understand bfloat16 types
wrap a number of the reference implementations in tests to temporarily cast to float32 for computation.
2019-11-20 22:43:46 -05:00
Peter Hawkins
42dd736afd
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.

Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.

This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.

In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00
Peter Hawkins
e670bd1a9a
Add stricter type checks for the start_indices arguments to dynamic_slice and dynamic_update_slice. (#1691) 2019-11-14 15:51:27 -05:00
Matthew Johnson
6cd995e3ff allow tokens in op-by-op by calling into _xla_callable_args 2019-11-12 18:38:07 -08:00
Matthew Johnson
67a9247ebe avoid staging out some trivial convert_element_types 2019-11-05 16:52:46 -08:00
Stephan Hoyer
89c90923db Add np.fft.ifftn (#1594)
Fixes GH1010
2019-10-30 10:40:02 -07:00
chenyee
6839f28c6a Fix issue #1576 2019-10-28 22:37:01 +08:00
Peter Hawkins
0d667d2727
Add tests for float16 support in lax_test.py. (#1553)
* Add tests for float16 support in lax_test.py.

Make test tolerances per-type, rather than a single tolerance based on the x64 mode.
Don't test float16 on TPU because it doesn't support float16.
Rework a number of the gradient tests. For linear primitives, increase eps and use a per-type tol.

* Perform float16 sinh and cosh in float32 precision.
More tweaks to test tolerances to get tests to pass.

* Add float16 testing to lax_numpy_test.py as well.

* Fix tolerance computation for testReducer test.
Relax tolerance for polyval.

* Relax some test tolerances further.

* Further relax test tolerances.

* Another tolerance relaxation.

* Use decorator for the upcast to fp32 for computation pattern.

Relax test tolerance for float_power.
2019-10-22 19:53:59 -04:00
Matthew Johnson
0601b8cdc7 make lax.broadcast_in_dim work on scalars
fixes #1548
2019-10-21 15:12:22 -07:00
Peter Hawkins
abe6990964
Add some @jit decorators to non-primitive lax functions. (#1542)
Fix the tests so they don't refer to op.__name__, which no longer has a usable value if the function has been jitted.
2019-10-21 10:56:54 -04:00
Peter Hawkins
9c23a95e6a
Add i0e and i1e Bessel functions. (#1541) 2019-10-21 10:30:55 -04:00
Peter Hawkins
06178d298c
Move lax.tie_in inside lax.full_like onto the fill value instead the output of lax.full. (#1507)
Fixes a bug where constants associated with relu gradients were being hoisted out of loops and materialized, causing a fairly large performance penalty (~20%) for a Resnet-50 model in a loop using infeed.
2019-10-15 15:01:52 -04:00
Peter Hawkins
4a075be62a
Merge pull request #1478 from hawkinsp/infeed
Add experimental support for XLA infeed/outfeed.
2019-10-09 21:09:16 -04:00
James Bradbury
9d2f25cf1a add test 2019-10-09 17:02:11 -07:00
James Bradbury
fb433fb9d2 preserve precision config in dot_general transpose 2019-10-09 16:25:37 -07:00
Peter Hawkins
b8a5473614 Add experimental support for XLA infeed/outfeed. 2019-10-09 15:05:54 -04:00
James Bradbury
6d29c4e352 remove dot primitive in favor of dot_general 2019-10-08 14:44:10 -07:00
James Bradbury
096a52a3a3 add dot_general masking rules 2019-10-08 14:44:10 -07:00
James Bradbury
658882513e avoid more transposes in dot_general batch rule 2019-10-08 14:44:02 -07:00
James Bradbury
064014b53c
Merge pull request #1374 from google/jb/abs-jvp
Improve numerics of abs jvp (and softplus)
2019-09-28 21:43:25 -04:00
Jamie Townsend
f9b9146a92 Ensure lax.scatter cache hits in op-by-op mode 2019-09-24 19:20:12 +02:00
Peter Hawkins
92c42ea1fe Use square(x) instead of pow(x, 2) in div JVP. 2019-09-23 12:46:15 -04:00
James Bradbury
b39179c887 better abs jvp 2019-09-18 23:55:31 -07:00
Matthew Johnson
99b9e48580 python2 fix for ShapeExpr slicing 2019-09-16 16:30:42 -07:00
Matthew Johnson
6662da8275 tweaks to simplify masked jaxprs, rnn test 2019-09-16 15:47:43 -07:00
Matthew Johnson
b71181d3c0 start writing nesting test 2019-09-15 11:10:05 -07:00
Matthew Johnson
283299649b add a 'monomorphic dim' symbol, bug fixes 2019-09-15 11:10:05 -07:00
Matthew Johnson
5b6b72c2fb fix broadcasting bug in rem jvp, fixes #1350 2019-09-15 08:45:58 -07:00
James Bradbury
705eb1cbcb
Merge pull request #1331 from google/jb/dot-general-batch
Remove explicit broadcasts in vmap(dot_general)
2019-09-10 14:49:17 -07:00
James Bradbury
b4b14b7e2b remove broadcasts from _dot_general_batch_rule 2019-09-10 13:58:23 -07:00
Sam Schoenholz
6f2d22fddf Tiny change to enable vmap with dimension numbers. 2019-09-08 14:19:10 -07:00
James Bradbury
35b63c740d add primitive for rsqrt 2019-09-04 15:06:46 -07:00
Matthew Johnson
96b8bb2d4d fix lax._canonicalize_shape for ShapeExprs 2019-09-03 17:18:23 -07:00