151 Commits

Author SHA1 Message Date
chenyee
6839f28c6a Fix issue #1576 2019-10-28 22:37:01 +08:00
Peter Hawkins
0d667d2727
Add tests for float16 support in lax_test.py. (#1553)
* Add tests for float16 support in lax_test.py.

Make test tolerances per-type, rather than a single tolerance based on the x64 mode.
Don't test float16 on TPU because it doesn't support float16.
Rework a number of the gradient tests. For linear primitives, increase eps and use a per-type tol.

* Perform float16 sinh and cosh in float32 precision.
More tweaks to test tolerances to get tests to pass.

* Add float16 testing to lax_numpy_test.py as well.

* Fix tolerance computation for testReducer test.
Relax tolerance for polyval.

* Relax some test tolerances further.

* Further relax test tolerances.

* Another tolerance relaxation.

* Use decorator for the upcast to fp32 for computation pattern.

Relax test tolerance for float_power.
2019-10-22 19:53:59 -04:00
Matthew Johnson
0601b8cdc7 make lax.broadcast_in_dim work on scalars
fixes #1548
2019-10-21 15:12:22 -07:00
Peter Hawkins
abe6990964
Add some @jit decorators to non-primitive lax functions. (#1542)
Fix the tests so they don't refer to op.__name__, which no longer has a usable value if the function has been jitted.
2019-10-21 10:56:54 -04:00
Peter Hawkins
9c23a95e6a
Add i0e and i1e Bessel functions. (#1541) 2019-10-21 10:30:55 -04:00
Peter Hawkins
06178d298c
Move lax.tie_in inside lax.full_like onto the fill value instead the output of lax.full. (#1507)
Fixes a bug where constants associated with relu gradients were being hoisted out of loops and materialized, causing a fairly large performance penalty (~20%) for a Resnet-50 model in a loop using infeed.
2019-10-15 15:01:52 -04:00
Peter Hawkins
4a075be62a
Merge pull request #1478 from hawkinsp/infeed
Add experimental support for XLA infeed/outfeed.
2019-10-09 21:09:16 -04:00
James Bradbury
9d2f25cf1a add test 2019-10-09 17:02:11 -07:00
James Bradbury
fb433fb9d2 preserve precision config in dot_general transpose 2019-10-09 16:25:37 -07:00
Peter Hawkins
b8a5473614 Add experimental support for XLA infeed/outfeed. 2019-10-09 15:05:54 -04:00
James Bradbury
6d29c4e352 remove dot primitive in favor of dot_general 2019-10-08 14:44:10 -07:00
James Bradbury
096a52a3a3 add dot_general masking rules 2019-10-08 14:44:10 -07:00
James Bradbury
658882513e avoid more transposes in dot_general batch rule 2019-10-08 14:44:02 -07:00
James Bradbury
064014b53c
Merge pull request #1374 from google/jb/abs-jvp
Improve numerics of abs jvp (and softplus)
2019-09-28 21:43:25 -04:00
Jamie Townsend
f9b9146a92 Ensure lax.scatter cache hits in op-by-op mode 2019-09-24 19:20:12 +02:00
Peter Hawkins
92c42ea1fe Use square(x) instead of pow(x, 2) in div JVP. 2019-09-23 12:46:15 -04:00
James Bradbury
b39179c887 better abs jvp 2019-09-18 23:55:31 -07:00
Matthew Johnson
99b9e48580 python2 fix for ShapeExpr slicing 2019-09-16 16:30:42 -07:00
Matthew Johnson
6662da8275 tweaks to simplify masked jaxprs, rnn test 2019-09-16 15:47:43 -07:00
Matthew Johnson
b71181d3c0 start writing nesting test 2019-09-15 11:10:05 -07:00
Matthew Johnson
283299649b add a 'monomorphic dim' symbol, bug fixes 2019-09-15 11:10:05 -07:00
Matthew Johnson
5b6b72c2fb fix broadcasting bug in rem jvp, fixes #1350 2019-09-15 08:45:58 -07:00
James Bradbury
705eb1cbcb
Merge pull request #1331 from google/jb/dot-general-batch
Remove explicit broadcasts in vmap(dot_general)
2019-09-10 14:49:17 -07:00
James Bradbury
b4b14b7e2b remove broadcasts from _dot_general_batch_rule 2019-09-10 13:58:23 -07:00
Sam Schoenholz
6f2d22fddf Tiny change to enable vmap with dimension numbers. 2019-09-08 14:19:10 -07:00
James Bradbury
35b63c740d add primitive for rsqrt 2019-09-04 15:06:46 -07:00
Matthew Johnson
96b8bb2d4d fix lax._canonicalize_shape for ShapeExprs 2019-09-03 17:18:23 -07:00
Matthew Johnson
772fdb8c4e move automasking prototype into jax/interpreters
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-09-03 17:10:17 -07:00
Matthew Johnson
fbc85af54f made polymorphic jaxprs, reshape fail 2019-09-03 17:10:17 -07:00
Matthew Johnson
e254dc43ab wip
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-09-03 17:10:17 -07:00
Matthew Johnson
cac042c34a move asinh/acosh/atanh to lax_numpy.py only 2019-08-31 22:39:51 -07:00
Matthew Johnson
478832c944 avoid Calls inside While/Cond
fixes #1267
2019-08-31 07:35:37 -07:00
Skye Wanderman-Milne
ae835b747e Add jax.devices() and friends, and add devices arg to pmap.
This change adds the following APIs:
* jax.devices(). This returns a list of available Device subclass instances.
* jax.host_id(). Currently always 0, but will be useful on multi-host platforms.
* jax.local_device_count(). Currently always equal to jax.device_count(), but
    will be useful on multi-host platforms.
* Optional `devices` argument to pmap. This can be used to specify which devices
    should be used in the replicated computation.
2019-08-26 11:46:45 -07:00
Matthew Johnson
0cc21c8d72
Merge branch 'master' into multibackend 2019-08-25 13:30:21 -07:00
Matthew Johnson
e90457d737 add dtype warnings to array-creation routines
fixes #1230
2019-08-24 08:19:05 -07:00
Anselm Levskaya
685ca6765e resolve merge conflicts with master 2019-08-22 19:56:27 -07:00
Anselm Levskaya
10e0842f47 Merge branch 'master' into multibackend 2019-08-22 19:52:29 -07:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Dougal Maclaurin
6d71396d56 Start exploring jaxprs without tuples
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Anselm Levskaya
f01fc35ce5 Make op-by-op work with all jit-returned devicearrays. 2019-08-21 00:22:53 -07:00
Anselm Levskaya
cc87fb6013 WIP: experimental multibackend jit 2019-08-19 23:45:36 -07:00
Peter Hawkins
6d357fe884 Use select instead of rem to handle index wraparound. 2019-08-15 16:41:05 -04:00
Peter Hawkins
932877dde6 Remove unnecessary reshape/concatenate in dynamic_slice_in_dim. 2019-08-15 13:31:37 -04:00
Peter Hawkins
099354aab0 Fix Python 2 compatibility. 2019-08-15 13:14:41 -04:00
Peter Hawkins
e28e73b38f Address review comment. 2019-08-15 12:33:36 -04:00
Peter Hawkins
e57a5c42c5 Fix batching rule. 2019-08-15 12:24:38 -04:00
Peter Hawkins
e4a7d30741 Fix batching rule. 2019-08-15 11:42:08 -04:00
Peter Hawkins
d09924f71c Change dynamic-slice and dynamic-update-slice primitives to have one argument per index, not a single array index.
XLA deprecated the single-array-of-indices form of dynamic-slices. It is preferable to use a list of scalar indices since it helps XLA generate more efficient code in the case that some indices are constant but others are not.
2019-08-15 11:26:30 -04:00
Peter Hawkins
3e78a0e290 Keep ShapedArray avals on xla.DeviceArray values
Makes abstractification of DeviceArray values cheaper, which is on the critical path for executing a compiled function.
2019-08-12 10:03:04 -04:00
Peter Hawkins
a8ddf071bd Add test case for concurrent device_get and device_put calls.
Fix concurrency problems in memoize_... decorators.
Rename util.memoize to util.cache.
Remove util.memoize_unary and xla_bridge.memoize_thunk, replace with more general and thread-safe util.memoize that wraps fastcache.
2019-08-09 13:12:44 -04:00