chenyee
6839f28c6a
Fix issue #1576
2019-10-28 22:37:01 +08:00
Peter Hawkins
0d667d2727
Add tests for float16 support in lax_test.py. ( #1553 )
...
* Add tests for float16 support in lax_test.py.
Make test tolerances per-type, rather than a single tolerance based on the x64 mode.
Don't test float16 on TPU because it doesn't support float16.
Rework a number of the gradient tests. For linear primitives, increase eps and use a per-type tol.
* Perform float16 sinh and cosh in float32 precision.
More tweaks to test tolerances to get tests to pass.
* Add float16 testing to lax_numpy_test.py as well.
* Fix tolerance computation for testReducer test.
Relax tolerance for polyval.
* Relax some test tolerances further.
* Further relax test tolerances.
* Another tolerance relaxation.
* Use decorator for the upcast to fp32 for computation pattern.
Relax test tolerance for float_power.
2019-10-22 19:53:59 -04:00
Matthew Johnson
0601b8cdc7
make lax.broadcast_in_dim work on scalars
...
fixes #1548
2019-10-21 15:12:22 -07:00
Peter Hawkins
abe6990964
Add some @jit decorators to non-primitive lax functions. ( #1542 )
...
Fix the tests so they don't refer to op.__name__, which no longer has a usable value if the function has been jitted.
2019-10-21 10:56:54 -04:00
Peter Hawkins
9c23a95e6a
Add i0e and i1e Bessel functions. ( #1541 )
2019-10-21 10:30:55 -04:00
Peter Hawkins
06178d298c
Move lax.tie_in inside lax.full_like onto the fill value instead the output of lax.full. ( #1507 )
...
Fixes a bug where constants associated with relu gradients were being hoisted out of loops and materialized, causing a fairly large performance penalty (~20%) for a Resnet-50 model in a loop using infeed.
2019-10-15 15:01:52 -04:00
Peter Hawkins
4a075be62a
Merge pull request #1478 from hawkinsp/infeed
...
Add experimental support for XLA infeed/outfeed.
2019-10-09 21:09:16 -04:00
James Bradbury
9d2f25cf1a
add test
2019-10-09 17:02:11 -07:00
James Bradbury
fb433fb9d2
preserve precision config in dot_general transpose
2019-10-09 16:25:37 -07:00
Peter Hawkins
b8a5473614
Add experimental support for XLA infeed/outfeed.
2019-10-09 15:05:54 -04:00
James Bradbury
6d29c4e352
remove dot primitive in favor of dot_general
2019-10-08 14:44:10 -07:00
James Bradbury
096a52a3a3
add dot_general masking rules
2019-10-08 14:44:10 -07:00
James Bradbury
658882513e
avoid more transposes in dot_general batch rule
2019-10-08 14:44:02 -07:00
James Bradbury
064014b53c
Merge pull request #1374 from google/jb/abs-jvp
...
Improve numerics of abs jvp (and softplus)
2019-09-28 21:43:25 -04:00
Jamie Townsend
f9b9146a92
Ensure lax.scatter cache hits in op-by-op mode
2019-09-24 19:20:12 +02:00
Peter Hawkins
92c42ea1fe
Use square(x) instead of pow(x, 2) in div JVP.
2019-09-23 12:46:15 -04:00
James Bradbury
b39179c887
better abs jvp
2019-09-18 23:55:31 -07:00
Matthew Johnson
99b9e48580
python2 fix for ShapeExpr slicing
2019-09-16 16:30:42 -07:00
Matthew Johnson
6662da8275
tweaks to simplify masked jaxprs, rnn test
2019-09-16 15:47:43 -07:00
Matthew Johnson
b71181d3c0
start writing nesting test
2019-09-15 11:10:05 -07:00
Matthew Johnson
283299649b
add a 'monomorphic dim' symbol, bug fixes
2019-09-15 11:10:05 -07:00
Matthew Johnson
5b6b72c2fb
fix broadcasting bug in rem jvp, fixes #1350
2019-09-15 08:45:58 -07:00
James Bradbury
705eb1cbcb
Merge pull request #1331 from google/jb/dot-general-batch
...
Remove explicit broadcasts in vmap(dot_general)
2019-09-10 14:49:17 -07:00
James Bradbury
b4b14b7e2b
remove broadcasts from _dot_general_batch_rule
2019-09-10 13:58:23 -07:00
Sam Schoenholz
6f2d22fddf
Tiny change to enable vmap with dimension numbers.
2019-09-08 14:19:10 -07:00
James Bradbury
35b63c740d
add primitive for rsqrt
2019-09-04 15:06:46 -07:00
Matthew Johnson
96b8bb2d4d
fix lax._canonicalize_shape for ShapeExprs
2019-09-03 17:18:23 -07:00
Matthew Johnson
772fdb8c4e
move automasking prototype into jax/interpreters
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-09-03 17:10:17 -07:00
Matthew Johnson
fbc85af54f
made polymorphic jaxprs, reshape fail
2019-09-03 17:10:17 -07:00
Matthew Johnson
e254dc43ab
wip
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-09-03 17:10:17 -07:00
Matthew Johnson
cac042c34a
move asinh/acosh/atanh to lax_numpy.py only
2019-08-31 22:39:51 -07:00
Matthew Johnson
478832c944
avoid Calls inside While/Cond
...
fixes #1267
2019-08-31 07:35:37 -07:00
Skye Wanderman-Milne
ae835b747e
Add jax.devices() and friends, and add devices
arg to pmap.
...
This change adds the following APIs:
* jax.devices(). This returns a list of available Device subclass instances.
* jax.host_id(). Currently always 0, but will be useful on multi-host platforms.
* jax.local_device_count(). Currently always equal to jax.device_count(), but
will be useful on multi-host platforms.
* Optional `devices` argument to pmap. This can be used to specify which devices
should be used in the replicated computation.
2019-08-26 11:46:45 -07:00
Matthew Johnson
0cc21c8d72
Merge branch 'master' into multibackend
2019-08-25 13:30:21 -07:00
Matthew Johnson
e90457d737
add dtype warnings to array-creation routines
...
fixes #1230
2019-08-24 08:19:05 -07:00
Anselm Levskaya
685ca6765e
resolve merge conflicts with master
2019-08-22 19:56:27 -07:00
Anselm Levskaya
10e0842f47
Merge branch 'master' into multibackend
2019-08-22 19:52:29 -07:00
Matthew Johnson
b702f8de3e
De-tuplify the rest of the core
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Dougal Maclaurin
6d71396d56
Start exploring jaxprs without tuples
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Anselm Levskaya
f01fc35ce5
Make op-by-op work with all jit-returned devicearrays.
2019-08-21 00:22:53 -07:00
Anselm Levskaya
cc87fb6013
WIP: experimental multibackend jit
2019-08-19 23:45:36 -07:00
Peter Hawkins
6d357fe884
Use select
instead of rem
to handle index wraparound.
2019-08-15 16:41:05 -04:00
Peter Hawkins
932877dde6
Remove unnecessary reshape/concatenate in dynamic_slice_in_dim.
2019-08-15 13:31:37 -04:00
Peter Hawkins
099354aab0
Fix Python 2 compatibility.
2019-08-15 13:14:41 -04:00
Peter Hawkins
e28e73b38f
Address review comment.
2019-08-15 12:33:36 -04:00
Peter Hawkins
e57a5c42c5
Fix batching rule.
2019-08-15 12:24:38 -04:00
Peter Hawkins
e4a7d30741
Fix batching rule.
2019-08-15 11:42:08 -04:00
Peter Hawkins
d09924f71c
Change dynamic-slice and dynamic-update-slice primitives to have one argument per index, not a single array index.
...
XLA deprecated the single-array-of-indices form of dynamic-slices. It is preferable to use a list of scalar indices since it helps XLA generate more efficient code in the case that some indices are constant but others are not.
2019-08-15 11:26:30 -04:00
Peter Hawkins
3e78a0e290
Keep ShapedArray avals on xla.DeviceArray values
...
Makes abstractification of DeviceArray values cheaper, which is on the critical path for executing a compiled function.
2019-08-12 10:03:04 -04:00
Peter Hawkins
a8ddf071bd
Add test case for concurrent device_get and device_put calls.
...
Fix concurrency problems in memoize_... decorators.
Rename util.memoize to util.cache.
Remove util.memoize_unary and xla_bridge.memoize_thunk, replace with more general and thread-safe util.memoize that wraps fastcache.
2019-08-09 13:12:44 -04:00