141 Commits

Author SHA1 Message Date
James Bradbury
6d29c4e352 remove dot primitive in favor of dot_general 2019-10-08 14:44:10 -07:00
James Bradbury
096a52a3a3 add dot_general masking rules 2019-10-08 14:44:10 -07:00
James Bradbury
658882513e avoid more transposes in dot_general batch rule 2019-10-08 14:44:02 -07:00
James Bradbury
064014b53c
Merge pull request #1374 from google/jb/abs-jvp
Improve numerics of abs jvp (and softplus)
2019-09-28 21:43:25 -04:00
Jamie Townsend
f9b9146a92 Ensure lax.scatter cache hits in op-by-op mode 2019-09-24 19:20:12 +02:00
Peter Hawkins
92c42ea1fe Use square(x) instead of pow(x, 2) in div JVP. 2019-09-23 12:46:15 -04:00
James Bradbury
b39179c887 better abs jvp 2019-09-18 23:55:31 -07:00
Matthew Johnson
99b9e48580 python2 fix for ShapeExpr slicing 2019-09-16 16:30:42 -07:00
Matthew Johnson
6662da8275 tweaks to simplify masked jaxprs, rnn test 2019-09-16 15:47:43 -07:00
Matthew Johnson
b71181d3c0 start writing nesting test 2019-09-15 11:10:05 -07:00
Matthew Johnson
283299649b add a 'monomorphic dim' symbol, bug fixes 2019-09-15 11:10:05 -07:00
Matthew Johnson
5b6b72c2fb fix broadcasting bug in rem jvp, fixes #1350 2019-09-15 08:45:58 -07:00
James Bradbury
705eb1cbcb
Merge pull request #1331 from google/jb/dot-general-batch
Remove explicit broadcasts in vmap(dot_general)
2019-09-10 14:49:17 -07:00
James Bradbury
b4b14b7e2b remove broadcasts from _dot_general_batch_rule 2019-09-10 13:58:23 -07:00
Sam Schoenholz
6f2d22fddf Tiny change to enable vmap with dimension numbers. 2019-09-08 14:19:10 -07:00
James Bradbury
35b63c740d add primitive for rsqrt 2019-09-04 15:06:46 -07:00
Matthew Johnson
96b8bb2d4d fix lax._canonicalize_shape for ShapeExprs 2019-09-03 17:18:23 -07:00
Matthew Johnson
772fdb8c4e move automasking prototype into jax/interpreters
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-09-03 17:10:17 -07:00
Matthew Johnson
fbc85af54f made polymorphic jaxprs, reshape fail 2019-09-03 17:10:17 -07:00
Matthew Johnson
e254dc43ab wip
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-09-03 17:10:17 -07:00
Matthew Johnson
cac042c34a move asinh/acosh/atanh to lax_numpy.py only 2019-08-31 22:39:51 -07:00
Matthew Johnson
478832c944 avoid Calls inside While/Cond
fixes #1267
2019-08-31 07:35:37 -07:00
Skye Wanderman-Milne
ae835b747e Add jax.devices() and friends, and add devices arg to pmap.
This change adds the following APIs:
* jax.devices(). This returns a list of available Device subclass instances.
* jax.host_id(). Currently always 0, but will be useful on multi-host platforms.
* jax.local_device_count(). Currently always equal to jax.device_count(), but
    will be useful on multi-host platforms.
* Optional `devices` argument to pmap. This can be used to specify which devices
    should be used in the replicated computation.
