Yash Katariya
854b2c85db
Drop into Auto
mode for .at[...].set(...)
but instead of taking an out_sharding
argument in set
, use the input array's sharding
. Since this is an update, after .set
, the input array's sharding should be preserved.
...
Fixes https://github.com/jax-ml/jax/issues/28111
PiperOrigin-RevId: 749089846
2025-04-18 11:17:38 -07:00
Yash Katariya
7de522c5a3
Enter into auto mode for .at[...].get(...)
a bit earlier so that all ops inside _gather
are in auto mode.
...
Fix select's batching rule where `explicit_mesh_axis` that we capture in `axis_data` was not propagated properly to the `broadcast` happening in `bdim_at_front`.
PiperOrigin-RevId: 748867490
2025-04-17 17:42:13 -07:00
Sergei Lebedev
c576d328bd
Added lax.axis_size
and switched all existing usage of psum(1, ...)
to it
...
PiperOrigin-RevId: 748604842
2025-04-17 02:22:25 -07:00
Jake VanderPlas
42542feac6
jnp.power: better docs for invalid input
2025-04-14 10:42:29 -07:00
carlosgmartin
b6a46310d1
Merge tuple_replace and tuple_update in jax._src.util.
2025-04-09 12:50:42 -04:00
jax authors
76825a2d45
Merge pull request #27807 from jakevdp:eigvalsh-symmetrize
...
PiperOrigin-RevId: 745216021
2025-04-08 11:09:58 -07:00
Jake VanderPlas
b7d430f96b
jnp.repeat: don't cast repeats to array, as they must be static.
2025-04-08 10:32:03 -07:00
Peter Hawkins
e02faabfb2
Replace references to jax.readthedocs.io with docs.jax.dev.
...
PiperOrigin-RevId: 745156931
2025-04-08 08:33:49 -07:00
Jake VanderPlas
96e63eaee8
jnp.linalg: add symmetrize_input argument & docs
2025-04-07 14:46:38 -07:00
Jake VanderPlas
d3cfff057f
jax.numpy: support __jax_array__ in remaining APIs
2025-04-07 14:08:35 -07:00
jax authors
35d75183c7
_attempt_rewriting_take_via_slice()
: canonicalize the slice index before checking it's not too long, so that e.g. my_1d_array[:, ...]
can be treated as a slice rather than generating a gather operation.
...
PiperOrigin-RevId: 743986126
2025-04-04 10:10:38 -07:00
jax authors
d7fc04b682
Merge pull request #27681 from jakevdp:jax-array
...
PiperOrigin-RevId: 743658654
2025-04-03 12:32:00 -07:00
Sergei Lebedev
9c58a112b3
jnp.array
no longer accepts None
...
PiperOrigin-RevId: 743291099
2025-04-02 14:58:51 -07:00
Jake VanderPlas
96780f19b0
jax.numpy: support __jax_array__ in several more functions
2025-04-02 14:28:54 -07:00
jax authors
e75b66463c
Merge pull request #27680 from jakevdp:array-api-version
...
PiperOrigin-RevId: 743277981
2025-04-02 14:22:03 -07:00
Jake VanderPlas
a2d62e2d3a
[array_api] update array_api_version to 2024.12
2025-04-02 13:46:07 -07:00
Jake VanderPlas
7f4e8c56fe
jnp.concat and friends: support __jax_array__
2025-04-02 13:17:59 -07:00
Jake VanderPlas
3aeabaedea
jnp.isinf & friends: support __jax_array__
2025-04-02 08:38:54 -07:00
Sergei Lebedev
45d577d3dc
Prepare for disallowing jnp.array(None)
...
PiperOrigin-RevId: 743074472
2025-04-02 04:17:36 -07:00
jax authors
398e8b0d79
Merge pull request #27644 from jakevdp:cumulative-jax-array
...
PiperOrigin-RevId: 743046047
2025-04-02 02:37:13 -07:00
Ayaka
f139192201
Add OOB checks to jax.numpy array indexing
...
PiperOrigin-RevId: 742927160
2025-04-01 19:17:57 -07:00
jax authors
747c5803c3
Merge pull request #27632 from LouisJustinTALLOT:patch-1
...
PiperOrigin-RevId: 742864570
2025-04-01 15:36:02 -07:00
Jake VanderPlas
4908b2f167
cumulative reductions: support __jax_array__ on inputs
2025-04-01 13:02:25 -07:00
Jake VanderPlas
7b04a79fbd
jnp.einsum: add support for __jax_array__
2025-04-01 12:26:26 -07:00
Jake VanderPlas
a34c462875
jnp.select: support __jax_array__ for inputs
2025-04-01 09:53:29 -07:00
Louis-Justin TALLOT
6adb728975
Clarify documentation of jnp.heaviside
2025-04-01 02:46:30 -04:00
Jake VanderPlas
4003e2d0ee
jnp.power: support __jax_array__ on inputs
2025-03-31 16:50:04 -07:00
Jake VanderPlas
ca36047ac9
__jax_array__: add support in jnp.reshape, jnp.transpose, jnp.matrix_transpose
2025-03-31 15:14:47 -07:00
Jake VanderPlas
200f826398
[array api] return all devices in devices()
2025-03-31 08:50:39 -07:00
jax authors
ebd90e06fa
Merge pull request #27585 from jakevdp:default-dtype-doc
...
PiperOrigin-RevId: 741691513
2025-03-28 17:23:16 -07:00
Jake VanderPlas
dafebd0d7f
DOC: add documentation note about default dtypes
2025-03-28 15:20:58 -07:00
jax authors
6edc31ae1d
Merge pull request #27525 from jakevdp:ml-dtypes-cleanup
...
PiperOrigin-RevId: 741651222
2025-03-28 14:38:38 -07:00
jax authors
679b6102e1
Merge pull request #27488 from jakevdp:array-capabilities
...
PiperOrigin-RevId: 741565179
2025-03-28 10:16:08 -07:00
Jake VanderPlas
431c2c0807
cleanup now that we depend on ml_dtypes>=0.5
2025-03-28 07:44:38 -07:00
Jake VanderPlas
667c4a0ee0
Support __jax_array__ for jnp.shape/jnp.size/jnp.ndim
2025-03-26 15:27:25 -07:00
Ayaka
ce3941c635
Add division-by-zero checks to jax.numpy functions
...
PiperOrigin-RevId: 740906595
2025-03-26 14:35:56 -07:00
Jake VanderPlas
66908372af
jnp.tri*_indices: support __jax_array__ inputs
2025-03-26 14:06:26 -07:00
Jake VanderPlas
096810a721
[array API] make capabilities more accurate
2025-03-26 12:11:47 -07:00
Ayaka
feed69c561
Add nan checking to jax.numpy functions
...
PiperOrigin-RevId: 740838221
2025-03-26 11:19:22 -07:00
Ayaka
b1b281a427
Prototype of adding error checking to jax.numpy functions
...
PiperOrigin-RevId: 740822504
2025-03-26 10:37:34 -07:00
Jake VanderPlas
91a07ea2e8
Clean up a number of finalized deprecations
2025-03-26 09:57:19 -07:00
Jake VanderPlas
85150471e2
Support __jax_array__ in jnp.full_like & co
2025-03-25 13:45:54 -07:00
Sergei Lebedev
92f231e875
Delay the unflattening in jnp.array
...
Reverts 53e8eac7134a13c1d28de673e7e3a23b4a837aed
PiperOrigin-RevId: 740012608
2025-03-24 11:32:23 -07:00
Brian Zhao
53e8eac713
Reverts be5713309521d5cf0d2252b9c8f1d38ab50952d1
...
PiperOrigin-RevId: 739258607
2025-03-21 12:12:45 -07:00
Sergei Lebedev
be57133095
Delay the unflattening in jnp.array
...
PiperOrigin-RevId: 739143346
2025-03-21 05:18:41 -07:00
Peter Hawkins
3f91b4b43a
Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/
...
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/
Cleanup only, no functional changes intended.
PiperOrigin-RevId: 738183402
2025-03-18 16:29:37 -07:00
Peter Hawkins
6fa98fc0a4
Use "x is y" rather than "id(x) == id(y)".
...
The latter involves at least two object constructions.
PiperOrigin-RevId: 736878098
2025-03-14 08:54:46 -07:00
jax authors
18f2f19c1a
Merge pull request #26525 from wenscarl:e2m1fn
...
PiperOrigin-RevId: 735457804
2025-03-10 11:46:18 -07:00
Yash Katariya
e9486920e8
Auto complete specs in a sharding if aval.ndim > len(sharding.spec) with None
. So that for a 2D input, P('data') continues to work.
...
PiperOrigin-RevId: 734325209
2025-03-06 16:10:14 -08:00
shuw
c099e8081d
support e2m1fn
2025-03-05 17:44:34 +00:00