6330 Commits

Author SHA1 Message Date
Benjamin Chetioui
08f2c7652b [jax2tf] Clean up and expand binary elementwise harnesses. 2020-12-16 17:24:32 +01:00
jax authors
43f603bae6 Merge pull request #5187 from bchetioui:remake_elementwise_harnesses
PiperOrigin-RevId: 347816538
2020-12-16 07:15:45 -08:00
Benjamin Chetioui
299c7b9012 [jax2tf] Relax unary elementwise test tolerance. 2020-12-16 14:51:58 +01:00
Benjamin Chetioui
1dc55aabdf [jax2tf] Update limitations with information about potentially
equally valid but different expected results.
2020-12-16 11:05:15 +01:00
jax authors
4d78d60a24 Merge pull request #5205 from gnecula:print_no_receiver
PiperOrigin-RevId: 347781549
2020-12-16 01:57:26 -08:00
jax authors
0b1771b867 Merge pull request #5196 from bchetioui:fix_round_p_harness
PiperOrigin-RevId: 347781466
2020-12-16 01:53:33 -08:00
George Necula
24d850815f [host_callback] Remove deprecated outfeed_receiver context manager 2020-12-16 11:43:05 +02:00
jax authors
7c294e62f4 Copybara import of the project:
--
7342318774c6f1195f0e238f1209425109ea8944 by Matthew Johnson <mattjj@google.com>:

check for __jax_array__ method for conversion

--
6742016382b0511f5ac9ec21f67d2122a9f37cb7 by Matthew Johnson <mattjj@google.com>:

fix typo

--
5eb36855e53d8d4e81e281d08dc9264d2671f21f by Matthew Johnson <mattjj@google.com>:

ensure some jnp funs duck-type with __jax_array__

PiperOrigin-RevId: 347763582
2020-12-15 23:13:29 -08:00
jax authors
00926928c4 Merge pull request #4725 from google:handle-dunder-array-classes
PiperOrigin-RevId: 347754954
2020-12-15 21:54:41 -08:00
Matthew Johnson
5eb36855e5 ensure some jnp funs duck-type with __jax_array__ 2020-12-15 20:46:09 -08:00
jax authors
8eedf1edc0 Merge pull request #5176 from jakevdp:isinstance-check
PiperOrigin-RevId: 347697032
2020-12-15 14:47:50 -08:00
David Majnemer
0f45d1fad3 TPU supports c64 -> c64 bitcast_convert_type
PiperOrigin-RevId: 347659193
2020-12-15 11:43:50 -08:00
jax authors
ad2349130d Merge pull request #5193 from hawkinsp:gather
PiperOrigin-RevId: 347644033
2020-12-15 10:35:50 -08:00
jax authors
aaf05264d0 Merge pull request #5180 from jblespiau:changelist/347208988
PiperOrigin-RevId: 347631638
2020-12-15 09:41:51 -08:00
Jean-Baptiste Lespiau
ca72d3dc80 Fix a typo on the dynamic definition of __hash__. 2020-12-15 15:32:55 +01:00
jax authors
a79055cb11 Merge pull request #5182 from gnecula:print_pmap
PiperOrigin-RevId: 347591016
2020-12-15 05:18:55 -08:00
Benjamin Chetioui
4f9000a78f [jax2tf] Minor fixes in round harness. 2020-12-15 13:27:48 +01:00
George Necula
20be478a6e [host_callback] Add support for pmap and for passing the device to tap
* Adds support for jit of pmap and pmap of pmap.
* Also adds a `tap_with_device` optional argument to `id_print` and
  `id_tap`, to have the tap function invoked with a device keyword argument.
* Added multiple tests involving pmap

