12969 Commits

Author SHA1 Message Date
Yash Katariya
8239674dab Replace donation_vector's logic with donation_vector_with_in_tree which is now deleted
PiperOrigin-RevId: 627556267
2024-04-23 17:38:30 -07:00
jax authors
8842c0bc91 Merge pull request #20901 from carlosgmartin:lax-scan-xs-none
PiperOrigin-RevId: 627546834
2024-04-23 17:01:24 -07:00
Yash Katariya
3f17626f4b Fix donation with kwargs. The problem is that pytrees sort dictionaries by default. So if we create the donation vector with original kwargs order, it won't match the aval order (which is created by sorting kwargs i.e. dict) and we end up donating the wrong input.
Fix this by calculating the donation vector by looking at the in_tree.

A bonus is that we can now cache the calculation of donation vector leading to faster tracing times in JAX.

PiperOrigin-RevId: 627512710
2024-04-23 14:50:04 -07:00
Paul Wohlhart
6b85557cc1 Use xla_client.Device in jax.numpy.
PiperOrigin-RevId: 627507470
2024-04-23 14:32:08 -07:00
carlosgmartin
2b332de9d7 Let xs=None by default in lax.scan. 2024-04-23 17:26:23 -04:00
Enrique Piqueras
cf9c08589e Add builtin cc dataclass pytree node for performance.
PiperOrigin-RevId: 627502102
2024-04-23 14:14:49 -07:00
jax authors
8b1418244b Merge pull request #20885 from rajasekharporeddy:test_branch4
PiperOrigin-RevId: 627486343
2024-04-23 13:29:40 -07:00
rajasekharporeddy
c536eea1e5 Fix jax.scipy.stats.beta.logpdf to emulate scipy.stats.beta.logpdf 2024-04-24 01:24:09 +05:30
jax authors
ba57ce3bd1 Merge pull request #20891 from rajasekharporeddy:test_branch1
PiperOrigin-RevId: 627472771
2024-04-23 12:39:18 -07:00
Peter Hawkins
ab30bcf071 [jax2tf] Bump asinh test tolerance in graph and eager modes.
Fixes CI test failure due to LLVM update.

PiperOrigin-RevId: 627462404
2024-04-23 12:03:00 -07:00
rajasekharporeddy
95ed0538fd Fix jax.scipy.stats.poisson.logpmf to emulate scipy.stats.poisson.logpmf for non-integer values of k 2024-04-24 00:29:52 +05:30
Pearu Peterson
e8ff7028f4 Workaround mpmath 1.3 issues in asin and asinh evaluation at infinities and on branch cuts. 2024-04-23 21:01:43 +03:00
Yunlong Liu
2df6b35dce Adds meaningful function names for better debugging.
The default `fn.__name__` was added in `_one_to_one_unop` but not other functions so that it leads to many downstream function wrappers giving unmeaningful names while debugging. For instance,

When a JAX numpy primitive `lax.add` is wrapped by `lu.WrappedFun`, `print(wrapped)` will give,

```
Wrapped function:
0   : _argnums_partial   ((0, 1), ())
1   : flatten_fun   (PyTreeDef(((*, *), {})),)
2   : result_paths   ()
Core: fn
```
instead of
```
Wrapped function:
0   : _argnums_partial   ((0, 1), ())
1   : flatten_fun   (PyTreeDef(((*, *), {})),)
2   : result_paths   ()
Core: add
```
PiperOrigin-RevId: 627417452
2024-04-23 09:45:57 -07:00
jax authors
493698e6e0 Merge pull request #20195 from Micky774:array_api_astype
PiperOrigin-RevId: 627232885
2024-04-22 19:30:51 -07:00
jax authors
d20a2f1070 Merge pull request #20317 from inailuig:mpi_collectives
PiperOrigin-RevId: 627208382
2024-04-22 17:41:44 -07:00
jax authors
47c9495dcb Return early from eigh for small matrices. This was accidentally removed cl/577222219.
PiperOrigin-RevId: 627168970
2024-04-22 15:10:38 -07:00
jax authors
1f4c31d0af Merge pull request #20849 from mattjj:jit-docstring-tweaks
PiperOrigin-RevId: 627167535
2024-04-22 15:05:38 -07:00
Parker Schuh
7ba811eb4a Support auto in shard_map.
- Pull mesh from NamedSharding when rewriting manual axes.
- Properly set manual axes in SPMDAxisContext in shard_map.
- Properly set dims as unspecified inside shard_map.

PiperOrigin-RevId: 627156892
2024-04-22 14:29:35 -07:00
Selam Waktola
b02f82b815 redundant phrase 'ever time' removed 2024-04-22 11:47:23 -07:00
Junwhan Ahn
4be25d7151 Optimize jax.device_put() dispatch for 1:1 device-to-device transfers
* Cache the sharding index comparison in addition to sharding index calculation. This helps when the list of indices is expensive to compare.
* Remove caching from `pxla.get_addressable_devices_for_shard_arg()` since `sharding._addressable_device_assignment` is already a cached property.
* Use `a is b` instead of `id(a) == id(b)` since the former is more concise.

PiperOrigin-RevId: 627080325
2024-04-22 10:24:35 -07:00
Meekail Zain
30cd3b88fd Add support for copy kwarg in astype to match Array API 2024-04-22 16:25:37 +00:00
jax authors
667a0c1fe5 Add some docstrings for remote DMAs and semaphore barriers.
PiperOrigin-RevId: 627037991
2024-04-22 08:01:57 -07:00
Clemens Giuliani
1e32fb510b ignore type 2024-04-22 16:35:31 +02:00
Adam Paszke
b79f3b77ef [Mosaic:GPU] Update lowering to match upstream changes in the LLVM dialect
LLVM integer arithmetic ops now explicitly require the overflow flags.

