8384 Commits

Author SHA1 Message Date
tlu7
d97b393694 Adds spherical harmonics.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2021-07-02 10:42:29 -07:00
jax authors
c97d63dec3 Merge pull request #7130 from google:custom-jvp-hotfix
PiperOrigin-RevId: 382665377
2021-07-01 21:23:17 -07:00
jax authors
7c45105d17 Merge pull request #7166 from jakevdp:pin-pillow
PiperOrigin-RevId: 382633599
2021-07-01 16:56:36 -07:00
Jake VanderPlas
4ba343aa83 CI: pin pillow dependency to 8.2 to avoid failures under 8.3 2021-07-01 16:32:35 -07:00
jax authors
b8b89d6ef7 Merge pull request #7150 from jakevdp:sparse-nnz
PiperOrigin-RevId: 382538132
2021-07-01 09:16:21 -07:00
jax authors
8061ae1676 Merge pull request #7100 from google:remat-fix-3
PiperOrigin-RevId: 382472685
2021-07-01 01:38:01 -07:00
Jake VanderPlas
76f9e6f016 [sparse] globally change nnz->nse 2021-06-30 17:46:02 -07:00
Matthew Johnson
a0eb1126e4 remat: don't apply cse-foiling widget to primal 2021-06-30 09:29:47 -07:00
jax authors
5978797528 Merge pull request #7137 from shawwn:patch-2
PiperOrigin-RevId: 382253683
2021-06-30 01:32:12 -07:00
jax authors
faef96a802 Merge pull request #7132 from shawwn:patch-1
PiperOrigin-RevId: 382198478
2021-06-29 17:20:57 -07:00
Shawn Presser
7b879c811d Fix typo in xmap_tutorial 2021-06-29 17:09:57 -05:00
Shawn Presser
76773798b2 Fix broken link in jax-101/06-parallelism 2021-06-29 17:02:45 -05:00
jax authors
90704c962f Merge pull request #7113 from zhangqiaorjc:pmap_vlog
PiperOrigin-RevId: 382137105
2021-06-29 12:13:59 -07:00
jax authors
4787825cc8 Merge pull request #7127 from zhangqiaorjc:updatechangelog
PiperOrigin-RevId: 382136863
2021-06-29 12:09:03 -07:00
jax authors
d4a7049718 Merge pull request #7126 from jakevdp:install-zsh
PiperOrigin-RevId: 382107958
2021-06-29 10:00:22 -07:00
jax authors
cbfa539053 Merge pull request #7105 from jakevdp:suppress-warning
PiperOrigin-RevId: 382071028
2021-06-29 06:25:59 -07:00
jax authors
653c9517bc Merge pull request #7138 from gnecula:tf_fix_float0
PiperOrigin-RevId: 382070092
2021-06-29 06:18:51 -07:00
jax authors
5fefdf1f30 Merge pull request #7133 from majnemer:patch-4
PiperOrigin-RevId: 382070089
2021-06-29 06:18:33 -07:00
jax authors
76bb26bb84 Merge pull request #7073 from colemanliyah:thread_saftey
PiperOrigin-RevId: 382069195
2021-06-29 06:11:13 -07:00
George Necula
ffd8fb84b4 [jax2tf] A few fixed for handling of float0 in jax2tf and call_tf
TF returns None or 0 for the gradients of functions with integer
arguments. JAX expects float0. We must convert to and from float0
at the JAX-TF boundary.
2021-06-29 15:48:24 +03:00
George Necula
238e8d0ac1 Improve the XlaDot shape inference rule for partially known shapes.
Added more shape inference tests.

PiperOrigin-RevId: 382032084
2021-06-29 01:19:43 -07:00
David Majnemer
781f85b09c
Fix broken markdown
A backtick was missing.
2021-06-28 23:35:26 -07:00
James Bradbury
e5d84522b7
custom_jvp named shape hotfix
When we added "avals with names", we intended to start by making the distinction between types with and without named axes load-bearing only in specific parts of the system, while (continuing to) ignore it elsewhere. This fixes a spot I missed, and that a user ran into.

