James Bradbury
6d29c4e352
remove dot primitive in favor of dot_general
2019-10-08 14:44:10 -07:00
James Bradbury
096a52a3a3
add dot_general masking rules
2019-10-08 14:44:10 -07:00
James Bradbury
658882513e
avoid more transposes in dot_general batch rule
2019-10-08 14:44:02 -07:00
James Bradbury
064014b53c
Merge pull request #1374 from google/jb/abs-jvp
...
Improve numerics of abs jvp (and softplus)
2019-09-28 21:43:25 -04:00
Jamie Townsend
f9b9146a92
Ensure lax.scatter cache hits in op-by-op mode
2019-09-24 19:20:12 +02:00
Peter Hawkins
92c42ea1fe
Use square(x) instead of pow(x, 2) in div JVP.
2019-09-23 12:46:15 -04:00
James Bradbury
b39179c887
better abs jvp
2019-09-18 23:55:31 -07:00
Matthew Johnson
99b9e48580
python2 fix for ShapeExpr slicing
2019-09-16 16:30:42 -07:00
Matthew Johnson
6662da8275
tweaks to simplify masked jaxprs, rnn test
2019-09-16 15:47:43 -07:00
Matthew Johnson
b71181d3c0
start writing nesting test
2019-09-15 11:10:05 -07:00
Matthew Johnson
283299649b
add a 'monomorphic dim' symbol, bug fixes
2019-09-15 11:10:05 -07:00
Matthew Johnson
5b6b72c2fb
fix broadcasting bug in rem jvp, fixes #1350
2019-09-15 08:45:58 -07:00
James Bradbury
705eb1cbcb
Merge pull request #1331 from google/jb/dot-general-batch
...
Remove explicit broadcasts in vmap(dot_general)
2019-09-10 14:49:17 -07:00
James Bradbury
b4b14b7e2b
remove broadcasts from _dot_general_batch_rule
2019-09-10 13:58:23 -07:00
Sam Schoenholz
6f2d22fddf
Tiny change to enable vmap with dimension numbers.
2019-09-08 14:19:10 -07:00
James Bradbury
35b63c740d
add primitive for rsqrt
2019-09-04 15:06:46 -07:00
Matthew Johnson
96b8bb2d4d
fix lax._canonicalize_shape for ShapeExprs
2019-09-03 17:18:23 -07:00
Matthew Johnson
772fdb8c4e
move automasking prototype into jax/interpreters
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-09-03 17:10:17 -07:00
Matthew Johnson
fbc85af54f
made polymorphic jaxprs, reshape fail
2019-09-03 17:10:17 -07:00
Matthew Johnson
e254dc43ab
wip
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-09-03 17:10:17 -07:00
Matthew Johnson
cac042c34a
move asinh/acosh/atanh to lax_numpy.py only
2019-08-31 22:39:51 -07:00
Matthew Johnson
478832c944
avoid Calls inside While/Cond
...
fixes #1267
2019-08-31 07:35:37 -07:00
Skye Wanderman-Milne
ae835b747e
Add jax.devices() and friends, and add devices
arg to pmap.
...
This change adds the following APIs:
* jax.devices(). This returns a list of available Device subclass instances.
* jax.host_id(). Currently always 0, but will be useful on multi-host platforms.
* jax.local_device_count(). Currently always equal to jax.device_count(), but
will be useful on multi-host platforms.
* Optional `devices` argument to pmap. This can be used to specify which devices
should be used in the replicated computation.
2019-08-26 11:46:45 -07:00
Matthew Johnson
0cc21c8d72
Merge branch 'master' into multibackend
2019-08-25 13:30:21 -07:00
Matthew Johnson
e90457d737
add dtype warnings to array-creation routines
...
fixes #1230
2019-08-24 08:19:05 -07:00
Anselm Levskaya
685ca6765e
resolve merge conflicts with master
2019-08-22 19:56:27 -07:00
Anselm Levskaya
10e0842f47
Merge branch 'master' into multibackend
2019-08-22 19:52:29 -07:00
Matthew Johnson
b702f8de3e
De-tuplify the rest of the core
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Dougal Maclaurin
6d71396d56
Start exploring jaxprs without tuples
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Anselm Levskaya
f01fc35ce5
Make op-by-op work with all jit-returned devicearrays.
2019-08-21 00:22:53 -07:00
Anselm Levskaya
cc87fb6013
WIP: experimental multibackend jit
2019-08-19 23:45:36 -07:00
Peter Hawkins
6d357fe884
Use select
instead of rem
to handle index wraparound.
2019-08-15 16:41:05 -04:00
Peter Hawkins
932877dde6
Remove unnecessary reshape/concatenate in dynamic_slice_in_dim.
2019-08-15 13:31:37 -04:00
Peter Hawkins
099354aab0
Fix Python 2 compatibility.
2019-08-15 13:14:41 -04:00
Peter Hawkins
e28e73b38f
Address review comment.
2019-08-15 12:33:36 -04:00
Peter Hawkins
e57a5c42c5
Fix batching rule.
2019-08-15 12:24:38 -04:00
Peter Hawkins
e4a7d30741
Fix batching rule.
2019-08-15 11:42:08 -04:00
Peter Hawkins
d09924f71c
Change dynamic-slice and dynamic-update-slice primitives to have one argument per index, not a single array index.
...
XLA deprecated the single-array-of-indices form of dynamic-slices. It is preferable to use a list of scalar indices since it helps XLA generate more efficient code in the case that some indices are constant but others are not.
2019-08-15 11:26:30 -04:00
Peter Hawkins
3e78a0e290
Keep ShapedArray avals on xla.DeviceArray values
...
Makes abstractification of DeviceArray values cheaper, which is on the critical path for executing a compiled function.
2019-08-12 10:03:04 -04:00
Peter Hawkins
a8ddf071bd
Add test case for concurrent device_get and device_put calls.
...
Fix concurrency problems in memoize_... decorators.
Rename util.memoize to util.cache.
Remove util.memoize_unary and xla_bridge.memoize_thunk, replace with more general and thread-safe util.memoize that wraps fastcache.
2019-08-09 13:12:44 -04:00
Peter Hawkins
51eb67f755
pmap optimization: Don't precompute size and ndim on DeviceArrays.
...
We don't even look at them most of the time, and they are in the critical path for running jit/pmap code.
Saves ~1-2ms on a pmap microbenchmark.
2019-08-06 10:29:25 -04:00
James Bradbury
a26963fe87
Merge pull request #1106 from google/jb/bool-reduction
...
fix jax.numpy reduction init_val for bools
2019-08-05 10:45:17 -07:00
Peter Hawkins
0ef05d7586
Cleanups to xla_bridge.py
...
Remove stringification of dtypes. The NumPy dtype handling bug has to do with types with different hashes comparing as equal. This only does not happen between two np.dtype objects; it is sufficient to ismply ensure we actually have an np.dtype rather than something dtype-like (e.g., a string or NumPy type object).
Remove xla_bridge.infeed_put, which is unused.
Remove xla_bridge.Shape (use xla_client.Shape instead).
Remove xla_bridge.dtype_to_etype_exact (use xla_client.dtype_to_etype instead).
Remove xla_bridge.device_put (inlined the definition into its callers)
Remove xla_bridge.make_tuple (inlined the definition into its callers).
2019-08-04 12:52:39 -04:00
James Bradbury
d0c9f45349
fix jax.numpy reduction init_val for bools
2019-08-03 21:27:06 -07:00
fehiepsi
1b490fb5e0
Merge remote-tracking branch 'upstream/master' into sort
2019-08-01 12:39:53 -04:00
fehiepsi
e1ee87b559
add batching rule for lax.sort
2019-08-01 12:39:33 -04:00
Matthew Johnson
0600b738f4
fix symbolic zero handling in _pad_transpose
...
tested manually against example from @matthewdhoffman
2019-07-31 13:27:19 -07:00
Peter Hawkins
d0644d6a3a
Remove old xla_data_pb2 compatibility shim.
2019-07-29 15:21:47 -04:00
Peter Hawkins
510a9167c5
Add C++ implementation of pytree logic.
...
Move jaxlib version test into jax/lib/__init__.py. Make jax/lib mirror the structure of jaxlib; e.g., xla_client is now available as jax.lib.xla_client.
2019-07-29 15:06:05 -04:00
Peter Hawkins
2369d1fe61
Increase minimum Jaxlib version to 0.1.22.
...
Remove code that preserves backward compatibility with older jaxlib versions.
2019-07-23 21:45:41 -04:00