Sergei Lebedev
6e23c14f85
jax.debug.callback now passes arguments as jax.Arrays
...
Prior to this change the behavior in eager and under jax.jit was inconsistent
>>> (lambda *args: jax.debug.callback(print, *args))([42])
[42]
>>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42])
[array(42, dtype=int32)]
It was also inconsistent with other callback APIs, which cast the arguments
to jax.Arrays.
Closes #20627 .
PiperOrigin-RevId: 626461904
2024-04-19 13:57:18 -07:00
Sergei Lebedev
32922f61e9
jax.debug.callback now requires a Callable[..., None]
...
This makes the "return value is ignored" behavior explicit in the type.
PiperOrigin-RevId: 626430448
2024-04-19 11:55:08 -07:00
George Necula
cea36a0438
[jax2tf] Adjust tolerance for asinh test.
...
This test has started to fail in compiled mode, for complex128, but with small errors (1e-14).
Adjust the tolerance for both the native and non-native serialization mode.
PiperOrigin-RevId: 626373781
2024-04-19 08:31:53 -07:00
jax authors
c7517b832a
Merge pull request #20825 from jakevdp:gammasgn
...
PiperOrigin-RevId: 626347182
2024-04-19 06:25:36 -07:00
Jake VanderPlas
41fa67c2dc
Finalize deprecation of zero-dimensional inputs to jnp.nonzero
...
PiperOrigin-RevId: 626299531
2024-04-19 02:19:10 -07:00
Yash Katariya
837f0bbf6f
Cache the _check_sharding
check in device_put. If aval and sharding are the same, no need to check multiple times
...
PiperOrigin-RevId: 626244240
2024-04-18 21:26:35 -07:00
Jake VanderPlas
568db105ea
Add jax.scipy.special.gammasgn
2024-04-18 16:14:55 -07:00
jax authors
de7d3b6838
Merge pull request #20816 from superbobry:int4
...
PiperOrigin-RevId: 626131730
2024-04-18 13:26:10 -07:00
Yue Sheng
c2d4373535
Make core.Token
a non-trivial class which wraps a jax.Array
. Currently, we use a singleton and empty core.token
object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
...
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).
PiperOrigin-RevId: 626091210
2024-04-18 11:09:55 -07:00
jax authors
9c9e805e82
[Pallas TPU] Generalize while_loop lowering in Pallas -> Mosaic.
...
The existing lowering path supports only while_loops which can be converted to fori_loop.
That path makes it significantly easier to optimize and unroll, but cannot support a large class of interesting loop formulations.
This patch draws from the Pallas -> Triton while_loop lowering rule to support such loops in Pallas.
Matching is still performed against fori_loop, to lower via that mechanism if possible -- as it is likely more straightforwardly optimizable compared to general "while".
PiperOrigin-RevId: 626089180
2024-04-18 11:03:52 -07:00
jax authors
51763d8b5d
Fix bug in rank-deficient fix-up code: Do not zero out the corresponding column of u_out if a diagonal entry of r is exactly zero.
...
PiperOrigin-RevId: 626056825
2024-04-18 09:20:48 -07:00
Sergei Lebedev
a13efc2815
Added int4 and uint4 to dtype-specific tests
...
I probably missed some cases, so this PR is really just the first step in
making sure we have good *int4 coverage.
2024-04-18 15:20:20 +01:00
jax authors
bb8cf34a31
Document the fact that jax.clear_caches() doesn't affect the persistent cache.
...
PiperOrigin-RevId: 626019057
2024-04-18 06:52:40 -07:00
Adam Paszke
8e3f5b1018
Initial commit for Mosaic GPU
...
Moving this to JAX to make it easier to explore Pallas integration.
PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
jax authors
1a8aae0e00
Merge pull request #20768 from chaserileyroberts:chase/eval_jaxpr_in_extend
...
PiperOrigin-RevId: 625830536
2024-04-17 15:48:31 -07:00
Chase Roberts
74c2e25314
Add more imports to jax extend
2024-04-17 15:13:17 -07:00
Yash Katariya
7cb0e601de
Remove the spmd_mode check from OSS JAX since enhanced barrier is switched on for OSS JAX
...
PiperOrigin-RevId: 625763988
2024-04-17 12:08:49 -07:00
Adam Paszke
acde885b81
Fix Pallas' while_loop lowering to properly account for the loop length
...
The previous implementation was only valid when the lower bound was 0.
PiperOrigin-RevId: 625613195
2024-04-17 02:34:32 -07:00
jax authors
6963c77044
Reverts f18739900c615a85e8d182bcf3217f704cf7aa0d
...
PiperOrigin-RevId: 625541309
2024-04-16 20:28:24 -07:00
jax authors
1c4195c55c
Merge pull request #20787 from superbobry:lazy-imports
...
PiperOrigin-RevId: 625468662
2024-04-16 15:15:55 -07:00
Sergei Lebedev
1be5451179
Import rich lazily
...
This ensures that the timing of `import jax` is not affected by `rich` being
installed.
See also #20778 .
2024-04-16 22:25:33 +01:00
Peter Hawkins
f18739900c
Import etils.epath lazily.
...
Reduces jax import time.
PiperOrigin-RevId: 625452204
2024-04-16 14:23:05 -07:00
Yue Sheng
1f83908bae
Temporarily disable async dispatch on JAX CPU by setting 'jax_cpu_enable_async_dispatch' to be False
by default, as we observed abnormal memory usage increases.
...
PiperOrigin-RevId: 625448228
2024-04-16 14:10:17 -07:00
Meekail Zain
ceeb975735
Add new cumulative_sum function to numpy and array_api
2024-04-16 19:57:55 +00:00
jax authors
adbb11f9fe
Merge pull request #20778 from superbobry:lazy-imports
...
PiperOrigin-RevId: 625415993
2024-04-16 12:25:04 -07:00
jax authors
42398a5411
Merge pull request #20773 from jakevdp:update-array-api
...
PiperOrigin-RevId: 625394362
2024-04-16 11:17:45 -07:00
jax authors
d2a2b9f29e
Merge pull request #20657 from Sai-Suraj-27:fix_error_messages
...
PiperOrigin-RevId: 625389526
2024-04-16 11:03:50 -07:00
Sergei Lebedev
35136851f2
Import web_pdb lazily
...
This ensures that `import jax` is not affected by `web_pdb` being installed.
Note that I also added an atexit handler closing the active `web_pdb`
consoles. This is strictly speaking not necessary, as the server powering
the console is running in a daemon thread, but nice-to-have anyway.
2024-04-16 17:22:11 +01:00
jax authors
06cd05d1d6
Merge pull request #20771 from rajasekharporeddy:doc_typos
...
PiperOrigin-RevId: 625311360
2024-04-16 06:36:52 -07:00
Jake VanderPlas
572c16284e
[array api] update to latest test repo commit
2024-04-16 06:09:00 -07:00
rajasekharporeddy
660d612e6a
DOC: Fix docstring typos in scipy special functions
2024-04-16 16:05:39 +05:30
Meekail Zain
6bdc83c680
Add new unstack function to numpy/array_api namespaces
2024-04-15 21:03:26 +00:00
Yue Sheng
64775d02a3
Async dispatch expensive computations on the JAX CPU backend.
...
Before the change, on CPU backend we always run computations inline unless there are multiple CPU devices with potential collectives. Now, we will use `HloCostAnalysis` to estimate the cost of the computation and do async dispatch if it is expensive.
Add a JAX flag for users to opt-out by adding `jax.config.update('jax_cpu_enable_async_dispatch', False)` in their programs.
PiperOrigin-RevId: 625064815
2024-04-15 13:29:44 -07:00
Yash Katariya
eb92a5c711
Add layout support to make_array_from_callback
.
...
PiperOrigin-RevId: 625048520
2024-04-15 12:38:34 -07:00
jax authors
b9a853d0c1
Merge pull request #20734 from hawkinsp:atfork
...
PiperOrigin-RevId: 625038840
2024-04-15 12:06:20 -07:00
jax authors
5f22b12576
Merge pull request #20754 from Micky774:array-api-hypot
...
PiperOrigin-RevId: 625035601
2024-04-15 11:56:53 -07:00
jax authors
982ab5337d
Merge pull request #20753 from Micky774:array-api-expose
...
PiperOrigin-RevId: 625027257
2024-04-15 11:31:44 -07:00
Meekail Zain
2899213efb
Fixed hypot bug on nan/inf pairings, began deprecation of non-real values
2024-04-15 17:56:16 +00:00
Meekail Zain
8b93da1830
Expose existing functions in array API namespace
2024-04-15 16:25:30 +00:00
Yash Katariya
90401d51e9
Accept layout on ShapeDtypeStruct
on the sharding
argument. DeviceLocalLayout.AUTO
is not allowed on SDS.
...
PiperOrigin-RevId: 624982814
2024-04-15 09:19:40 -07:00
jax authors
b7293005af
Merge pull request #20762 from j-towns:scatter-doc-correction
...
PiperOrigin-RevId: 624971136
2024-04-15 08:38:57 -07:00
Junwhan Ahn
ac1a53d8e4
Optimize _create_copy_plan
by iterating over only the shards that are needed for materialization
...
For arrays that are fully or partially replicated, it is more efficient to (pre-)construct a list of addressable array shards that participate in materialization rather than going over all array shards. This is particularly useful for single-controller JAX.
The implementation assumes that addressable arrays appear in the same order as the corresponding addressable devices in `sharding.addressable_devices_indices_map()`.
PiperOrigin-RevId: 624969222
2024-04-15 08:29:47 -07:00
jax authors
3a09404426
Merge pull request #20586 from superbobry:jaxlib
...
PiperOrigin-RevId: 624941598
2024-04-15 06:40:07 -07:00
Jamie Townsend
b2783120c0
Correct a name in ScatterDimensionNumbers docstring
2024-04-15 10:36:24 +00:00
Yash Katariya
2c85ca6fec
If callback returns a fully replicated global array, return it as is.
...
Also take the batched_device_put fast path for non-jax.Array's since slicing can return arrays on multiple devices which batched_device_put doesn't support.
PiperOrigin-RevId: 624763603
2024-04-14 14:35:57 -07:00
Yash Katariya
5ce7dca969
Add support for loading checkpoints with a given layout to the array serialization library
...
PiperOrigin-RevId: 624596358
2024-04-13 19:35:50 -07:00
Sergei Lebedev
754fab91f7
Bumped the minimum jaxlib to 0.4.23
...
jaxlib 0.4.23 has xla_extension_version 223 and mlir_api_version 54.
2024-04-13 08:18:33 +01:00
Yash Katariya
70dca30395
Remove the dead code now that jax.Array is the only array we have
...
PiperOrigin-RevId: 624390245
2024-04-12 21:41:42 -07:00
Yash Katariya
9e989321f1
Make sure we don't return GSPMDSharding in compiled.input_shardings
...
PiperOrigin-RevId: 624343180
2024-04-12 17:52:44 -07:00
Roy Frostig
09415607bb
fix up extend:core
build rule
...
We want `pytype_strict_library` here.
PiperOrigin-RevId: 624337356
2024-04-12 17:31:10 -07:00