PiperOrigin-RevId: 627020143
2024-04-22 06:46:27 -07:00
Jake VanderPlas
5953ae896c jnp.select: lower to lax.select_n 2024-04-22 05:39:11 -07:00
rajasekharporeddy
aaddba0c20 Fix doc Typos 2024-04-22 10:32:51 +05:30
Yash Katariya
1837b436d7 Merge some loops in device_put since it's trivial to do so
PiperOrigin-RevId: 626546322
2024-04-19 20:59:55 -07:00
Matthew Johnson
b8df23c25b tweak jit docstring 2024-04-19 17:37:52 -07:00
Sergei Lebedev
5981df7bab Removed unnecessary jax.tree.map calls from *_callback_impl functions
jax.device_put works for any PyTree.

PiperOrigin-RevId: 626510762
2024-04-19 17:34:05 -07:00
Clemens Giuliani
fdd24d137f try remove MpiCollectives from type annotation 2024-04-20 00:53:31 +02:00
Sergei Lebedev
6e23c14f85 jax.debug.callback now passes arguments as jax.Arrays
Prior to this change the behavior in eager and under jax.jit was inconsistent

    >>> (lambda *args: jax.debug.callback(print, *args))([42])
    [42]
    >>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42])
    [array(42, dtype=int32)]

It was also inconsistent with other callback APIs, which cast the arguments
to jax.Arrays.

Closes #20627.

PiperOrigin-RevId: 626461904
2024-04-19 13:57:18 -07:00
Clemens Giuliani
60d4c43fcb Add a common flag for the collectives implementations on cpu. 2024-04-19 20:55:35 +02:00
Sergei Lebedev
32922f61e9 jax.debug.callback now requires a Callable[..., None]
This makes the "return value is ignored" behavior explicit in the type.

PiperOrigin-RevId: 626430448
2024-04-19 11:55:08 -07:00
George Necula
cea36a0438 [jax2tf] Adjust tolerance for asinh test.
This test has started to fail in compiled mode, for complex128, but with small errors (1e-14).
Adjust the tolerance for both the native and non-native serialization mode.

PiperOrigin-RevId: 626373781
2024-04-19 08:31:53 -07:00
jax authors
c7517b832a Merge pull request #20825 from jakevdp:gammasgn
PiperOrigin-RevId: 626347182
2024-04-19 06:25:36 -07:00
Jake VanderPlas
41fa67c2dc Finalize deprecation of zero-dimensional inputs to jnp.nonzero
PiperOrigin-RevId: 626299531
2024-04-19 02:19:10 -07:00
Yash Katariya
837f0bbf6f Cache the _check_sharding check in device_put. If aval and sharding are the same, no need to check multiple times
PiperOrigin-RevId: 626244240
2024-04-18 21:26:35 -07:00
Jake VanderPlas
568db105ea Add jax.scipy.special.gammasgn 2024-04-18 16:14:55 -07:00
jax authors
de7d3b6838 Merge pull request #20816 from superbobry:int4
PiperOrigin-RevId: 626131730
2024-04-18 13:26:10 -07:00
Yue Sheng
c2d4373535 Make core.Token a non-trivial class which wraps a jax.Array. Currently, we use a singleton and empty core.token object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).

PiperOrigin-RevId: 626091210
2024-04-18 11:09:55 -07:00
jax authors
9c9e805e82 [Pallas TPU] Generalize while_loop lowering in Pallas -> Mosaic.
The existing lowering path supports only while_loops which can be converted to fori_loop.
That path makes it significantly easier to optimize and unroll, but cannot support a large class of interesting loop formulations.

This patch draws from the Pallas -> Triton while_loop lowering rule to support such loops in Pallas.
Matching is still performed against fori_loop, to lower via that mechanism if possible -- as it is likely more straightforwardly optimizable compared to general "while".

PiperOrigin-RevId: 626089180
2024-04-18 11:03:52 -07:00
jax authors
51763d8b5d Fix bug in rank-deficient fix-up code: Do not zero out the corresponding column of u_out if a diagonal entry of r is exactly zero.
PiperOrigin-RevId: 626056825
2024-04-18 09:20:48 -07:00
Sergei Lebedev
a13efc2815 Added int4 and uint4 to dtype-specific tests
I probably missed some cases, so this PR is really just the first step in
making sure we have good *int4 coverage.
2024-04-18 15:20:20 +01:00
jax authors
bb8cf34a31 Document the fact that jax.clear_caches() doesn't affect the persistent cache.
PiperOrigin-RevId: 626019057
2024-04-18 06:52:40 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
jax authors
1a8aae0e00 Merge pull request #20768 from chaserileyroberts:chase/eval_jaxpr_in_extend
PiperOrigin-RevId: 625830536
2024-04-17 15:48:31 -07:00
Chase Roberts
74c2e25314 Add more imports to jax extend 2024-04-17 15:13:17 -07:00
Yash Katariya
7cb0e601de Remove the spmd_mode check from OSS JAX since enhanced barrier is switched on for OSS JAX
PiperOrigin-RevId: 625763988
2024-04-17 12:08:49 -07:00
Adam Paszke
acde885b81 Fix Pallas' while_loop lowering to properly account for the loop length
The previous implementation was only valid when the lower bound was 0.

PiperOrigin-RevId: 625613195
2024-04-17 02:34:32 -07:00
jax authors
6963c77044 Reverts f18739900c615a85e8d182bcf3217f704cf7aa0d
PiperOrigin-RevId: 625541309
2024-04-16 20:28:24 -07:00