Sergei Lebedev
6e23c14f85
jax.debug.callback now passes arguments as jax.Arrays
...
Prior to this change the behavior in eager and under jax.jit was inconsistent
>>> (lambda *args: jax.debug.callback(print, *args))([42])
[42]
>>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42])
[array(42, dtype=int32)]
It was also inconsistent with other callback APIs, which cast the arguments
to jax.Arrays.
Closes #20627 .
PiperOrigin-RevId: 626461904
2024-04-19 13:57:18 -07:00
Sergei Lebedev
32922f61e9
jax.debug.callback now requires a Callable[..., None]
...
This makes the "return value is ignored" behavior explicit in the type.
PiperOrigin-RevId: 626430448
2024-04-19 11:55:08 -07:00
Jieying Luo
b2375fa7e9
[PJRT C API] Add stream extension to support DLPack and implement this extension in CUDA plugin.
...
PiperOrigin-RevId: 626408630
2024-04-19 10:41:55 -07:00
George Necula
cea36a0438
[jax2tf] Adjust tolerance for asinh test.
...
This test has started to fail in compiled mode, for complex128, but with small errors (1e-14).
Adjust the tolerance for both the native and non-native serialization mode.
PiperOrigin-RevId: 626373781
2024-04-19 08:31:53 -07:00
jax authors
c7517b832a
Merge pull request #20825 from jakevdp:gammasgn
...
PiperOrigin-RevId: 626347182
2024-04-19 06:25:36 -07:00
Jake VanderPlas
41fa67c2dc
Finalize deprecation of zero-dimensional inputs to jnp.nonzero
...
PiperOrigin-RevId: 626299531
2024-04-19 02:19:10 -07:00
Yash Katariya
837f0bbf6f
Cache the _check_sharding
check in device_put. If aval and sharding are the same, no need to check multiple times
...
PiperOrigin-RevId: 626244240
2024-04-18 21:26:35 -07:00
jax authors
8fec8a600f
Update XLA dependency to use revision
...
628a8c9490
.
PiperOrigin-RevId: 626235955
2024-04-18 20:33:07 -07:00
Jake VanderPlas
568db105ea
Add jax.scipy.special.gammasgn
2024-04-18 16:14:55 -07:00
jax authors
4d6c4d63d5
Merge pull request #20823 from sagelywizard:patch-1
...
PiperOrigin-RevId: 626175454
2024-04-18 15:59:07 -07:00
Benjamin Bastian
afc87785a6
Remove outdated section in cloud_tpu_colabs README
...
Colab now has TPU VMs, so this section is now out of date.
2024-04-18 15:48:13 -07:00
Marvin Kim
90e9e47a55
[Jax/Triton] Skip benchmarking while autotuning for configs that cannot be launched.
...
For configs that cannot be launched, we should not launch them via benchmark.
PiperOrigin-RevId: 626153377
2024-04-18 14:35:51 -07:00
jax authors
de7d3b6838
Merge pull request #20816 from superbobry:int4
...
PiperOrigin-RevId: 626131730
2024-04-18 13:26:10 -07:00
Yue Sheng
c2d4373535
Make core.Token
a non-trivial class which wraps a jax.Array
. Currently, we use a singleton and empty core.token
object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
...
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).
PiperOrigin-RevId: 626091210
2024-04-18 11:09:55 -07:00
jax authors
9c9e805e82
[Pallas TPU] Generalize while_loop lowering in Pallas -> Mosaic.
...
The existing lowering path supports only while_loops which can be converted to fori_loop.
That path makes it significantly easier to optimize and unroll, but cannot support a large class of interesting loop formulations.
This patch draws from the Pallas -> Triton while_loop lowering rule to support such loops in Pallas.
Matching is still performed against fori_loop, to lower via that mechanism if possible -- as it is likely more straightforwardly optimizable compared to general "while".
PiperOrigin-RevId: 626089180
2024-04-18 11:03:52 -07:00
jax authors
6ca69f3824
Merge pull request #20537 from jakevdp:update-tpu-readme
...
PiperOrigin-RevId: 626065975
2024-04-18 09:51:46 -07:00
Yash Katariya
aaac1d8f46
Skip test_spmd_preserves_input_sharding_vmap_grad
unless xla_extension_version >= 258
...
PiperOrigin-RevId: 626062015
2024-04-18 09:37:57 -07:00
jax authors
51763d8b5d
Fix bug in rank-deficient fix-up code: Do not zero out the corresponding column of u_out if a diagonal entry of r is exactly zero.
...
PiperOrigin-RevId: 626056825
2024-04-18 09:20:48 -07:00
Sergei Lebedev
a13efc2815
Added int4 and uint4 to dtype-specific tests
...
I probably missed some cases, so this PR is really just the first step in
making sure we have good *int4 coverage.
2024-04-18 15:20:20 +01:00
jax authors
bb8cf34a31
Document the fact that jax.clear_caches() doesn't affect the persistent cache.
...
PiperOrigin-RevId: 626019057
2024-04-18 06:52:40 -07:00
Adam Paszke
fa66e731e6
Increase sharding to avoid timeouts
...
PiperOrigin-RevId: 626008096
2024-04-18 06:04:41 -07:00
Adam Paszke
8e3f5b1018
Initial commit for Mosaic GPU
...
Moving this to JAX to make it easier to explore Pallas integration.
PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
jax authors
c4dea624cc
Update XLA dependency to use revision
...
fe9679250b
.
PiperOrigin-RevId: 625887588
2024-04-17 19:59:57 -07:00
jax authors
2a65cd38d7
Merge pull request #20806 from jakevdp:tweak-docs
...
PiperOrigin-RevId: 625869820
2024-04-17 18:22:42 -07:00
Jake VanderPlas
48e8457b81
DOC: one last readthrough of the new 101 tutorials
2024-04-17 16:08:38 -07:00
jax authors
1a8aae0e00
Merge pull request #20768 from chaserileyroberts:chase/eval_jaxpr_in_extend
...
PiperOrigin-RevId: 625830536
2024-04-17 15:48:31 -07:00
Chase Roberts
74c2e25314
Add more imports to jax extend
2024-04-17 15:13:17 -07:00
Jevin Jiang
d44b16cfde
[XLA:Mosaic] Generalize (8,128) -> (8 * packing,128) retiling for packed type.
...
PiperOrigin-RevId: 625816937
2024-04-17 15:01:37 -07:00
Yash Katariya
7cb0e601de
Remove the spmd_mode check from OSS JAX since enhanced barrier is switched on for OSS JAX
...
PiperOrigin-RevId: 625763988
2024-04-17 12:08:49 -07:00
jax authors
1c8534e034
Merge pull request #20802 from jakevdp:dup-thinking-in-jax
...
PiperOrigin-RevId: 625738710
2024-04-17 10:55:11 -07:00
jax authors
73bd54f0e0
Merge pull request #20800 from jakevdp:pin-sphinx
...
PiperOrigin-RevId: 625730488
2024-04-17 10:31:23 -07:00
Jake VanderPlas
edfdc36b4c
DOC: remove copy of thinking-in-jax from new tutorial flow
2024-04-17 10:27:15 -07:00
Jake VanderPlas
29b15e2e2d
DOC: pin sphinx to >=7.3.2
2024-04-17 09:51:27 -07:00
Adam Paszke
acde885b81
Fix Pallas' while_loop lowering to properly account for the loop length
...
The previous implementation was only valid when the lower bound was 0.
PiperOrigin-RevId: 625613195
2024-04-17 02:34:32 -07:00
jax authors
6963c77044
Reverts f18739900c615a85e8d182bcf3217f704cf7aa0d
...
PiperOrigin-RevId: 625541309
2024-04-16 20:28:24 -07:00
jax authors
00e3a1eab9
Update XLA dependency to use revision
...
ca98bc910f
.
PiperOrigin-RevId: 625530388
2024-04-16 19:34:16 -07:00
jax authors
1d615a3bbd
Merge pull request #20792 from jakevdp:fix-sharding-doc
...
PiperOrigin-RevId: 625515766
2024-04-16 18:32:02 -07:00
jax authors
fc654a5772
Merge pull request #20732 from jakevdp:doc-stateful
...
PiperOrigin-RevId: 625515752
2024-04-16 18:27:26 -07:00
jax authors
07063fb11f
Merge pull request #20794 from jakevdp:pin-sphinx
...
PiperOrigin-RevId: 625515741
2024-04-16 18:27:05 -07:00
jax authors
e554416d2c
Merge pull request #20784 from jakevdp:doc-installation
...
PiperOrigin-RevId: 625515725
2024-04-16 18:22:29 -07:00
Jake VanderPlas
3c18a021c3
DOC: add stateful computation doc
2024-04-16 17:54:55 -07:00
Jake VanderPlas
40ad5b283f
Pin sphinx version to avoid error in 7.3.0
2024-04-16 16:51:06 -07:00
Jake VanderPlas
2552c29028
DOC: fix Mesh construction in sharding doc
2024-04-16 16:26:18 -07:00
jax authors
1c4195c55c
Merge pull request #20787 from superbobry:lazy-imports
...
PiperOrigin-RevId: 625468662
2024-04-16 15:15:55 -07:00
Sergei Lebedev
1be5451179
Import rich lazily
...
This ensures that the timing of `import jax` is not affected by `rich` being
installed.
See also #20778 .
2024-04-16 22:25:33 +01:00
Peter Hawkins
f18739900c
Import etils.epath lazily.
...
Reduces jax import time.
PiperOrigin-RevId: 625452204
2024-04-16 14:23:05 -07:00
Yue Sheng
1f83908bae
Temporarily disable async dispatch on JAX CPU by setting 'jax_cpu_enable_async_dispatch' to be False
by default, as we observed abnormal memory usage increases.
...
PiperOrigin-RevId: 625448228
2024-04-16 14:10:17 -07:00
jax authors
47815c54d0
Merge pull request #20756 from Micky774:array-api-cumulative-sum
...
PiperOrigin-RevId: 625439530
2024-04-16 13:43:47 -07:00
Meekail Zain
ceeb975735
Add new cumulative_sum function to numpy and array_api
2024-04-16 19:57:55 +00:00
Jake VanderPlas
a4ccb354f8
DOC: update installation guide
2024-04-16 12:46:36 -07:00