12788 Commits

Author SHA1 Message Date
Sergei Lebedev
6e23c14f85 jax.debug.callback now passes arguments as jax.Arrays
Prior to this change the behavior in eager and under jax.jit was inconsistent

    >>> (lambda *args: jax.debug.callback(print, *args))([42])
    [42]
    >>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42])
    [array(42, dtype=int32)]

It was also inconsistent with other callback APIs, which cast the arguments
to jax.Arrays.

Closes #20627.

PiperOrigin-RevId: 626461904
2024-04-19 13:57:18 -07:00
Sergei Lebedev
32922f61e9 jax.debug.callback now requires a Callable[..., None]
This makes the "return value is ignored" behavior explicit in the type.

PiperOrigin-RevId: 626430448
2024-04-19 11:55:08 -07:00
George Necula
cea36a0438 [jax2tf] Adjust tolerance for asinh test.
This test has started to fail in compiled mode, for complex128, but with small errors (1e-14).
Adjust the tolerance for both the native and non-native serialization mode.

PiperOrigin-RevId: 626373781
2024-04-19 08:31:53 -07:00
jax authors
c7517b832a Merge pull request #20825 from jakevdp:gammasgn
PiperOrigin-RevId: 626347182
2024-04-19 06:25:36 -07:00
Jake VanderPlas
41fa67c2dc Finalize deprecation of zero-dimensional inputs to jnp.nonzero
PiperOrigin-RevId: 626299531
2024-04-19 02:19:10 -07:00
Yash Katariya
837f0bbf6f Cache the _check_sharding check in device_put. If aval and sharding are the same, no need to check multiple times
PiperOrigin-RevId: 626244240
2024-04-18 21:26:35 -07:00
Jake VanderPlas
568db105ea Add jax.scipy.special.gammasgn 2024-04-18 16:14:55 -07:00
jax authors
de7d3b6838 Merge pull request #20816 from superbobry:int4
PiperOrigin-RevId: 626131730
2024-04-18 13:26:10 -07:00
Yue Sheng
c2d4373535 Make core.Token a non-trivial class which wraps a jax.Array. Currently, we use a singleton and empty core.token object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).

PiperOrigin-RevId: 626091210
2024-04-18 11:09:55 -07:00
jax authors
9c9e805e82 [Pallas TPU] Generalize while_loop lowering in Pallas -> Mosaic.
The existing lowering path supports only while_loops which can be converted to fori_loop.
That path makes it significantly easier to optimize and unroll, but cannot support a large class of interesting loop formulations.

This patch draws from the Pallas -> Triton while_loop lowering rule to support such loops in Pallas.
Matching is still performed against fori_loop, to lower via that mechanism if possible -- as it is likely more straightforwardly optimizable compared to general "while".

PiperOrigin-RevId: 626089180
2024-04-18 11:03:52 -07:00
jax authors
51763d8b5d Fix bug in rank-deficient fix-up code: Do not zero out the corresponding column of u_out if a diagonal entry of r is exactly zero.
PiperOrigin-RevId: 626056825
2024-04-18 09:20:48 -07:00
Sergei Lebedev
a13efc2815 Added int4 and uint4 to dtype-specific tests
I probably missed some cases, so this PR is really just the first step in
making sure we have good *int4 coverage.
2024-04-18 15:20:20 +01:00
jax authors
bb8cf34a31 Document the fact that jax.clear_caches() doesn't affect the persistent cache.
PiperOrigin-RevId: 626019057
2024-04-18 06:52:40 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
jax authors
1a8aae0e00 Merge pull request #20768 from chaserileyroberts:chase/eval_jaxpr_in_extend
PiperOrigin-RevId: 625830536
2024-04-17 15:48:31 -07:00
Chase Roberts
74c2e25314 Add more imports to jax extend 2024-04-17 15:13:17 -07:00
Yash Katariya
7cb0e601de Remove the spmd_mode check from OSS JAX since enhanced barrier is switched on for OSS JAX
PiperOrigin-RevId: 625763988
2024-04-17 12:08:49 -07:00
Adam Paszke
acde885b81 Fix Pallas' while_loop lowering to properly account for the loop length
The previous implementation was only valid when the lower bound was 0.

PiperOrigin-RevId: 625613195
2024-04-17 02:34:32 -07:00
jax authors
6963c77044 Reverts f18739900c615a85e8d182bcf3217f704cf7aa0d
PiperOrigin-RevId: 625541309
2024-04-16 20:28:24 -07:00
jax authors
1c4195c55c Merge pull request #20787 from superbobry:lazy-imports
PiperOrigin-RevId: 625468662
2024-04-16 15:15:55 -07:00
Sergei Lebedev
1be5451179 Import rich lazily
This ensures that the timing of `import jax` is not affected by `rich` being
installed.

See also #20778.
2024-04-16 22:25:33 +01:00
Peter Hawkins
f18739900c Import etils.epath lazily.
Reduces jax import time.

