159 Commits

Author SHA1 Message Date
Peter Hawkins
a8a19e196c
Implement batching rule for lax._select_and_gather_add (#1736) 2019-11-21 11:52:58 -05:00
Peter Hawkins
ee36818a58
Add bfloat16 support to JAX. (#1720)
bfloat16 support is still immature, but this PR adds some initial support.

Fixes #76, at least enough that we can declare it fixed and open specific issues for specific bfloat16 problems.

The main awkwardness that this change deals with is that classic NumPy doesn't understand bfloat16 promotion rules, so we must:

implement our own type promotion operators that understand bfloat16 types
wrap a number of the reference implementations in tests to temporarily cast to float32 for computation.
2019-11-20 22:43:46 -05:00
Peter Hawkins
42dd736afd
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.

Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.

This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.

In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00
Peter Hawkins
e670bd1a9a
Add stricter type checks for the start_indices arguments to dynamic_slice and dynamic_update_slice. (#1691) 2019-11-14 15:51:27 -05:00
Matthew Johnson
6cd995e3ff allow tokens in op-by-op by calling into _xla_callable_args 2019-11-12 18:38:07 -08:00
Matthew Johnson
67a9247ebe avoid staging out some trivial convert_element_types 2019-11-05 16:52:46 -08:00
Stephan Hoyer
89c90923db Add np.fft.ifftn (#1594)
Fixes GH1010
2019-10-30 10:40:02 -07:00
chenyee
6839f28c6a Fix issue #1576 2019-10-28 22:37:01 +08:00
Peter Hawkins
0d667d2727
Add tests for float16 support in lax_test.py. (#1553)
* Add tests for float16 support in lax_test.py.

Make test tolerances per-type, rather than a single tolerance based on the x64 mode.
Don't test float16 on TPU because it doesn't support float16.
Rework a number of the gradient tests. For linear primitives, increase eps and use a per-type tol.

* Perform float16 sinh and cosh in float32 precision.
More tweaks to test tolerances to get tests to pass.

* Add float16 testing to lax_numpy_test.py as well.

* Fix tolerance computation for testReducer test.
Relax tolerance for polyval.

* Relax some test tolerances further.

* Further relax test tolerances.

* Another tolerance relaxation.

* Use decorator for the upcast to fp32 for computation pattern.

Relax test tolerance for float_power.
2019-10-22 19:53:59 -04:00
Matthew Johnson
0601b8cdc7 make lax.broadcast_in_dim work on scalars
fixes #1548
2019-10-21 15:12:22 -07:00
Peter Hawkins
abe6990964
Add some @jit decorators to non-primitive lax functions. (#1542)
Fix the tests so they don't refer to op.__name__, which no longer has a usable value if the function has been jitted.
2019-10-21 10:56:54 -04:00
Peter Hawkins
9c23a95e6a
Add i0e and i1e Bessel functions. (#1541) 2019-10-21 10:30:55 -04:00
Peter Hawkins
06178d298c
Move lax.tie_in inside lax.full_like onto the fill value instead the output of lax.full. (#1507)
Fixes a bug where constants associated with relu gradients were being hoisted out of loops and materialized, causing a fairly large performance penalty (~20%) for a Resnet-50 model in a loop using infeed.
2019-10-15 15:01:52 -04:00
Peter Hawkins
4a075be62a
Merge pull request #1478 from hawkinsp/infeed
Add experimental support for XLA infeed/outfeed.
2019-10-09 21:09:16 -04:00
James Bradbury
9d2f25cf1a add test 2019-10-09 17:02:11 -07:00
James Bradbury
fb433fb9d2 preserve precision config in dot_general transpose 2019-10-09 16:25:37 -07:00
Peter Hawkins
b8a5473614 Add experimental support for XLA infeed/outfeed. 2019-10-09 15:05:54 -04:00
James Bradbury
6d29c4e352 remove dot primitive in favor of dot_general 2019-10-08 14:44:10 -07:00
James Bradbury
096a52a3a3 add dot_general masking rules 2019-10-08 14:44:10 -07:00
James Bradbury
658882513e avoid more transposes in dot_general batch rule 2019-10-08 14:44:02 -07:00
James Bradbury
064014b53c
Merge pull request #1374 from google/jb/abs-jvp
Improve numerics of abs jvp (and softplus)
2019-09-28 21:43:25 -04:00
Jamie Townsend
f9b9146a92 Ensure lax.scatter cache hits in op-by-op mode 2019-09-24 19:20:12 +02:00
Peter Hawkins
92c42ea1fe Use square(x) instead of pow(x, 2) in div JVP. 2019-09-23 12:46:15 -04:00
James Bradbury
b39179c887 better abs jvp 2019-09-18 23:55:31 -07:00
Matthew Johnson
99b9e48580 python2 fix for ShapeExpr slicing 2019-09-16 16:30:42 -07:00
Matthew Johnson
6662da8275 tweaks to simplify masked jaxprs, rnn test 2019-09-16 15:47:43 -07:00
Matthew Johnson
b71181d3c0 start writing nesting test 2019-09-15 11:10:05 -07:00
Matthew Johnson
283299649b add a 'monomorphic dim' symbol, bug fixes 2019-09-15 11:10:05 -07:00
Matthew Johnson
5b6b72c2fb fix broadcasting bug in rem jvp, fixes #1350 2019-09-15 08:45:58 -07:00
James Bradbury
705eb1cbcb
Merge pull request #1331 from google/jb/dot-general-batch
Remove explicit broadcasts in vmap(dot_general)
2019-09-10 14:49:17 -07:00
James Bradbury
b4b14b7e2b remove broadcasts from _dot_general_batch_rule 2019-09-10 13:58:23 -07:00
Sam Schoenholz
6f2d22fddf Tiny change to enable vmap with dimension numbers. 2019-09-08 14:19:10 -07:00
James Bradbury
35b63c740d add primitive for rsqrt 2019-09-04 15:06:46 -07:00
Matthew Johnson
96b8bb2d4d fix lax._canonicalize_shape for ShapeExprs 2019-09-03 17:18:23 -07:00
Matthew Johnson
772fdb8c4e move automasking prototype into jax/interpreters
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-09-03 17:10:17 -07:00
Matthew Johnson
fbc85af54f made polymorphic jaxprs, reshape fail 2019-09-03 17:10:17 -07:00
Matthew Johnson
e254dc43ab wip
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-09-03 17:10:17 -07:00
Matthew Johnson
cac042c34a move asinh/acosh/atanh to lax_numpy.py only 2019-08-31 22:39:51 -07:00
Matthew Johnson
478832c944 avoid Calls inside While/Cond
fixes #1267
2019-08-31 07:35:37 -07:00
Skye Wanderman-Milne
ae835b747e Add jax.devices() and friends, and add devices arg to pmap.
This change adds the following APIs:
* jax.devices(). This returns a list of available Device subclass instances.
* jax.host_id(). Currently always 0, but will be useful on multi-host platforms.
* jax.local_device_count(). Currently always equal to jax.device_count(), but
    will be useful on multi-host platforms.
* Optional `devices` argument to pmap. This can be used to specify which devices
    should be used in the replicated computation.
2019-08-26 11:46:45 -07:00
Matthew Johnson
0cc21c8d72
Merge branch 'master' into multibackend 2019-08-25 13:30:21 -07:00
Matthew Johnson
e90457d737 add dtype warnings to array-creation routines
fixes #1230
2019-08-24 08:19:05 -07:00
Anselm Levskaya
685ca6765e resolve merge conflicts with master 2019-08-22 19:56:27 -07:00
Anselm Levskaya
10e0842f47 Merge branch 'master' into multibackend 2019-08-22 19:52:29 -07:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Dougal Maclaurin
6d71396d56 Start exploring jaxprs without tuples
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Anselm Levskaya
f01fc35ce5 Make op-by-op work with all jit-returned devicearrays. 2019-08-21 00:22:53 -07:00
Anselm Levskaya
cc87fb6013 WIP: experimental multibackend jit 2019-08-19 23:45:36 -07:00
Peter Hawkins
6d357fe884 Use select instead of rem to handle index wraparound. 2019-08-15 16:41:05 -04:00