Most likely, we'll want to restore something like this typecheck after vmap and pmap use avals with names; for now, the typecheck won't always be satisfied in those contexts and needs to be loosened.
2021-06-28 19:47:50 -07:00
Liyah Coleman
0ea0525740 implementing thread saftey check 2021-06-28 23:23:49 +00:00
Qiao Zhang
61ab59c40a Update changelog for jax and jaxlib releases. 2021-06-28 13:52:19 -07:00
Jake VanderPlas
5e2681939a DOC: update installation instructions for compatibility with zsh 2021-06-28 12:57:24 -07:00
Jake VanderPlas
c8e571ad84 Allow suppression of GPU warning via jax_platform_name 2021-06-28 12:54:21 -07:00
jax authors
0d68dbd619 Merge pull request #7124 from jakevdp:fix-flake
PiperOrigin-RevId: 381913589
2021-06-28 11:51:45 -07:00
Jake VanderPlas
5ed9471b9a flake: fix unused import 2021-06-28 11:40:23 -07:00
jax authors
cb8582c63d Merge pull request #6929 from jakevdp:sparsify
PiperOrigin-RevId: 381897788
2021-06-28 10:43:07 -07:00
jax authors
698555f082 Merge pull request #7119 from gnecula:call_tf_eager
PiperOrigin-RevId: 381826406
2021-06-28 03:28:47 -07:00
jax authors
d80f9d69f1 Merge pull request #7121 from gnecula:tf_reduce
PiperOrigin-RevId: 381825396
2021-06-28 03:19:49 -07:00
George Necula
44b95426d1 [jax2tf] Fix the conversion of reduce_sum and reduce_prod for booleans
Also update the documentation
2021-06-28 08:40:43 +02:00
George Necula
33b5164be9 [call_tf] Improved call_tf for op-by-op executions.
There are two major improvements here. First we ensure that
in op-by-op execution we can even execute functions that are not
compileable. We do this by ensuring that we do not trace the
TF function to a graph too early.

The other improvement is to work around some bugs in the TF shape
inference. Some TF graphs has unknown output shapes even when traced
with known inputs shapes. This happens even for some graph that
are generated by jax2tf, which we know should have known shapes. To
work around this, we get the output shapes for the TF function using
the XLA compiler, which is more reliably able to figure out the output
shapes. We do this even during abstract evaluation of the call_tf
primitive, and we use caching to ensure we do not call the TF
compiler repeatedly.
2021-06-27 16:40:45 +02:00
Qiao Zhang
57669bf401 Lazy eval for vlogging on pmap/pjit critical path.
See absl/logging comments
https://cs.opensource.google/bazel/bazel/+/master:third_party/py/abseil/absl/logging/__init__.py;l=44?q=absl%20logging
2021-06-25 19:35:16 -07:00
George Karpenkov
a50c2732b6 Reland bf16 gemm: rollback of rollback
PiperOrigin-RevId: 381579643
2021-06-25 17:47:03 -07:00
jax authors
f647b659a7 Merge pull request #7112 from hawkinsp:docs
PiperOrigin-RevId: 381574720
2021-06-25 17:04:43 -07:00
Peter Hawkins
b746d021c8 Disable xmap_tutorial to fix doc CI build. 2021-06-25 19:55:23 -04:00
jax authors
9a22498a91 Merge pull request #7023 from jakevdp:dot-general-sampled
PiperOrigin-RevId: 381547760
2021-06-25 14:27:15 -07:00
Jake VanderPlas
157d0121ef [sparse] add bcoo_dot_general_sampled primitive 2021-06-25 11:07:48 -07:00
Jake VanderPlas
0401d2be57 Add experimental sparsify transform
Co-authored-by: Roy Frostig <frostig@google.com>
2021-06-25 10:45:16 -07:00
jax authors
96575216a7 Merge pull request #6981 from pschuh:parallel-multidevice
PiperOrigin-RevId: 381377185
2021-06-24 18:09:55 -07:00
jax authors
e8695c20e7 Merge pull request #7101 from inailuig:linearize_lifted_jvp_pytree
PiperOrigin-RevId: 381361421
2021-06-24 16:30:59 -07:00
Clemens Giuliani
3041c18250 turn lifted_jvp into a PyTree 2021-06-24 23:55:49 +02:00
jax authors
df3cc0d980 Merge pull request #7086 from jakevdp:install-doc
PiperOrigin-RevId: 381328146
2021-06-24 13:43:14 -07:00
Jake VanderPlas
3c727ab604 DOC: update install instructions for GPU & TPU 2021-06-24 11:22:12 -07:00
jax authors
f3fb0c4e8f Merge pull request #7094 from hawkinsp:numpy
PiperOrigin-RevId: 381266625
2021-06-24 09:07:24 -07:00
jax authors
380ab090ab Merge pull request #7096 from hawkinsp:numpy2
PiperOrigin-RevId: 381266444
2021-06-24 09:03:45 -07:00
Peter Hawkins
988a2cd3d5 Relax NumPy version check in tests.
Only look at the major, minor, and patch parts of the NumPy version. There might be non-integer parts of the version for development builds of NumPy.
2021-06-24 11:30:31 -04:00
Peter Hawkins
15fe683945 Disable float0 tests that fail under NumPy 1.21.
https://github.com/numpy/numpy/issues/19305
2021-06-24 11:30:16 -04:00