PiperOrigin-RevId: 625452204
2024-04-16 14:23:05 -07:00
Yue Sheng
1f83908bae Temporarily disable async dispatch on JAX CPU by setting 'jax_cpu_enable_async_dispatch' to be False by default, as we observed abnormal memory usage increases.
PiperOrigin-RevId: 625448228
2024-04-16 14:10:17 -07:00
Meekail Zain
ceeb975735 Add new cumulative_sum function to numpy and array_api 2024-04-16 19:57:55 +00:00
jax authors
adbb11f9fe Merge pull request #20778 from superbobry:lazy-imports
PiperOrigin-RevId: 625415993
2024-04-16 12:25:04 -07:00
jax authors
42398a5411 Merge pull request #20773 from jakevdp:update-array-api
PiperOrigin-RevId: 625394362
2024-04-16 11:17:45 -07:00
jax authors
d2a2b9f29e Merge pull request #20657 from Sai-Suraj-27:fix_error_messages
PiperOrigin-RevId: 625389526
2024-04-16 11:03:50 -07:00
Sergei Lebedev
35136851f2 Import web_pdb lazily
This ensures that `import jax` is not affected by `web_pdb` being installed.

Note that I also added an atexit handler closing the active `web_pdb`
consoles. This is strictly speaking not necessary, as the server powering
the console is running in a daemon thread, but nice-to-have anyway.
2024-04-16 17:22:11 +01:00
jax authors
06cd05d1d6 Merge pull request #20771 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 625311360
2024-04-16 06:36:52 -07:00
Jake VanderPlas
572c16284e [array api] update to latest test repo commit 2024-04-16 06:09:00 -07:00
rajasekharporeddy
660d612e6a DOC: Fix docstring typos in scipy special functions 2024-04-16 16:05:39 +05:30
Meekail Zain
6bdc83c680 Add new unstack function to numpy/array_api namespaces 2024-04-15 21:03:26 +00:00
Yue Sheng
64775d02a3 Async dispatch expensive computations on the JAX CPU backend.
Before the change, on CPU backend we always run computations inline unless there are multiple CPU devices with potential collectives. Now, we will use `HloCostAnalysis` to estimate the cost of the computation and do async dispatch if it is expensive.

Add a JAX flag for users to opt-out by adding `jax.config.update('jax_cpu_enable_async_dispatch', False)` in their programs.

PiperOrigin-RevId: 625064815
2024-04-15 13:29:44 -07:00
Yash Katariya
eb92a5c711 Add layout support to make_array_from_callback.
PiperOrigin-RevId: 625048520
2024-04-15 12:38:34 -07:00
jax authors
b9a853d0c1 Merge pull request #20734 from hawkinsp:atfork
PiperOrigin-RevId: 625038840
2024-04-15 12:06:20 -07:00
jax authors
5f22b12576 Merge pull request #20754 from Micky774:array-api-hypot
PiperOrigin-RevId: 625035601
2024-04-15 11:56:53 -07:00
jax authors
982ab5337d Merge pull request #20753 from Micky774:array-api-expose
PiperOrigin-RevId: 625027257
2024-04-15 11:31:44 -07:00
Meekail Zain
2899213efb Fixed hypot bug on nan/inf pairings, began deprecation of non-real values 2024-04-15 17:56:16 +00:00
Meekail Zain
8b93da1830 Expose existing functions in array API namespace 2024-04-15 16:25:30 +00:00
Yash Katariya
90401d51e9 Accept layout on ShapeDtypeStruct on the sharding argument. DeviceLocalLayout.AUTO is not allowed on SDS.
PiperOrigin-RevId: 624982814
2024-04-15 09:19:40 -07:00
jax authors
b7293005af Merge pull request #20762 from j-towns:scatter-doc-correction
PiperOrigin-RevId: 624971136
2024-04-15 08:38:57 -07:00
Junwhan Ahn
ac1a53d8e4 Optimize _create_copy_plan by iterating over only the shards that are needed for materialization
For arrays that are fully or partially replicated, it is more efficient to (pre-)construct a list of addressable array shards that participate in materialization rather than going over all array shards. This is particularly useful for single-controller JAX.

The implementation assumes that addressable arrays appear in the same order as the corresponding addressable devices in `sharding.addressable_devices_indices_map()`.

PiperOrigin-RevId: 624969222
2024-04-15 08:29:47 -07:00
jax authors
3a09404426 Merge pull request #20586 from superbobry:jaxlib
PiperOrigin-RevId: 624941598
2024-04-15 06:40:07 -07:00
Jamie Townsend
b2783120c0 Correct a name in ScatterDimensionNumbers docstring 2024-04-15 10:36:24 +00:00
Yash Katariya
2c85ca6fec If callback returns a fully replicated global array, return it as is.
Also take the batched_device_put fast path for non-jax.Array's since slicing can return arrays on multiple devices which batched_device_put doesn't support.

PiperOrigin-RevId: 624763603
2024-04-14 14:35:57 -07:00
Yash Katariya
5ce7dca969 Add support for loading checkpoints with a given layout to the array serialization library
PiperOrigin-RevId: 624596358
2024-04-13 19:35:50 -07:00
Sergei Lebedev
754fab91f7 Bumped the minimum jaxlib to 0.4.23
jaxlib 0.4.23 has xla_extension_version 223 and mlir_api_version 54.
2024-04-13 08:18:33 +01:00
Yash Katariya
70dca30395 Remove the dead code now that jax.Array is the only array we have
PiperOrigin-RevId: 624390245
2024-04-12 21:41:42 -07:00
Yash Katariya
9e989321f1 Make sure we don't return GSPMDSharding in compiled.input_shardings
PiperOrigin-RevId: 624343180
2024-04-12 17:52:44 -07:00
Roy Frostig
09415607bb fix up extend:core build rule
We want `pytype_strict_library` here.

PiperOrigin-RevId: 624337356
2024-04-12 17:31:10 -07:00