2019-08-26 11:46:45 -07:00
Matthew Johnson
0cc21c8d72
Merge branch 'master' into multibackend 2019-08-25 13:30:21 -07:00
Matthew Johnson
e90457d737 add dtype warnings to array-creation routines
fixes #1230
2019-08-24 08:19:05 -07:00
Anselm Levskaya
685ca6765e resolve merge conflicts with master 2019-08-22 19:56:27 -07:00
Anselm Levskaya
10e0842f47 Merge branch 'master' into multibackend 2019-08-22 19:52:29 -07:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Dougal Maclaurin
6d71396d56 Start exploring jaxprs without tuples
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Anselm Levskaya
f01fc35ce5 Make op-by-op work with all jit-returned devicearrays. 2019-08-21 00:22:53 -07:00
Anselm Levskaya
cc87fb6013 WIP: experimental multibackend jit 2019-08-19 23:45:36 -07:00
Peter Hawkins
6d357fe884 Use select instead of rem to handle index wraparound. 2019-08-15 16:41:05 -04:00
Peter Hawkins
932877dde6 Remove unnecessary reshape/concatenate in dynamic_slice_in_dim. 2019-08-15 13:31:37 -04:00
Peter Hawkins
099354aab0 Fix Python 2 compatibility. 2019-08-15 13:14:41 -04:00
Peter Hawkins
e28e73b38f Address review comment. 2019-08-15 12:33:36 -04:00
Peter Hawkins
e57a5c42c5 Fix batching rule. 2019-08-15 12:24:38 -04:00
Peter Hawkins
e4a7d30741 Fix batching rule. 2019-08-15 11:42:08 -04:00
Peter Hawkins
d09924f71c Change dynamic-slice and dynamic-update-slice primitives to have one argument per index, not a single array index.
XLA deprecated the single-array-of-indices form of dynamic-slices. It is preferable to use a list of scalar indices since it helps XLA generate more efficient code in the case that some indices are constant but others are not.
2019-08-15 11:26:30 -04:00
Peter Hawkins
3e78a0e290 Keep ShapedArray avals on xla.DeviceArray values
Makes abstractification of DeviceArray values cheaper, which is on the critical path for executing a compiled function.
2019-08-12 10:03:04 -04:00
Peter Hawkins
a8ddf071bd Add test case for concurrent device_get and device_put calls.
Fix concurrency problems in memoize_... decorators.
Rename util.memoize to util.cache.
Remove util.memoize_unary and xla_bridge.memoize_thunk, replace with more general and thread-safe util.memoize that wraps fastcache.
2019-08-09 13:12:44 -04:00
Peter Hawkins
51eb67f755 pmap optimization: Don't precompute size and ndim on DeviceArrays.
We don't even look at them most of the time, and they are in the critical path for running jit/pmap code.

Saves ~1-2ms on a pmap microbenchmark.
2019-08-06 10:29:25 -04:00
James Bradbury
a26963fe87
Merge pull request #1106 from google/jb/bool-reduction
fix jax.numpy reduction init_val for bools
2019-08-05 10:45:17 -07:00
Peter Hawkins
0ef05d7586 Cleanups to xla_bridge.py
Remove stringification of dtypes. The NumPy dtype handling bug has to do with types with different hashes comparing as equal. This only does not happen between two np.dtype objects; it is sufficient to ismply ensure we actually have an np.dtype rather than something dtype-like (e.g., a string or NumPy type object).
Remove xla_bridge.infeed_put, which is unused.
Remove xla_bridge.Shape (use xla_client.Shape instead).
Remove xla_bridge.dtype_to_etype_exact (use xla_client.dtype_to_etype instead).
Remove xla_bridge.device_put (inlined the definition into its callers)
Remove xla_bridge.make_tuple (inlined the definition into its callers).
2019-08-04 12:52:39 -04:00
James Bradbury
d0c9f45349 fix jax.numpy reduction init_val for bools 2019-08-03 21:27:06 -07:00
fehiepsi
1b490fb5e0 Merge remote-tracking branch 'upstream/master' into sort 2019-08-01 12:39:53 -04:00
fehiepsi
e1ee87b559 add batching rule for lax.sort 2019-08-01 12:39:33 -04:00
Matthew Johnson
0600b738f4 fix symbolic zero handling in _pad_transpose
tested manually against example from @matthewdhoffman
2019-07-31 13:27:19 -07:00
Peter Hawkins
d0644d6a3a Remove old xla_data_pb2 compatibility shim. 2019-07-29 15:21:47 -04:00
Peter Hawkins
510a9167c5 Add C++ implementation of pytree logic.
Move jaxlib version test into jax/lib/__init__.py. Make jax/lib mirror the structure of jaxlib; e.g., xla_client is now available as jax.lib.xla_client.
2019-07-29 15:06:05 -04:00
Peter Hawkins
2369d1fe61 Increase minimum Jaxlib version to 0.1.22.
Remove code that preserves backward compatibility with older jaxlib versions.
2019-07-23 21:45:41 -04:00