Jean-Baptiste Lespiau
36f9126109
Copybara import of the project:
...
--
8eac55e71e9a023ce9dce9f79461b7265ae28f00 by Jean-Baptiste Lespiau <jblespiau@google.com>:
Have the C++ path raise a cleaner error for static argnums.
Fixes #5190 .
PiperOrigin-RevId: 348102847
2020-12-17 15:15:56 -08:00
jax authors
e90de8934d
Merge pull request #5220 from jakevdp:fix-4672
...
PiperOrigin-RevId: 348091777
2020-12-17 14:14:43 -08:00
jax authors
a8518769a2
Merge pull request #5115 from inailuig:rocm-gpukernels
...
PiperOrigin-RevId: 348077827
2020-12-17 14:01:04 -08:00
Jake VanderPlas
98f88152cd
Fix bug in primitive_computation
...
fixes #4672
2020-12-17 12:51:35 -08:00
jax authors
1a5b186ceb
Merge pull request #5219 from google:issue5217
...
PiperOrigin-RevId: 348074156
2020-12-17 12:39:53 -08:00
Matthew Johnson
d7b5e3b5d4
add add_any to jet rules table
...
fixes #5217
2020-12-17 12:10:12 -08:00
jax authors
943c7794f9
Merge pull request #5207 from bchetioui:remake_binary_elementwise_harnesses
...
PiperOrigin-RevId: 348026395
2020-12-17 08:31:41 -08:00
Benjamin Chetioui
c09a73abda
[jax2tf] Systematize broadcasting tests for binary elementwise harnesses.
...
Also add broadcasting tests to min and max, and splits the logic
for add and mul.
2020-12-17 15:52:54 +01:00
jax authors
c256222dc5
Merge pull request #5214 from google:duck-shaped-args
...
PiperOrigin-RevId: 347968970
2020-12-16 23:44:43 -08:00
Roy Frostig
391967809a
document the option for duck typing in api.xla_computation example arguments
2020-12-16 21:18:58 -08:00
jax authors
7d8b7d4187
Merge pull request #5212 from zhangqiaorjc:jtf1
...
PiperOrigin-RevId: 347918677
2020-12-16 16:20:30 -08:00
Qiao Zhang
efb8915f31
Bump jax2tf test_digamma tol for CPU.
2020-12-16 14:38:16 -08:00
jax authors
3f136acc22
Merge pull request #5174 from jakevdp:lax-weak-types
...
PiperOrigin-RevId: 347889057
2020-12-16 13:43:16 -08:00
jax authors
6ea9f16426
Merge pull request #5204 from minoring:pad-empty
...
PiperOrigin-RevId: 347848088
2020-12-16 10:23:47 -08:00
Jake VanderPlas
c820dbf44c
Propagate weak_types in remaining lax primitives
2020-12-16 09:53:30 -08:00
Benjamin Chetioui
08f2c7652b
[jax2tf] Clean up and expand binary elementwise harnesses.
2020-12-16 17:24:32 +01:00
jax authors
43f603bae6
Merge pull request #5187 from bchetioui:remake_elementwise_harnesses
...
PiperOrigin-RevId: 347816538
2020-12-16 07:15:45 -08:00
Clemens Giuliani
4981c53ac1
Add BLAS and LAPACK gpu kernels for ROCm
2020-12-16 16:00:17 +01:00
Clemens Giuliani
c128bdd90c
extract the shared handle pool code from cublas and cusolver
2020-12-16 16:00:16 +01:00
Benjamin Chetioui
299c7b9012
[jax2tf] Relax unary elementwise test tolerance.
2020-12-16 14:51:58 +01:00
Benjamin Chetioui
1dc55aabdf
[jax2tf] Update limitations with information about potentially
...
equally valid but different expected results.
2020-12-16 11:05:15 +01:00
jax authors
4d78d60a24
Merge pull request #5205 from gnecula:print_no_receiver
...
PiperOrigin-RevId: 347781549
2020-12-16 01:57:26 -08:00
jax authors
0b1771b867
Merge pull request #5196 from bchetioui:fix_round_p_harness
...
PiperOrigin-RevId: 347781466
2020-12-16 01:53:33 -08:00
George Necula
24d850815f
[host_callback] Remove deprecated outfeed_receiver context manager
2020-12-16 11:43:05 +02:00
jax authors
7c294e62f4
Copybara import of the project:
...
--
7342318774c6f1195f0e238f1209425109ea8944 by Matthew Johnson <mattjj@google.com>:
check for __jax_array__ method for conversion
--
6742016382b0511f5ac9ec21f67d2122a9f37cb7 by Matthew Johnson <mattjj@google.com>:
fix typo
--
5eb36855e53d8d4e81e281d08dc9264d2671f21f by Matthew Johnson <mattjj@google.com>:
ensure some jnp funs duck-type with __jax_array__
PiperOrigin-RevId: 347763582
2020-12-15 23:13:29 -08:00
minoring
3bdcb1211a
Implement jax.numpy.pad empty mode
2020-12-16 16:06:57 +09:00
jax authors
00926928c4
Merge pull request #4725 from google:handle-dunder-array-classes
...
PiperOrigin-RevId: 347754954
2020-12-15 21:54:41 -08:00
Matthew Johnson
5eb36855e5
ensure some jnp funs duck-type with __jax_array__
2020-12-15 20:46:09 -08:00
jax authors
8eedf1edc0
Merge pull request #5176 from jakevdp:isinstance-check
...
PiperOrigin-RevId: 347697032
2020-12-15 14:47:50 -08:00
David Majnemer
0f45d1fad3
TPU supports c64 -> c64 bitcast_convert_type
...
PiperOrigin-RevId: 347659193
2020-12-15 11:43:50 -08:00
jax authors
ad2349130d
Merge pull request #5193 from hawkinsp:gather
...
PiperOrigin-RevId: 347644033
2020-12-15 10:35:50 -08:00
jax authors
aaf05264d0
Merge pull request #5180 from jblespiau:changelist/347208988
...
PiperOrigin-RevId: 347631638
2020-12-15 09:41:51 -08:00
Jean-Baptiste Lespiau
ca72d3dc80
Fix a typo on the dynamic definition of __hash__.
2020-12-15 15:32:55 +01:00
jax authors
a79055cb11
Merge pull request #5182 from gnecula:print_pmap
...
PiperOrigin-RevId: 347591016
2020-12-15 05:18:55 -08:00
Benjamin Chetioui
4f9000a78f
[jax2tf] Minor fixes in round harness.
2020-12-15 13:27:48 +01:00
George Necula
20be478a6e
[host_callback] Add support for pmap and for passing the device to tap
...
* Adds support for jit of pmap and pmap of pmap.
* Also adds a `tap_with_device` optional argument to `id_print` and
`id_tap`, to have the tap function invoked with a device keyword argument.
* Added multiple tests involving pmap
Issue: #5134
Fixes : #5169
2020-12-15 10:46:23 +02:00
Peter Hawkins
308e7f95b0
Fix batching rule for gather where the batch dimension has size 0.
2020-12-14 22:27:34 -05:00
Matthew Johnson
6742016382
fix typo
2020-12-14 17:13:27 -08:00
Matthew Johnson
7342318774
check for __jax_array__ method for conversion
2020-12-14 17:09:25 -08:00
jax authors
34bc6ca987
Merge pull request #5191 from skye:fix_colab_tpu
...
PiperOrigin-RevId: 347500520
2020-12-14 17:04:58 -08:00
Skye Wanderman-Milne
85796cc7e3
Fix colab_tpu.setup_tpu import in example Cloud TPU notebooks.
2020-12-14 16:23:20 -08:00
jax authors
0ca612c552
Merge pull request #5189 from hawkinsp:herbie
...
PiperOrigin-RevId: 347476608
2020-12-14 15:04:05 -08:00
jax authors
4390f3beca
Merge pull request #5183 from AntoinePlumerault:patch-1
...
PiperOrigin-RevId: 347476597
2020-12-14 15:00:44 -08:00
Jake VanderPlas
a0562dc9c9
api: handle numpy integers for static argnums
2020-12-14 14:52:51 -08:00
jax authors
931f925979
Merge pull request #5177 from minoring:pad-reflect-type
...
PiperOrigin-RevId: 347446703
2020-12-14 12:43:44 -08:00
Peter Hawkins
de8df3a86f
Improve a few JVP rules with rewrites from Herbie.
2020-12-14 11:48:16 -05:00
jax authors
a28436bc5f
Merge pull request #5163 from bchetioui:factor_test_utils
...
PiperOrigin-RevId: 347392991
2020-12-14 08:41:46 -08:00
jax authors
7ac3c04a00
Merge pull request #5137 from hawkinsp:dot
...
PiperOrigin-RevId: 347387515
2020-12-14 08:09:09 -08:00
Benjamin Chetioui
4bba27a446
[jax2tf] Remove workaround from sign_p and add limitations.
2020-12-14 12:00:36 +01:00
Benjamin Chetioui
7d4835786a
[jax2tf] Expand coverage for unary elementwise ops.
2020-12-14 09:35:19 +01:00