6348 Commits

Author SHA1 Message Date
Jean-Baptiste Lespiau
36f9126109 Copybara import of the project:
--
8eac55e71e9a023ce9dce9f79461b7265ae28f00 by Jean-Baptiste Lespiau <jblespiau@google.com>:

Have the C++ path raise a cleaner error for static argnums.

Fixes #5190.

PiperOrigin-RevId: 348102847
2020-12-17 15:15:56 -08:00
jax authors
e90de8934d Merge pull request #5220 from jakevdp:fix-4672
PiperOrigin-RevId: 348091777
2020-12-17 14:14:43 -08:00
jax authors
a8518769a2 Merge pull request #5115 from inailuig:rocm-gpukernels
PiperOrigin-RevId: 348077827
2020-12-17 14:01:04 -08:00
Jake VanderPlas
98f88152cd Fix bug in primitive_computation
fixes #4672
2020-12-17 12:51:35 -08:00
jax authors
1a5b186ceb Merge pull request #5219 from google:issue5217
PiperOrigin-RevId: 348074156
2020-12-17 12:39:53 -08:00
Matthew Johnson
d7b5e3b5d4 add add_any to jet rules table
fixes #5217
2020-12-17 12:10:12 -08:00
jax authors
943c7794f9 Merge pull request #5207 from bchetioui:remake_binary_elementwise_harnesses
PiperOrigin-RevId: 348026395
2020-12-17 08:31:41 -08:00
Benjamin Chetioui
c09a73abda [jax2tf] Systematize broadcasting tests for binary elementwise harnesses.
Also add broadcasting tests to min and max, and splits the logic
for add and mul.
2020-12-17 15:52:54 +01:00
jax authors
c256222dc5 Merge pull request #5214 from google:duck-shaped-args
PiperOrigin-RevId: 347968970
2020-12-16 23:44:43 -08:00
Roy Frostig
391967809a document the option for duck typing in api.xla_computation example arguments 2020-12-16 21:18:58 -08:00
jax authors
7d8b7d4187 Merge pull request #5212 from zhangqiaorjc:jtf1
PiperOrigin-RevId: 347918677
2020-12-16 16:20:30 -08:00
Qiao Zhang
efb8915f31 Bump jax2tf test_digamma tol for CPU. 2020-12-16 14:38:16 -08:00
jax authors
3f136acc22 Merge pull request #5174 from jakevdp:lax-weak-types
PiperOrigin-RevId: 347889057
2020-12-16 13:43:16 -08:00
jax authors
6ea9f16426 Merge pull request #5204 from minoring:pad-empty
PiperOrigin-RevId: 347848088
2020-12-16 10:23:47 -08:00
Jake VanderPlas
c820dbf44c Propagate weak_types in remaining lax primitives 2020-12-16 09:53:30 -08:00
Benjamin Chetioui
08f2c7652b [jax2tf] Clean up and expand binary elementwise harnesses. 2020-12-16 17:24:32 +01:00
jax authors
43f603bae6 Merge pull request #5187 from bchetioui:remake_elementwise_harnesses
PiperOrigin-RevId: 347816538
2020-12-16 07:15:45 -08:00
Clemens Giuliani
4981c53ac1 Add BLAS and LAPACK gpu kernels for ROCm 2020-12-16 16:00:17 +01:00
Clemens Giuliani
c128bdd90c extract the shared handle pool code from cublas and cusolver 2020-12-16 16:00:16 +01:00
Benjamin Chetioui
299c7b9012 [jax2tf] Relax unary elementwise test tolerance. 2020-12-16 14:51:58 +01:00
Benjamin Chetioui
1dc55aabdf [jax2tf] Update limitations with information about potentially
equally valid but different expected results.
2020-12-16 11:05:15 +01:00
jax authors
4d78d60a24 Merge pull request #5205 from gnecula:print_no_receiver
PiperOrigin-RevId: 347781549
2020-12-16 01:57:26 -08:00
jax authors
0b1771b867 Merge pull request #5196 from bchetioui:fix_round_p_harness
PiperOrigin-RevId: 347781466
2020-12-16 01:53:33 -08:00
George Necula
24d850815f [host_callback] Remove deprecated outfeed_receiver context manager 2020-12-16 11:43:05 +02:00
jax authors
7c294e62f4 Copybara import of the project:
--
7342318774c6f1195f0e238f1209425109ea8944 by Matthew Johnson <mattjj@google.com>:

check for __jax_array__ method for conversion

--
6742016382b0511f5ac9ec21f67d2122a9f37cb7 by Matthew Johnson <mattjj@google.com>:

fix typo

--
5eb36855e53d8d4e81e281d08dc9264d2671f21f by Matthew Johnson <mattjj@google.com>:

ensure some jnp funs duck-type with __jax_array__

PiperOrigin-RevId: 347763582
2020-12-15 23:13:29 -08:00
minoring
3bdcb1211a Implement jax.numpy.pad empty mode 2020-12-16 16:06:57 +09:00
jax authors
00926928c4 Merge pull request #4725 from google:handle-dunder-array-classes
PiperOrigin-RevId: 347754954
2020-12-15 21:54:41 -08:00
Matthew Johnson
5eb36855e5 ensure some jnp funs duck-type with __jax_array__ 2020-12-15 20:46:09 -08:00
jax authors
8eedf1edc0 Merge pull request #5176 from jakevdp:isinstance-check
PiperOrigin-RevId: 347697032
2020-12-15 14:47:50 -08:00
David Majnemer
0f45d1fad3 TPU supports c64 -> c64 bitcast_convert_type
PiperOrigin-RevId: 347659193
2020-12-15 11:43:50 -08:00
jax authors
ad2349130d Merge pull request #5193 from hawkinsp:gather
PiperOrigin-RevId: 347644033
2020-12-15 10:35:50 -08:00
jax authors
aaf05264d0 Merge pull request #5180 from jblespiau:changelist/347208988
PiperOrigin-RevId: 347631638
2020-12-15 09:41:51 -08:00
Jean-Baptiste Lespiau
ca72d3dc80 Fix a typo on the dynamic definition of __hash__. 2020-12-15 15:32:55 +01:00
jax authors
a79055cb11 Merge pull request #5182 from gnecula:print_pmap
PiperOrigin-RevId: 347591016
2020-12-15 05:18:55 -08:00
Benjamin Chetioui
4f9000a78f [jax2tf] Minor fixes in round harness. 2020-12-15 13:27:48 +01:00
George Necula
20be478a6e [host_callback] Add support for pmap and for passing the device to tap
* Adds support for jit of pmap and pmap of pmap.
* Also adds a `tap_with_device` optional argument to `id_print` and
  `id_tap`, to have the tap function invoked with a device keyword argument.
* Added multiple tests involving pmap

Issue: #5134
Fixes: #5169
2020-12-15 10:46:23 +02:00
Peter Hawkins
308e7f95b0 Fix batching rule for gather where the batch dimension has size 0. 2020-12-14 22:27:34 -05:00
Matthew Johnson
6742016382 fix typo 2020-12-14 17:13:27 -08:00
Matthew Johnson
7342318774 check for __jax_array__ method for conversion 2020-12-14 17:09:25 -08:00
jax authors
34bc6ca987 Merge pull request #5191 from skye:fix_colab_tpu
PiperOrigin-RevId: 347500520
2020-12-14 17:04:58 -08:00
Skye Wanderman-Milne
85796cc7e3 Fix colab_tpu.setup_tpu import in example Cloud TPU notebooks. 2020-12-14 16:23:20 -08:00
jax authors
0ca612c552 Merge pull request #5189 from hawkinsp:herbie
PiperOrigin-RevId: 347476608
2020-12-14 15:04:05 -08:00
jax authors
4390f3beca Merge pull request #5183 from AntoinePlumerault:patch-1
PiperOrigin-RevId: 347476597
2020-12-14 15:00:44 -08:00
Jake VanderPlas
a0562dc9c9 api: handle numpy integers for static argnums 2020-12-14 14:52:51 -08:00
jax authors
931f925979 Merge pull request #5177 from minoring:pad-reflect-type
PiperOrigin-RevId: 347446703
2020-12-14 12:43:44 -08:00
Peter Hawkins
de8df3a86f Improve a few JVP rules with rewrites from Herbie. 2020-12-14 11:48:16 -05:00
jax authors
a28436bc5f Merge pull request #5163 from bchetioui:factor_test_utils
PiperOrigin-RevId: 347392991
2020-12-14 08:41:46 -08:00
jax authors
7ac3c04a00 Merge pull request #5137 from hawkinsp:dot
PiperOrigin-RevId: 347387515
2020-12-14 08:09:09 -08:00
Benjamin Chetioui
4bba27a446 [jax2tf] Remove workaround from sign_p and add limitations. 2020-12-14 12:00:36 +01:00
Benjamin Chetioui
7d4835786a [jax2tf] Expand coverage for unary elementwise ops. 2020-12-14 09:35:19 +01:00