9874 Commits

Author SHA1 Message Date
Matthew Johnson
c1f71d17c0 generalize assert primitive, allow recharging 2021-12-02 14:35:23 -08:00
jax authors
40912d5d96 Merge pull request #8777 from google:yashk2810-patch-1
PiperOrigin-RevId: 413762284
2021-12-02 13:52:58 -08:00
Yash Katariya
20b70aa5e0
Update workspace file for py 3.10 release test 2021-12-02 13:46:46 -08:00
Yash Katariya
271bd3237b Add python 3.10 bazel build configs for jaxlib.
PiperOrigin-RevId: 413755904
2021-12-02 13:27:28 -08:00
jax authors
250d160640 Merge pull request #8773 from mattjj:checkify
PiperOrigin-RevId: 413749382
2021-12-02 12:59:45 -08:00
jax authors
4444a0771c Merge pull request #8760 from jakevdp:lax-numpy-changes
PiperOrigin-RevId: 413747082
2021-12-02 12:50:23 -08:00
Peter Hawkins
f306d33fda [MLIR] Implement input-output aliasing via the tf.aliasing_output attribute.
PiperOrigin-RevId: 413746577
2021-12-02 12:46:27 -08:00
Matthew Johnson
768b076420 add an assert primitive
The assert primitive has an effectful API and so it can't be staged out;
it's only a trace-time primitive. It can be discharged to the functional
form.

We might want to have separate transforms for discharging errors and for
adding error checks. But right now they're just bundled together in the
checkify transform.
2021-12-02 11:33:56 -08:00
Peter Hawkins
a28c4eb149 [MLIR] Remove option to tuple function results.
The MHLO->HLO converter handles untupled function results just fine, so we don't need to support this directly in JAX.

PiperOrigin-RevId: 413715299
2021-12-02 10:37:54 -08:00
jax authors
99182372c9 Merge pull request #8743 from jakevdp:scatter-weak-type
PiperOrigin-RevId: 413691628
2021-12-02 09:01:19 -08:00
jax authors
ea98cf42d9 Merge pull request #8754 from jakevdp:lax-dtype-none
PiperOrigin-RevId: 413690787
2021-12-02 08:57:11 -08:00
Jake VanderPlas
da319cf302 [sparse] refactor jax.experimental.sparse
Why? Better organization, and to avoid issues with circular imports.

PiperOrigin-RevId: 413679493
2021-12-02 08:02:15 -08:00
jax authors
b4dcf2b08d Merge pull request #8766 from hawkinsp:bazel421
PiperOrigin-RevId: 413671952
2021-12-02 07:33:53 -08:00
Peter Hawkins
3597b44559 Express xmap tile/untile logic in lax via xla.lower_fun().
The lax APIs are simpler and avoid the need to port the code to MHLO.

PiperOrigin-RevId: 413657577
2021-12-02 06:12:13 -08:00
Peter Hawkins
ffb7ec1651 Update Bazel to 4.2.1.
Fixes #8573
2021-12-02 09:11:38 -05:00
jax authors
404c3c7d25 Merge pull request #8718 from jakevdp:config-doc
PiperOrigin-RevId: 413630185
2021-12-02 03:14:31 -08:00
Jake VanderPlas
9d9244e33c [x64] make jax.numpy functionality respect default dtypes 2021-12-01 15:42:50 -08:00
jax authors
7869a6cb75 Merge pull request #8753 from mattjj:checkify
PiperOrigin-RevId: 413513067
2021-12-01 14:34:17 -08:00
Matthew Johnson
659f8b794f add skeleton checkify transformation 2021-12-01 10:44:58 -08:00
Jake VanderPlas
b977028022 lax.convert_element_type: better validation for new_dtype 2021-12-01 10:33:26 -08:00
Peter Hawkins
9394350727 Refactor uses of xla.call_translations to use xla.register_translation.
No functional changes intended.

PiperOrigin-RevId: 413443279
2021-12-01 09:58:43 -08:00
jax authors
610187cd33 Copybara import of the project:
--
6288d67fefec5dad8b0a42055e071d647ba1be06 by George Necula <gcnecula@gmail.com>:

Refactor the lowering of remat.

Replace the XLA-specific lowering with the same logic written
with LAX primitives. Then use this for HLO, MLIR, and jax2tf lowerings.

PiperOrigin-RevId: 413425828
2021-12-01 08:43:09 -08:00
George Necula
76d94391a1 Copybara import of the project:
--
b55e993be3aa2f0f846f6bb0935ae3d7c59a922a by George Necula <gcnecula@gmail.com>:

[jax2tf] Add tests for rng_uniform

These tests really only verify which dtypes are supported.
The actual numeric comparison is disabled, because rng_uniform
is statefull and calling multiple times produces different results.

