George Necula
20be478a6e
[host_callback] Add support for pmap and for passing the device to tap
...
* Adds support for jit of pmap and pmap of pmap.
* Also adds a `tap_with_device` optional argument to `id_print` and
`id_tap`, to have the tap function invoked with a device keyword argument.
* Added multiple tests involving pmap
Issue: #5134
Fixes : #5169
2020-12-15 10:46:23 +02:00
jax authors
0ca612c552
Merge pull request #5189 from hawkinsp:herbie
...
PiperOrigin-RevId: 347476608
2020-12-14 15:04:05 -08:00
jax authors
931f925979
Merge pull request #5177 from minoring:pad-reflect-type
...
PiperOrigin-RevId: 347446703
2020-12-14 12:43:44 -08:00
Peter Hawkins
de8df3a86f
Improve a few JVP rules with rewrites from Herbie.
2020-12-14 11:48:16 -05:00
jax authors
a28436bc5f
Merge pull request #5163 from bchetioui:factor_test_utils
...
PiperOrigin-RevId: 347392991
2020-12-14 08:41:46 -08:00
jax authors
7ac3c04a00
Merge pull request #5137 from hawkinsp:dot
...
PiperOrigin-RevId: 347387515
2020-12-14 08:09:09 -08:00
Benjamin Chetioui
3a06dfec61
Set collect_limitations variable to True and remove original_impl.
2020-12-13 15:52:30 +01:00
minoring
c1d86739d1
Remove bool from odd reflect_type test
2020-12-13 09:21:05 +09:00
jax authors
91a0dd7163
Merge pull request #5121 from qiuminxu:add_tflite_example
...
PiperOrigin-RevId: 347177932
2020-12-12 09:53:03 -08:00
jax authors
b85e05956f
Merge pull request #5161 from qiuminxu:jax2tf_avg_pool
...
PiperOrigin-RevId: 347167053
2020-12-12 07:10:17 -08:00
minoring
f0a248c84b
Add reflect_type argument for symmetric and reflect modes in padding
2020-12-12 12:26:47 +09:00
jax authors
4d3e0fd892
Merge pull request #5160 from jakevdp:weak-types-slicing
...
PiperOrigin-RevId: 347094147
2020-12-11 16:08:52 -08:00
Skye Wanderman-Milne
e802bf9b63
Merge pull request #5173 from jakevdp:conv-elem-type-cleanup
...
PiperOrigin-RevId: 347085811
2020-12-11 15:27:40 -08:00
Jake VanderPlas
ebaa35e2b7
Preserve weak types in array slices
2020-12-11 15:06:01 -08:00
Jake VanderPlas
de96856be4
Cleanup: remove unnecessary code in convert_element_type
2020-12-11 14:14:05 -08:00
Qiumin Xu
542f33edf3
Update mnist.py
2020-12-11 12:40:03 -08:00
Adam Paszke
146c6eb308
Check for non-tuples and not int subclasses in Chunked
...
Apparently some projects like to pass in instances of `numpy.int64`
where `int`s are expected, and those fail the subclass check. This
should be a hotfix for them, though it would be good to figure out where
does the NumPy scalar come from, and make them well typed.
2020-12-11 19:07:34 +00:00
Qiumin Xu
aee42a33b1
Added a non-XLA conversion path for reduce_window_sum
2020-12-11 10:13:03 -08:00
jax authors
ca468940b4
Merge pull request #5099 from apaszke:xmap-multiple-mesh-dims
...
PiperOrigin-RevId: 347013166
2020-12-11 09:24:51 -08:00
jax authors
7ee2710d07
Merge pull request #5157 from jakevdp:weak-types-creation
...
PiperOrigin-RevId: 346996341
2020-12-11 07:44:22 -08:00
Adam Paszke
8fcacd645c
Support mapping a single logical axis to multiple mesh axes in xmap
2020-12-11 14:35:31 +00:00
Benjamin Chetioui
b6114a011f
[jax2tf] Factorize and clean up test utils.
2020-12-11 14:27:17 +01:00
Adam Paszke
8438ffe998
Make vectorization the default xmap schedule
...
It's been really annoying to have to spell out all the `vectorize`
schedule components, especially when dealing with lots of axes. And the
utility of checking that an axis is _not_ vectorized seems quite
limited.
Finally, one could argue that vectorization really is _the default_.
`xmap` can be seen as a generalization of `einsum`, or as a way to
expose programming with named axes instead of positional axes. In both
those cases, the focus is on shifting how the program is expressed,
without saying anything about the lowering or execution strategy. But,
vectorization is the default of the whole ecosystem, given the emphasis
most libraries put on broadcasting semantics. So it also makes sense to
adopt it here.
2020-12-11 11:26:21 +00:00
Adam Paszke
136f9eca6e
Make xmap in/out_axes take in native Python containers
...
instead of forcing the pytree leaves to be the custom container exported
by xmap. This makes the API a bit less verbose, and also relaxes it so
that the mapping can be both specified through a dict (mapping
positional axes to named axes), as well as a list.
All this thanks to the recent pytree changes that let us terminate the
flattening before reaching true leaves.
2020-12-11 11:26:21 +00:00
Adam Paszke
f3bfdf8968
Expose is_leaf
predicate for pytree.flatten
...
and add tests for it. The change has already been landed in the TF code,
where the C++ pytree components live. This is why I needed to bump the
commit.
2020-12-11 11:26:18 +00:00
Qiumin Xu
41d8cc729e
Add a tflite mnist example for jax2tf.
2020-12-10 21:19:16 -08:00
Jake VanderPlas
99c69ec9c2
Propagate weak types in jnp.array, jnp.full, jnp.full_like, jnp.zeros_like, jnp.ones_like
2020-12-10 17:16:48 -08:00
Jake VanderPlas
33c2dc1153
Propagate weak types in jnp.array()
2020-12-10 17:16:48 -08:00
Jake VanderPlas
7b097340bf
Fix lax.convert_element_type() with dtype=None
2020-12-10 14:14:36 -08:00
jax authors
245c6d2339
Merge pull request #4850 from jakevdp:weak-types-propagation
...
PiperOrigin-RevId: 346850541
2020-12-10 13:16:18 -08:00
Jake VanderPlas
eb9adcc92a
jnp.arccosh: use same complex branch as numpy
2020-12-10 11:50:14 -08:00
Jake VanderPlas
c63097bc90
Add weak_type argument to convert_element_type_p
2020-12-10 11:10:21 -08:00
jax authors
24a27e07f5
Merge pull request #5150 from jakevdp:allow-singular
...
PiperOrigin-RevId: 346814700
2020-12-10 10:32:38 -08:00
jax authors
5520a28a8a
Merge pull request #5132 from gnecula:tf_colab
...
PiperOrigin-RevId: 346759909
2020-12-10 04:55:32 -08:00
Jake VanderPlas
8a8a48a926
multivariate_normal.logpdf: add (unimplemented) allow_singular argument
2020-12-09 11:39:06 -08:00
jax authors
2f4f940347
Merge pull request #5110 from minoring:linear-ramp-pad
...
PiperOrigin-RevId: 346592848
2020-12-09 11:13:15 -08:00
Jake VanderPlas
f74235cdae
X32 tests: fail on dtype warnings
2020-12-08 13:03:30 -08:00
Peter Hawkins
450747cf72
Only use sum-of-products lowering of integer/bool dots on CPU.
...
Other XLA backends support integer dots.
2020-12-08 12:31:01 -05:00
George Necula
5f42929357
[jax2tf] Deprecate the getting_started colab
2020-12-08 13:37:54 +02:00
jax authors
fa15380d97
Merge pull request #5130 from hawkinsp:trig
...
PiperOrigin-RevId: 346277512
2020-12-08 01:59:36 -08:00
Peter Hawkins
706d78553c
Upgrade trigonometric functions to primitives.
...
This allows for slightly simpler derivatives than differentiating their implementations.
2020-12-07 17:34:27 -05:00
jax authors
b1be39499e
Merge pull request #5085 from jakevdp:conv-elem-type-devicearray
...
PiperOrigin-RevId: 346163986
2020-12-07 13:14:18 -08:00
jax authors
2a699d0b04
Merge pull request #5038 from skye:sharded_jit_namespace
...
PiperOrigin-RevId: 346121390
2020-12-07 10:18:56 -08:00
jax authors
47e4f062f6
Merge pull request #4639 from petebu:changelist/335389665
...
PiperOrigin-RevId: 346110001
2020-12-07 09:29:32 -08:00
Jake VanderPlas
8a00b4e0ee
lax.convert_element_type: always return DeviceArray
2020-12-07 09:10:34 -08:00
jax authors
73f68c9791
Merge pull request #5124 from bchetioui:bump_jax2tf_version
...
PiperOrigin-RevId: 346080664
2020-12-07 06:40:12 -08:00
Jean-Baptiste Lespiau
57f1bab09b
Return PyBuffer directly from the C++ jax.jit.
...
PiperOrigin-RevId: 346080315
2020-12-07 06:39:50 -08:00
Benjamin Chetioui
9af7edb331
[jax2tf] Bump TF nightly version in README.
2020-12-07 15:31:21 +01:00
jax authors
d8aabdb8c6
Merge pull request #5059 from jpuigcerver:all-to-all-groups
...
PiperOrigin-RevId: 346060720
2020-12-07 04:05:26 -08:00
Joan Puigcerver
85fbc6d790
Add axis_index_groups argument to all_to_all.
2020-12-07 11:52:42 +00:00