Matthew Johnson
c1f71d17c0
generalize assert primitive, allow recharging
2021-12-02 14:35:23 -08:00
jax authors
40912d5d96
Merge pull request #8777 from google:yashk2810-patch-1
...
PiperOrigin-RevId: 413762284
2021-12-02 13:52:58 -08:00
Yash Katariya
20b70aa5e0
Update workspace file for py 3.10 release test
2021-12-02 13:46:46 -08:00
Yash Katariya
271bd3237b
Add python 3.10 bazel build configs for jaxlib.
...
PiperOrigin-RevId: 413755904
2021-12-02 13:27:28 -08:00
jax authors
250d160640
Merge pull request #8773 from mattjj:checkify
...
PiperOrigin-RevId: 413749382
2021-12-02 12:59:45 -08:00
jax authors
4444a0771c
Merge pull request #8760 from jakevdp:lax-numpy-changes
...
PiperOrigin-RevId: 413747082
2021-12-02 12:50:23 -08:00
Peter Hawkins
f306d33fda
[MLIR] Implement input-output aliasing via the tf.aliasing_output attribute.
...
PiperOrigin-RevId: 413746577
2021-12-02 12:46:27 -08:00
Matthew Johnson
768b076420
add an assert primitive
...
The assert primitive has an effectful API and so it can't be staged out;
it's only a trace-time primitive. It can be discharged to the functional
form.
We might want to have separate transforms for discharging errors and for
adding error checks. But right now they're just bundled together in the
checkify transform.
2021-12-02 11:33:56 -08:00
Peter Hawkins
a28c4eb149
[MLIR] Remove option to tuple function results.
...
The MHLO->HLO converter handles untupled function results just fine, so we don't need to support this directly in JAX.
PiperOrigin-RevId: 413715299
2021-12-02 10:37:54 -08:00
jax authors
99182372c9
Merge pull request #8743 from jakevdp:scatter-weak-type
...
PiperOrigin-RevId: 413691628
2021-12-02 09:01:19 -08:00
jax authors
ea98cf42d9
Merge pull request #8754 from jakevdp:lax-dtype-none
...
PiperOrigin-RevId: 413690787
2021-12-02 08:57:11 -08:00
Jake VanderPlas
da319cf302
[sparse] refactor jax.experimental.sparse
...
Why? Better organization, and to avoid issues with circular imports.
PiperOrigin-RevId: 413679493
2021-12-02 08:02:15 -08:00
jax authors
b4dcf2b08d
Merge pull request #8766 from hawkinsp:bazel421
...
PiperOrigin-RevId: 413671952
2021-12-02 07:33:53 -08:00
Peter Hawkins
3597b44559
Express xmap tile/untile logic in lax via xla.lower_fun().
...
The lax APIs are simpler and avoid the need to port the code to MHLO.
PiperOrigin-RevId: 413657577
2021-12-02 06:12:13 -08:00
Peter Hawkins
ffb7ec1651
Update Bazel to 4.2.1.
...
Fixes #8573
2021-12-02 09:11:38 -05:00
jax authors
404c3c7d25
Merge pull request #8718 from jakevdp:config-doc
...
PiperOrigin-RevId: 413630185
2021-12-02 03:14:31 -08:00
Jake VanderPlas
9d9244e33c
[x64] make jax.numpy functionality respect default dtypes
2021-12-01 15:42:50 -08:00
jax authors
7869a6cb75
Merge pull request #8753 from mattjj:checkify
...
PiperOrigin-RevId: 413513067
2021-12-01 14:34:17 -08:00
Matthew Johnson
659f8b794f
add skeleton checkify transformation
2021-12-01 10:44:58 -08:00
Jake VanderPlas
b977028022
lax.convert_element_type: better validation for new_dtype
2021-12-01 10:33:26 -08:00
Peter Hawkins
9394350727
Refactor uses of xla.call_translations to use xla.register_translation.
...
No functional changes intended.
PiperOrigin-RevId: 413443279
2021-12-01 09:58:43 -08:00
jax authors
610187cd33
Copybara import of the project:
...
--
6288d67fefec5dad8b0a42055e071d647ba1be06 by George Necula <gcnecula@gmail.com>:
Refactor the lowering of remat.
Replace the XLA-specific lowering with the same logic written
with LAX primitives. Then use this for HLO, MLIR, and jax2tf lowerings.
PiperOrigin-RevId: 413425828
2021-12-01 08:43:09 -08:00
George Necula
76d94391a1
Copybara import of the project:
...
--
b55e993be3aa2f0f846f6bb0935ae3d7c59a922a by George Necula <gcnecula@gmail.com>:
[jax2tf] Add tests for rng_uniform
These tests really only verify which dtypes are supported.
The actual numeric comparison is disabled, because rng_uniform
is statefull and calling multiple times produces different results.
PiperOrigin-RevId: 413420796
2021-12-01 08:18:03 -08:00
jax authors
baec8dcac0
Merge pull request #8748 from gnecula:clean_call_tf_test
...
PiperOrigin-RevId: 413357314
2021-12-01 02:21:24 -08:00
jax authors
b4b10f98f3
Merge pull request #8727 from gnecula:tf_fix_rng_test
...
PiperOrigin-RevId: 413355568
2021-12-01 02:10:55 -08:00
George Necula
5b177967e6
[jax2tf] Improved error checking for call_tf for functions with dynamic
...
shapes.
2021-12-01 11:57:05 +02:00
George Necula
b55e993be3
[jax2tf] Add tests for rng_uniform
...
These tests really only verify which dtypes are supported.
The actual numeric comparison is disabled, because rng_uniform
is statefull and calling multiple times produces different results.
2021-12-01 11:53:27 +02:00
jax authors
b7d7936e99
Merge pull request #8716 from gnecula:remat_translation_lax
...
PiperOrigin-RevId: 413351001
2021-12-01 01:46:36 -08:00
jax authors
4d9d6497ce
Merge pull request #8742 from jakevdp:sparse-empty
...
PiperOrigin-RevId: 413333035
2021-11-30 23:46:15 -08:00
George Necula
6288d67fef
Refactor the lowering of remat.
...
Replace the XLA-specific lowering with the same logic written
with LAX primitives. Then use this for HLO, MLIR, and jax2tf lowerings.
2021-12-01 09:40:37 +02:00
jax authors
88242bbff9
Merge pull request #8741 from jakevdp:dtype-none
...
PiperOrigin-RevId: 413331832
2021-11-30 23:39:12 -08:00
jax authors
6bb882cb56
Merge pull request #8722 from jakevdp:linalg-weak-types
...
PiperOrigin-RevId: 413317854
2021-11-30 21:44:55 -08:00
jax authors
800aac8fd3
Merge pull request #8681 from jakevdp:numpy-faq
...
PiperOrigin-RevId: 413316336
2021-11-30 21:33:37 -08:00
jax authors
3f2a4b10f6
Merge pull request #8650 from tlu7:update-qdwh-test
...
PiperOrigin-RevId: 413290729
2021-11-30 18:18:15 -08:00
Tianjian Lu
19554e21d3
Enable QDWH TPU tests.
2021-11-30 15:47:50 -08:00
Jake VanderPlas
47e88ded05
[x64] ensure scatter functionality preserves weak_type
2021-11-30 15:43:06 -08:00
Jake VanderPlas
0b872bb5d0
[sparse] add sparse.empty() utility to create empty matrix
2021-11-30 14:45:43 -08:00
Peter Hawkins
68e9e1c26d
Consolidate more XLA-lowering logic between jit, pmap, and xmap.
...
Move remaining functions relating to building XLA HLO IR out of xla_bridge.py and into jax.interpreters.xla.
PiperOrigin-RevId: 413244450
2021-11-30 14:24:33 -08:00
Jake VanderPlas
03197dd298
[x64] improve consistency in handling dtype=None
2021-11-30 13:51:38 -08:00
jax authors
dd6a6f206c
Merge pull request #8339 from mariogeiger:main
...
PiperOrigin-RevId: 413227294
2021-11-30 13:15:03 -08:00
Peter Hawkins
34ec805698
[MLIR] Fix test failures on GPU and TPU.
...
PiperOrigin-RevId: 413226939
2021-11-30 13:11:01 -08:00
jax authors
737d021fd2
Merge pull request #8329 from AdrienCorenflos:patch-1
...
PiperOrigin-RevId: 413223979
2021-11-30 12:57:58 -08:00
jax authors
6bac78d500
Merge pull request #8735 from pschuh:update-docs
...
PiperOrigin-RevId: 413200537
2021-11-30 11:21:41 -08:00
Jake VanderPlas
022f8ac2ee
[x64] preserve weak types in jax.scipy.sparse solvers
2021-11-30 10:36:28 -08:00
Parker Schuh
46a1033311
Update device_get docs to mention parrallelism.
2021-11-30 10:20:11 -08:00
jax authors
aacee8f0c4
Merge pull request #8719 from jakevdp:pad-weak-type
...
PiperOrigin-RevId: 413166641
2021-11-30 09:18:48 -08:00
jax authors
8c2c054ab2
Merge pull request #8728 from LenaMartens:changelist/413128290
...
PiperOrigin-RevId: 413166628
2021-11-30 09:14:46 -08:00
Peter Hawkins
db3c3aae87
[JAX] Correctly propagate Python errors out of pytree code when the keys of an enum value cannot be sorted.
...
Also catch std::runtime_error since the pytree code may throw it.
PiperOrigin-RevId: 413160923
2021-11-30 08:50:30 -08:00
Peter Hawkins
fa411d864e
[MLIR] Fix CPU test failures for MLIR lowering.
...
The remaining failures relate to buffer donation and xmap_p, which are not yet implemented.
Quite a few primitives still use fallback paths.
PiperOrigin-RevId: 413130158
2021-11-30 06:08:55 -08:00
Lena Martens
cb6a3f216f
Leak checker: garbage collect before collecting hanging references.
2021-11-30 14:02:51 +00:00