jax authors
b2a8df7183
Add the method
argument to jax.numpy.isin
stub.
...
This parameter is available from https://github.com/google/jax/pull/23040 and documented in https://docs.jax.dev/en/latest/_autosummary/jax.numpy.isin.html .
PiperOrigin-RevId: 746606206
2025-04-11 15:15:22 -07:00
Jake VanderPlas
431c2c0807
cleanup now that we depend on ml_dtypes>=0.5
2025-03-28 07:44:38 -07:00
Jake VanderPlas
667c4a0ee0
Support __jax_array__ for jnp.shape/jnp.size/jnp.ndim
2025-03-26 15:27:25 -07:00
Jake VanderPlas
66908372af
jnp.tri*_indices: support __jax_array__ inputs
2025-03-26 14:06:26 -07:00
Jake VanderPlas
91a07ea2e8
Clean up a number of finalized deprecations
2025-03-26 09:57:19 -07:00
shuw
c099e8081d
support e2m1fn
2025-03-05 17:44:34 +00:00
Jake VanderPlas
8cec6e636a
jax.numpy ndim/shape/size: deprecate non-array input
2025-03-04 10:42:32 -08:00
jax authors
72f0a90ee6
Merge pull request #26401 from jakevdp:numpy-consts
...
PiperOrigin-RevId: 728292846
2025-02-18 11:32:25 -08:00
BaconBreaker
422b747dfe
Change type signature of lexsort in stub file to match type signature in sorting.py
2025-02-18 13:24:02 +01:00
Jake VanderPlas
33b989ac9e
refactor: import numpy objects directly in jax.numpy
2025-02-14 12:47:58 -08:00
Jake VanderPlas
f750d0b855
refactor: move lax_numpy indexing routines to their own submodule
2025-02-13 12:03:07 -08:00
Jake VanderPlas
7ab7b214ac
refactor: move jnp.einsum impl into its own submodule
2025-02-12 09:05:30 -08:00
Jake VanderPlas
e6fc7f3e87
refactor: move lax_numpy tensor contractions into their own file
2025-02-10 18:56:18 -08:00
Jake VanderPlas
17215177fa
refactor: move lax_numpy window functions into their own file
2025-02-07 11:21:38 -08:00
jax authors
ec477634f1
Merge pull request #26376 from jakevdp:array-creation
...
PiperOrigin-RevId: 724399604
2025-02-07 10:48:05 -08:00
Jake VanderPlas
d3b3cd369f
refactor: move sorting ops out of lax_numpy
2025-02-07 08:18:04 -08:00
Jake VanderPlas
7bacfbc658
refactor: move array creation routines out of lax_numpy.py
2025-02-06 15:47:30 -08:00
Jake VanderPlas
b4f98eef7e
refactor: move scalar type defs out of lax_numpy.py
2025-02-06 14:48:10 -08:00
wenscarl
638c6ae046
Add e8m0fnu support by conditional dtype.
2025-01-22 21:57:43 +00:00
Yash Katariya
97cd748376
Rename out_type -> out_sharding parameter on einsum
...
PiperOrigin-RevId: 716454800
2025-01-16 18:16:52 -08:00
Yash Katariya
c6b5ac5c7b
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
...
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:47:26 -08:00
Jake VanderPlas
c73f306099
Finalize deprecation of jnp.round_
...
PiperOrigin-RevId: 705998500
2024-12-13 14:13:44 -08:00
Jake VanderPlas
f6d58761d1
jax.numpy: implement matvec & vecmat
2024-12-10 16:03:19 -08:00
jax authors
91891cb600
Merge pull request #23585 from apivovarov:float8_e4m3
...
PiperOrigin-RevId: 697760985
2024-11-18 14:34:59 -08:00
Jake VanderPlas
5bebd0f6c4
fix typo in numpy/__init__.pyi
2024-11-18 11:04:33 -08:00
Jake VanderPlas
e9864c69da
Make logaddexp and logaddexp2 into ufuncs
2024-11-18 09:27:36 -08:00
Jake VanderPlas
5f94284432
Add missing functions to jax.numpy type interface
2024-11-15 12:14:55 -08:00
carlosgmartin
1f114b1cf7
Add numpy.put_along_axis.
2024-11-14 15:23:26 -05:00
Sergei Lebedev
78da9fa432
Add float8_e4m3 and float8_e3m4 types support
2024-11-08 18:58:31 +00:00
Jake VanderPlas
2b9c73d10d
Remove a number of expired deprecations.
...
These APIs were all removed 3 or more months ago, and the registrations
here cause them to raise informative AttributeErrors. Enough time has
passed now that we can remove these.
2024-10-31 15:40:54 -07:00
Jake VanderPlas
02daf75f97
Add new jnp.cumulative_prod function.
...
This follows the API of the similar function added in NumPy 2.1.0
2024-10-25 13:45:54 -07:00
Jake VanderPlas
6467d03925
Make jnp.subtract a ufunc
2024-10-21 10:11:51 -07:00
Jake VanderPlas
0a85ba5f82
Better documentation for jnp.load
2024-10-19 06:20:20 -07:00
Jake VanderPlas
de3191fab3
Cleanup: fix unused imports & mark exported names
2024-10-16 17:42:41 -07:00
Jake VanderPlas
b574d2ceb1
Fix aliases in jax.numpy type interface file.
...
This includes removing some alias declarations for functions that were
previously removed.
2024-10-16 10:40:56 -07:00
Jake VanderPlas
635e29a0b9
Implement jax.numpy.spacing
...
Somehow we've missed this numpy API up until now.
2024-10-03 10:40:39 -07:00
Jake VanderPlas
c0612576de
Better documentation for jnp.choose
2024-09-27 10:35:19 -07:00
Michael Hudgins
d4d1518c3d
Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
...
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Jake VanderPlas
57a4b76d09
Improve documentation for jnp.digitize
2024-09-18 11:59:00 -07:00
Jake VanderPlas
e7d3785b18
Refactor & document cumulative reductions
2024-09-04 12:20:24 -07:00
Jake VanderPlas
f2ffe7f8f2
Deprecate jax.numpy.round_
...
NumPy removed np.round in version 2.0; jax.numpy.round is drop-in
replacement.
2024-09-03 06:52:07 -07:00
Jake VanderPlas
a3d6cf007e
First pass at ufunc interfaces for several jax.numpy functions
2024-08-30 11:53:02 -07:00
Jake VanderPlas
c2c116dc5c
jnp.intersect1d: add support for static size argument.
2024-08-10 05:22:05 -07:00
Jake VanderPlas
14fa06298e
[array api] Finalize array API in jax.numpy & deprecate jax.experimental.array_api
2024-08-01 11:19:17 -07:00
Jake VanderPlas
c2f2b0ed28
[array API] move api metadata into jax.numpy namespace
2024-07-30 12:15:24 -07:00
Jake VanderPlas
ff8e8ad2fe
revert #22734
...
Reverts 5ce66dc1aae67a88a8ed72584bdc3f5a7f712507
PiperOrigin-RevId: 657638187
2024-07-30 10:17:34 -07:00
Jake VanderPlas
00ba7a6d25
[array API] move api metadata into jax.numpy namespace
2024-07-29 12:43:11 -07:00
vfdev-5
76d61f9d8f
Added device kwargs to jnp.linspace, jnp.array, jnp.asarray
2024-07-26 00:36:34 +02:00
Jake VanderPlas
a88a4b13fb
Add missing parameters to jnp.compress type interface
2024-07-23 07:14:46 -07:00
Dan Foreman-Mackey
991187aaa8
Fix dtype canonicalization in jnp.indices
.
...
`jnp.indices` was hard coded to default to `dtype = np.int32`, but it
should default to the canonicalized `np.int64`.
Fixes https://github.com/google/jax/issues/22501
2024-07-22 15:02:48 -04:00