PiperOrigin-RevId: 413420796
2021-12-01 08:18:03 -08:00
jax authors
baec8dcac0 Merge pull request #8748 from gnecula:clean_call_tf_test
PiperOrigin-RevId: 413357314
2021-12-01 02:21:24 -08:00
jax authors
b4b10f98f3 Merge pull request #8727 from gnecula:tf_fix_rng_test
PiperOrigin-RevId: 413355568
2021-12-01 02:10:55 -08:00
George Necula
5b177967e6 [jax2tf] Improved error checking for call_tf for functions with dynamic
shapes.
2021-12-01 11:57:05 +02:00
George Necula
b55e993be3 [jax2tf] Add tests for rng_uniform
These tests really only verify which dtypes are supported.
The actual numeric comparison is disabled, because rng_uniform
is statefull and calling multiple times produces different results.
2021-12-01 11:53:27 +02:00
jax authors
b7d7936e99 Merge pull request #8716 from gnecula:remat_translation_lax
PiperOrigin-RevId: 413351001
2021-12-01 01:46:36 -08:00
jax authors
4d9d6497ce Merge pull request #8742 from jakevdp:sparse-empty
PiperOrigin-RevId: 413333035
2021-11-30 23:46:15 -08:00
George Necula
6288d67fef Refactor the lowering of remat.
Replace the XLA-specific lowering with the same logic written
with LAX primitives. Then use this for HLO, MLIR, and jax2tf lowerings.
2021-12-01 09:40:37 +02:00
jax authors
88242bbff9 Merge pull request #8741 from jakevdp:dtype-none
PiperOrigin-RevId: 413331832
2021-11-30 23:39:12 -08:00
jax authors
6bb882cb56 Merge pull request #8722 from jakevdp:linalg-weak-types
PiperOrigin-RevId: 413317854
2021-11-30 21:44:55 -08:00
jax authors
800aac8fd3 Merge pull request #8681 from jakevdp:numpy-faq
PiperOrigin-RevId: 413316336
2021-11-30 21:33:37 -08:00
jax authors
3f2a4b10f6 Merge pull request #8650 from tlu7:update-qdwh-test
PiperOrigin-RevId: 413290729
2021-11-30 18:18:15 -08:00
Tianjian Lu
19554e21d3 Enable QDWH TPU tests. 2021-11-30 15:47:50 -08:00
Jake VanderPlas
47e88ded05 [x64] ensure scatter functionality preserves weak_type 2021-11-30 15:43:06 -08:00
Jake VanderPlas
0b872bb5d0 [sparse] add sparse.empty() utility to create empty matrix 2021-11-30 14:45:43 -08:00
Peter Hawkins
68e9e1c26d Consolidate more XLA-lowering logic between jit, pmap, and xmap.
Move remaining functions relating to building XLA HLO IR out of xla_bridge.py and into jax.interpreters.xla.

PiperOrigin-RevId: 413244450
2021-11-30 14:24:33 -08:00
Jake VanderPlas
03197dd298 [x64] improve consistency in handling dtype=None 2021-11-30 13:51:38 -08:00
jax authors
dd6a6f206c Merge pull request #8339 from mariogeiger:main
PiperOrigin-RevId: 413227294
2021-11-30 13:15:03 -08:00
Peter Hawkins
34ec805698 [MLIR] Fix test failures on GPU and TPU.
PiperOrigin-RevId: 413226939
2021-11-30 13:11:01 -08:00
jax authors
737d021fd2 Merge pull request #8329 from AdrienCorenflos:patch-1
PiperOrigin-RevId: 413223979
2021-11-30 12:57:58 -08:00
jax authors
6bac78d500 Merge pull request #8735 from pschuh:update-docs
PiperOrigin-RevId: 413200537
2021-11-30 11:21:41 -08:00
Jake VanderPlas
022f8ac2ee [x64] preserve weak types in jax.scipy.sparse solvers 2021-11-30 10:36:28 -08:00
Parker Schuh
46a1033311 Update device_get docs to mention parrallelism. 2021-11-30 10:20:11 -08:00
jax authors
aacee8f0c4 Merge pull request #8719 from jakevdp:pad-weak-type
PiperOrigin-RevId: 413166641
2021-11-30 09:18:48 -08:00
jax authors
8c2c054ab2 Merge pull request #8728 from LenaMartens:changelist/413128290
PiperOrigin-RevId: 413166628
2021-11-30 09:14:46 -08:00
Peter Hawkins
db3c3aae87 [JAX] Correctly propagate Python errors out of pytree code when the keys of an enum value cannot be sorted.
Also catch std::runtime_error since the pytree code may throw it.

PiperOrigin-RevId: 413160923
2021-11-30 08:50:30 -08:00
Peter Hawkins
fa411d864e [MLIR] Fix CPU test failures for MLIR lowering.
The remaining failures relate to buffer donation and xmap_p, which are not yet implemented.

Quite a few primitives still use fallback paths.

PiperOrigin-RevId: 413130158
2021-11-30 06:08:55 -08:00
Lena Martens
cb6a3f216f Leak checker: garbage collect before collecting hanging references. 2021-11-30 14:02:51 +00:00