3989 Commits

Author SHA1 Message Date
George Necula
20be478a6e [host_callback] Add support for pmap and for passing the device to tap
* Adds support for jit of pmap and pmap of pmap.
* Also adds a `tap_with_device` optional argument to `id_print` and
  `id_tap`, to have the tap function invoked with a device keyword argument.
* Added multiple tests involving pmap

Issue: #5134
Fixes: #5169
2020-12-15 10:46:23 +02:00
jax authors
0ca612c552 Merge pull request #5189 from hawkinsp:herbie
PiperOrigin-RevId: 347476608
2020-12-14 15:04:05 -08:00
jax authors
931f925979 Merge pull request #5177 from minoring:pad-reflect-type
PiperOrigin-RevId: 347446703
2020-12-14 12:43:44 -08:00
Peter Hawkins
de8df3a86f Improve a few JVP rules with rewrites from Herbie. 2020-12-14 11:48:16 -05:00
jax authors
a28436bc5f Merge pull request #5163 from bchetioui:factor_test_utils
PiperOrigin-RevId: 347392991
2020-12-14 08:41:46 -08:00
jax authors
7ac3c04a00 Merge pull request #5137 from hawkinsp:dot
PiperOrigin-RevId: 347387515
2020-12-14 08:09:09 -08:00
Benjamin Chetioui
3a06dfec61 Set collect_limitations variable to True and remove original_impl. 2020-12-13 15:52:30 +01:00
minoring
c1d86739d1 Remove bool from odd reflect_type test 2020-12-13 09:21:05 +09:00
jax authors
91a0dd7163 Merge pull request #5121 from qiuminxu:add_tflite_example
PiperOrigin-RevId: 347177932
2020-12-12 09:53:03 -08:00
jax authors
b85e05956f Merge pull request #5161 from qiuminxu:jax2tf_avg_pool
PiperOrigin-RevId: 347167053
2020-12-12 07:10:17 -08:00
minoring
f0a248c84b Add reflect_type argument for symmetric and reflect modes in padding 2020-12-12 12:26:47 +09:00
jax authors
4d3e0fd892 Merge pull request #5160 from jakevdp:weak-types-slicing
PiperOrigin-RevId: 347094147
2020-12-11 16:08:52 -08:00
Skye Wanderman-Milne
e802bf9b63 Merge pull request #5173 from jakevdp:conv-elem-type-cleanup
PiperOrigin-RevId: 347085811
2020-12-11 15:27:40 -08:00
Jake VanderPlas
ebaa35e2b7 Preserve weak types in array slices 2020-12-11 15:06:01 -08:00
Jake VanderPlas
de96856be4 Cleanup: remove unnecessary code in convert_element_type 2020-12-11 14:14:05 -08:00
Qiumin Xu
542f33edf3 Update mnist.py 2020-12-11 12:40:03 -08:00
Adam Paszke
146c6eb308 Check for non-tuples and not int subclasses in Chunked
Apparently some projects like to pass in instances of `numpy.int64`
where `int`s are expected, and those fail the subclass check. This
should be a hotfix for them, though it would be good to figure out where
does the NumPy scalar come from, and make them well typed.
2020-12-11 19:07:34 +00:00
Qiumin Xu
aee42a33b1 Added a non-XLA conversion path for reduce_window_sum 2020-12-11 10:13:03 -08:00
jax authors
ca468940b4 Merge pull request #5099 from apaszke:xmap-multiple-mesh-dims
PiperOrigin-RevId: 347013166
2020-12-11 09:24:51 -08:00
jax authors
7ee2710d07 Merge pull request #5157 from jakevdp:weak-types-creation
PiperOrigin-RevId: 346996341
2020-12-11 07:44:22 -08:00
Adam Paszke
8fcacd645c Support mapping a single logical axis to multiple mesh axes in xmap 2020-12-11 14:35:31 +00:00
Benjamin Chetioui
b6114a011f [jax2tf] Factorize and clean up test utils. 2020-12-11 14:27:17 +01:00
Adam Paszke
8438ffe998 Make vectorization the default xmap schedule
It's been really annoying to have to spell out all the `vectorize`
schedule components, especially when dealing with lots of axes. And the
utility of checking that an axis is _not_ vectorized seems quite
limited.

Finally, one could argue that vectorization really is _the default_.
`xmap` can be seen as a generalization of `einsum`, or as a way to
expose programming with named axes instead of positional axes. In both
those cases, the focus is on shifting how the program is expressed,
without saying anything about the lowering or execution strategy. But,
vectorization is the default of the whole ecosystem, given the emphasis
most libraries put on broadcasting semantics. So it also makes sense to
adopt it here.
2020-12-11 11:26:21 +00:00
Adam Paszke
136f9eca6e Make xmap in/out_axes take in native Python containers
instead of forcing the pytree leaves to be the custom container exported
by xmap. This makes the API a bit less verbose, and also relaxes it so
that the mapping can be both specified through a dict (mapping
positional axes to named axes), as well as a list.