Issue: #5134
Fixes: #5169
2020-12-15 10:46:23 +02:00
Peter Hawkins
308e7f95b0 Fix batching rule for gather where the batch dimension has size 0. 2020-12-14 22:27:34 -05:00
Matthew Johnson
6742016382 fix typo 2020-12-14 17:13:27 -08:00
Matthew Johnson
7342318774 check for __jax_array__ method for conversion 2020-12-14 17:09:25 -08:00
jax authors
34bc6ca987 Merge pull request #5191 from skye:fix_colab_tpu
PiperOrigin-RevId: 347500520
2020-12-14 17:04:58 -08:00
Skye Wanderman-Milne
85796cc7e3 Fix colab_tpu.setup_tpu import in example Cloud TPU notebooks. 2020-12-14 16:23:20 -08:00
jax authors
0ca612c552 Merge pull request #5189 from hawkinsp:herbie
PiperOrigin-RevId: 347476608
2020-12-14 15:04:05 -08:00
jax authors
4390f3beca Merge pull request #5183 from AntoinePlumerault:patch-1
PiperOrigin-RevId: 347476597
2020-12-14 15:00:44 -08:00
Jake VanderPlas
a0562dc9c9 api: handle numpy integers for static argnums 2020-12-14 14:52:51 -08:00
jax authors
931f925979 Merge pull request #5177 from minoring:pad-reflect-type
PiperOrigin-RevId: 347446703
2020-12-14 12:43:44 -08:00
Peter Hawkins
de8df3a86f Improve a few JVP rules with rewrites from Herbie. 2020-12-14 11:48:16 -05:00
jax authors
a28436bc5f Merge pull request #5163 from bchetioui:factor_test_utils
PiperOrigin-RevId: 347392991
2020-12-14 08:41:46 -08:00
jax authors
7ac3c04a00 Merge pull request #5137 from hawkinsp:dot
PiperOrigin-RevId: 347387515
2020-12-14 08:09:09 -08:00
Benjamin Chetioui
4bba27a446 [jax2tf] Remove workaround from sign_p and add limitations. 2020-12-14 12:00:36 +01:00
Benjamin Chetioui
7d4835786a [jax2tf] Expand coverage for unary elementwise ops. 2020-12-14 09:35:19 +01:00
AntoinePlumerault
269676ee4e
Typo ?
I removed "-e" option from "pip install -e dist/*.whl  # installs jaxlib (includes XLA)" line 58. It is now coherent with lines 69-70. 
When I tried the command with the "-e" it threw an error, without "-e" it worked fine.
2020-12-13 17:03:07 +01:00
Benjamin Chetioui
3a06dfec61 Set collect_limitations variable to True and remove original_impl. 2020-12-13 15:52:30 +01:00
minoring
c1d86739d1 Remove bool from odd reflect_type test 2020-12-13 09:21:05 +09:00
jax authors
91a0dd7163 Merge pull request #5121 from qiuminxu:add_tflite_example
PiperOrigin-RevId: 347177932
2020-12-12 09:53:03 -08:00
jax authors
b85e05956f Merge pull request #5161 from qiuminxu:jax2tf_avg_pool
PiperOrigin-RevId: 347167053
2020-12-12 07:10:17 -08:00
minoring
f0a248c84b Add reflect_type argument for symmetric and reflect modes in padding 2020-12-12 12:26:47 +09:00
jax authors
4d3e0fd892 Merge pull request #5160 from jakevdp:weak-types-slicing
PiperOrigin-RevId: 347094147
2020-12-11 16:08:52 -08:00
Skye Wanderman-Milne
e802bf9b63 Merge pull request #5173 from jakevdp:conv-elem-type-cleanup
PiperOrigin-RevId: 347085811
2020-12-11 15:27:40 -08:00
Jake VanderPlas
ebaa35e2b7 Preserve weak types in array slices 2020-12-11 15:06:01 -08:00
jax authors
a5df588298 Merge pull request #5172 from jakevdp:boilerplate
PiperOrigin-RevId: 347076382
2020-12-11 14:32:11 -08:00
Jake VanderPlas
de96856be4 Cleanup: remove unnecessary code in convert_element_type 2020-12-11 14:14:05 -08:00
Jake VanderPlas
1a83bb6f90 Cleanup: remove remaining instances of rng_factory boilerplate 2020-12-11 13:47:46 -08:00
jax authors
92ae9d85ce Merge pull request #5138 from jakevdp:boilerplate
PiperOrigin-RevId: 347055073
2020-12-11 12:46:26 -08:00
jax authors
940b197c39 Merge pull request #5171 from apaszke:relax-check
PiperOrigin-RevId: 347054837
2020-12-11 12:42:51 -08:00
Qiumin Xu
542f33edf3 Update mnist.py 2020-12-11 12:40:03 -08:00
Adam Paszke
146c6eb308 Check for non-tuples and not int subclasses in Chunked
Apparently some projects like to pass in instances of `numpy.int64`
where `int`s are expected, and those fail the subclass check. This
should be a hotfix for them, though it would be good to figure out where
does the NumPy scalar come from, and make them well typed.
2020-12-11 19:07:34 +00:00
Qiumin Xu
aee42a33b1 Added a non-XLA conversion path for reduce_window_sum 2020-12-11 10:13:03 -08:00
jax authors
ca468940b4 Merge pull request #5099 from apaszke:xmap-multiple-mesh-dims
PiperOrigin-RevId: 347013166
2020-12-11 09:24:51 -08:00