1661 Commits

Author SHA1 Message Date
Yash Katariya
854b2c85db Drop into Auto mode for .at[...].set(...) but instead of taking an out_sharding argument in set, use the input array's sharding. Since this is an update, after .set, the input array's sharding should be preserved.
Fixes https://github.com/jax-ml/jax/issues/28111

PiperOrigin-RevId: 749089846
2025-04-18 11:17:38 -07:00
Yash Katariya
7de522c5a3 Enter into auto mode for .at[...].get(...) a bit earlier so that all ops inside _gather are in auto mode.
Fix select's batching rule where `explicit_mesh_axis` that we capture in `axis_data` was not propagated properly to the `broadcast` happening in `bdim_at_front`.

PiperOrigin-RevId: 748867490
2025-04-17 17:42:13 -07:00
Sergei Lebedev
c576d328bd Added lax.axis_size and switched all existing usage of psum(1, ...) to it
PiperOrigin-RevId: 748604842
2025-04-17 02:22:25 -07:00
Jake VanderPlas
42542feac6 jnp.power: better docs for invalid input 2025-04-14 10:42:29 -07:00
carlosgmartin
b6a46310d1 Merge tuple_replace and tuple_update in jax._src.util. 2025-04-09 12:50:42 -04:00
jax authors
76825a2d45 Merge pull request #27807 from jakevdp:eigvalsh-symmetrize
PiperOrigin-RevId: 745216021
2025-04-08 11:09:58 -07:00
Jake VanderPlas
b7d430f96b jnp.repeat: don't cast repeats to array, as they must be static. 2025-04-08 10:32:03 -07:00
Peter Hawkins
e02faabfb2 Replace references to jax.readthedocs.io with docs.jax.dev.
PiperOrigin-RevId: 745156931
2025-04-08 08:33:49 -07:00
Jake VanderPlas
96e63eaee8 jnp.linalg: add symmetrize_input argument & docs 2025-04-07 14:46:38 -07:00
Jake VanderPlas
d3cfff057f jax.numpy: support __jax_array__ in remaining APIs 2025-04-07 14:08:35 -07:00
jax authors
35d75183c7 _attempt_rewriting_take_via_slice(): canonicalize the slice index before checking it's not too long, so that e.g. my_1d_array[:, ...] can be treated as a slice rather than generating a gather operation.
PiperOrigin-RevId: 743986126
2025-04-04 10:10:38 -07:00
jax authors
d7fc04b682 Merge pull request #27681 from jakevdp:jax-array
PiperOrigin-RevId: 743658654
2025-04-03 12:32:00 -07:00
Sergei Lebedev
9c58a112b3 jnp.array no longer accepts None
PiperOrigin-RevId: 743291099
2025-04-02 14:58:51 -07:00
Jake VanderPlas
96780f19b0 jax.numpy: support __jax_array__ in several more functions 2025-04-02 14:28:54 -07:00
jax authors
e75b66463c Merge pull request #27680 from jakevdp:array-api-version
PiperOrigin-RevId: 743277981
2025-04-02 14:22:03 -07:00
Jake VanderPlas
a2d62e2d3a [array_api] update array_api_version to 2024.12 2025-04-02 13:46:07 -07:00
Jake VanderPlas
7f4e8c56fe jnp.concat and friends: support __jax_array__ 2025-04-02 13:17:59 -07:00
Jake VanderPlas
3aeabaedea jnp.isinf & friends: support __jax_array__ 2025-04-02 08:38:54 -07:00
Sergei Lebedev
45d577d3dc Prepare for disallowing jnp.array(None)
PiperOrigin-RevId: 743074472
2025-04-02 04:17:36 -07:00
jax authors
398e8b0d79 Merge pull request #27644 from jakevdp:cumulative-jax-array
PiperOrigin-RevId: 743046047
2025-04-02 02:37:13 -07:00
Ayaka
f139192201 Add OOB checks to jax.numpy array indexing
PiperOrigin-RevId: 742927160
2025-04-01 19:17:57 -07:00
jax authors
747c5803c3 Merge pull request #27632 from LouisJustinTALLOT:patch-1
PiperOrigin-RevId: 742864570
2025-04-01 15:36:02 -07:00
Jake VanderPlas
4908b2f167 cumulative reductions: support __jax_array__ on inputs 2025-04-01 13:02:25 -07:00
Jake VanderPlas
7b04a79fbd jnp.einsum: add support for __jax_array__ 2025-04-01 12:26:26 -07:00
Jake VanderPlas
a34c462875 jnp.select: support __jax_array__ for inputs 2025-04-01 09:53:29 -07:00
Louis-Justin TALLOT
6adb728975
Clarify documentation of jnp.heaviside 2025-04-01 02:46:30 -04:00
Jake VanderPlas
4003e2d0ee jnp.power: support __jax_array__ on inputs 2025-03-31 16:50:04 -07:00
Jake VanderPlas
ca36047ac9 __jax_array__: add support in jnp.reshape, jnp.transpose, jnp.matrix_transpose 2025-03-31 15:14:47 -07:00
Jake VanderPlas
200f826398 [array api] return all devices in devices() 2025-03-31 08:50:39 -07:00
jax authors
ebd90e06fa Merge pull request #27585 from jakevdp:default-dtype-doc
PiperOrigin-RevId: 741691513
2025-03-28 17:23:16 -07:00
Jake VanderPlas
dafebd0d7f DOC: add documentation note about default dtypes 2025-03-28 15:20:58 -07:00
jax authors
6edc31ae1d Merge pull request #27525 from jakevdp:ml-dtypes-cleanup
PiperOrigin-RevId: 741651222
2025-03-28 14:38:38 -07:00
jax authors
679b6102e1 Merge pull request #27488 from jakevdp:array-capabilities
PiperOrigin-RevId: 741565179
2025-03-28 10:16:08 -07:00
Jake VanderPlas
431c2c0807 cleanup now that we depend on ml_dtypes>=0.5 2025-03-28 07:44:38 -07:00
Jake VanderPlas
667c4a0ee0 Support __jax_array__ for jnp.shape/jnp.size/jnp.ndim 2025-03-26 15:27:25 -07:00
Ayaka
ce3941c635 Add division-by-zero checks to jax.numpy functions
PiperOrigin-RevId: 740906595
2025-03-26 14:35:56 -07:00
Jake VanderPlas
66908372af jnp.tri*_indices: support __jax_array__ inputs 2025-03-26 14:06:26 -07:00
Jake VanderPlas
096810a721 [array API] make capabilities more accurate 2025-03-26 12:11:47 -07:00
Ayaka
feed69c561 Add nan checking to jax.numpy functions
PiperOrigin-RevId: 740838221
2025-03-26 11:19:22 -07:00
Ayaka
b1b281a427 Prototype of adding error checking to jax.numpy functions
PiperOrigin-RevId: 740822504
2025-03-26 10:37:34 -07:00
Jake VanderPlas
91a07ea2e8 Clean up a number of finalized deprecations 2025-03-26 09:57:19 -07:00
Jake VanderPlas
85150471e2 Support __jax_array__ in jnp.full_like & co 2025-03-25 13:45:54 -07:00
Sergei Lebedev
92f231e875 Delay the unflattening in jnp.array
Reverts 53e8eac7134a13c1d28de673e7e3a23b4a837aed