All this thanks to the recent pytree changes that let us terminate the
flattening before reaching true leaves.
2020-12-11 11:26:21 +00:00
Adam Paszke
f3bfdf8968 Expose is_leaf predicate for pytree.flatten
and add tests for it. The change has already been landed in the TF code,
where the C++ pytree components live. This is why I needed to bump the
commit.
2020-12-11 11:26:18 +00:00
Qiumin Xu
41d8cc729e Add a tflite mnist example for jax2tf. 2020-12-10 21:19:16 -08:00
Jake VanderPlas
99c69ec9c2 Propagate weak types in jnp.array, jnp.full, jnp.full_like, jnp.zeros_like, jnp.ones_like 2020-12-10 17:16:48 -08:00
Jake VanderPlas
33c2dc1153 Propagate weak types in jnp.array() 2020-12-10 17:16:48 -08:00
Jake VanderPlas
7b097340bf Fix lax.convert_element_type() with dtype=None 2020-12-10 14:14:36 -08:00
jax authors
245c6d2339 Merge pull request #4850 from jakevdp:weak-types-propagation
PiperOrigin-RevId: 346850541
2020-12-10 13:16:18 -08:00
Jake VanderPlas
eb9adcc92a jnp.arccosh: use same complex branch as numpy 2020-12-10 11:50:14 -08:00
Jake VanderPlas
c63097bc90 Add weak_type argument to convert_element_type_p 2020-12-10 11:10:21 -08:00
jax authors
24a27e07f5 Merge pull request #5150 from jakevdp:allow-singular
PiperOrigin-RevId: 346814700
2020-12-10 10:32:38 -08:00
jax authors
5520a28a8a Merge pull request #5132 from gnecula:tf_colab
PiperOrigin-RevId: 346759909
2020-12-10 04:55:32 -08:00
Jake VanderPlas
8a8a48a926 multivariate_normal.logpdf: add (unimplemented) allow_singular argument 2020-12-09 11:39:06 -08:00
jax authors
2f4f940347 Merge pull request #5110 from minoring:linear-ramp-pad
PiperOrigin-RevId: 346592848
2020-12-09 11:13:15 -08:00
Jake VanderPlas
f74235cdae X32 tests: fail on dtype warnings 2020-12-08 13:03:30 -08:00
Peter Hawkins
450747cf72 Only use sum-of-products lowering of integer/bool dots on CPU.
Other XLA backends support integer dots.
2020-12-08 12:31:01 -05:00
George Necula
5f42929357 [jax2tf] Deprecate the getting_started colab 2020-12-08 13:37:54 +02:00
jax authors
fa15380d97 Merge pull request #5130 from hawkinsp:trig
PiperOrigin-RevId: 346277512
2020-12-08 01:59:36 -08:00
Peter Hawkins
706d78553c Upgrade trigonometric functions to primitives.
This allows for slightly simpler derivatives than differentiating their implementations.
2020-12-07 17:34:27 -05:00
jax authors
b1be39499e Merge pull request #5085 from jakevdp:conv-elem-type-devicearray
PiperOrigin-RevId: 346163986
2020-12-07 13:14:18 -08:00
jax authors
2a699d0b04 Merge pull request #5038 from skye:sharded_jit_namespace
PiperOrigin-RevId: 346121390
2020-12-07 10:18:56 -08:00
jax authors
47e4f062f6 Merge pull request #4639 from petebu:changelist/335389665
PiperOrigin-RevId: 346110001
2020-12-07 09:29:32 -08:00
Jake VanderPlas
8a00b4e0ee lax.convert_element_type: always return DeviceArray 2020-12-07 09:10:34 -08:00
jax authors
73f68c9791 Merge pull request #5124 from bchetioui:bump_jax2tf_version
PiperOrigin-RevId: 346080664
2020-12-07 06:40:12 -08:00
Jean-Baptiste Lespiau
57f1bab09b Return PyBuffer directly from the C++ jax.jit.
PiperOrigin-RevId: 346080315
2020-12-07 06:39:50 -08:00
Benjamin Chetioui
9af7edb331 [jax2tf] Bump TF nightly version in README. 2020-12-07 15:31:21 +01:00
jax authors
d8aabdb8c6 Merge pull request #5059 from jpuigcerver:all-to-all-groups
PiperOrigin-RevId: 346060720
2020-12-07 04:05:26 -08:00
Joan Puigcerver
85fbc6d790 Add axis_index_groups argument to all_to_all. 2020-12-07 11:52:42 +00:00