PiperOrigin-RevId: 740012608
2025-03-24 11:32:23 -07:00
Brian Zhao
53e8eac713 Reverts be5713309521d5cf0d2252b9c8f1d38ab50952d1
PiperOrigin-RevId: 739258607
2025-03-21 12:12:45 -07:00
Sergei Lebedev
be57133095 Delay the unflattening in jnp.array
PiperOrigin-RevId: 739143346
2025-03-21 05:18:41 -07:00
Peter Hawkins
3f91b4b43a Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/

Cleanup only, no functional changes intended.

PiperOrigin-RevId: 738183402
2025-03-18 16:29:37 -07:00
Peter Hawkins
6fa98fc0a4 Use "x is y" rather than "id(x) == id(y)".
The latter involves at least two object constructions.

PiperOrigin-RevId: 736878098
2025-03-14 08:54:46 -07:00
jax authors
18f2f19c1a Merge pull request #26525 from wenscarl:e2m1fn
PiperOrigin-RevId: 735457804
2025-03-10 11:46:18 -07:00
Yash Katariya
e9486920e8 Auto complete specs in a sharding if aval.ndim > len(sharding.spec) with None. So that for a 2D input, P('data') continues to work.
PiperOrigin-RevId: 734325209
2025-03-06 16:10:14 -08:00
shuw
c099e8081d support e2m1fn 2025-03-05 17:44:34 +00:00