42 Commits

Author SHA1 Message Date
Meekail Zain
ceeb975735 Add new cumulative_sum function to numpy and array_api 2024-04-16 19:57:55 +00:00
Jake VanderPlas
572c16284e [array api] update to latest test repo commit 2024-04-16 06:09:00 -07:00
Meekail Zain
6bdc83c680 Add new unstack function to numpy/array_api namespaces 2024-04-15 21:03:26 +00:00
jax authors
5f22b12576 Merge pull request #20754 from Micky774:array-api-hypot
PiperOrigin-RevId: 625035601
2024-04-15 11:56:53 -07:00
Meekail Zain
2899213efb Fixed hypot bug on nan/inf pairings, began deprecation of non-real values 2024-04-15 17:56:16 +00:00
Meekail Zain
8b93da1830 Expose existing functions in array API namespace 2024-04-15 16:25:30 +00:00
jax authors
301c3518d8 Merge pull request #20294 from Micky774:array_namespace_info
PiperOrigin-RevId: 623877931
2024-04-11 11:09:37 -07:00
Meekail Zain
e6508a4f47 Add __array_namespace_info__ and corresponding utilities 2024-04-11 14:20:44 +00:00
jax authors
77db7a60ed Merge pull request #20637 from jakevdp:array-api-scalar
PiperOrigin-RevId: 623184036
2024-04-09 09:04:12 -07:00
Jake VanderPlas
1b3aea8205 Finalize the deprecation of the arr.device() method
The method has been emitting an DeprecationWarning since JAX v0.4.21, released December 2023. Existing uses can be replaced with `arr.devices()` or `arr.sharding`, depending on the context.

PiperOrigin-RevId: 623015500
2024-04-08 19:04:15 -07:00
Jake VanderPlas
c19c1a7148 [array api] allow Python scalar arguments to functions 2024-04-08 10:10:01 -07:00
jax authors
2512843a56 Merge pull request #20550 from Micky774:api_clip
PiperOrigin-RevId: 622045823
2024-04-04 19:58:06 -07:00
Meekail Zain
8b7aae586b Update jnp.clip to Array API 2023 standard 2024-04-04 22:55:10 +00:00
Meekail Zain
2b1c3deee2 Update from_dlpack to match array API 2023 2024-04-04 22:51:25 +00:00
Jake VanderPlas
85f205bdc7 typing: fix incorrect tuple annotations 2024-02-26 10:53:19 -08:00
jax authors
e1e9de0e7b Merge pull request #19499 from nstarman:array_api-broadcast_to-type-hint
PiperOrigin-RevId: 609852827
2024-02-23 15:28:33 -08:00
nstarman
b9f28572f7 Fix type annotation for array_api.broadcast_to
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
2024-02-23 18:32:43 +00:00
Jake VanderPlas
bbfd4f2c26 jax.numpy: implement scalar boolean indexing 2024-02-09 11:00:00 -08:00
Jake VanderPlas
6b713071e4 [array API] update to recent array API test version 2024-02-05 15:15:57 -08:00
Jake VanderPlas
c9a700921b jnp.linalg.cholesky: add upper argument 2024-01-31 14:16:12 -08:00
jax authors
df28ee7fd8 Merge pull request #19400 from jakevdp:jnp-isdtype
PiperOrigin-RevId: 599265256
2024-01-17 13:08:51 -08:00
Jake VanderPlas
fbf7492a2c Add jnp.isdtype function, following np.isdtype in NumPy 2.0 2024-01-17 12:14:55 -08:00
jax authors
3e8067060e Merge pull request #19347 from jakevdp:linalg-retvals
PiperOrigin-RevId: 599244819
2024-01-17 11:55:32 -08:00
Jake VanderPlas
989618c5f7 [array api] add jax.numpy.concat 2024-01-12 13:12:09 -08:00
Jake VanderPlas
012c5bd439 [array api] return NamedTuple from np.linalg APIs 2024-01-12 13:10:50 -08:00
Jake VanderPlas
b08a010949 [array API] add jnp.linalg.diagonal 2024-01-11 12:52:15 -08:00
Jake VanderPlas
c906f44ac1 array api: simplify some wrappers 2024-01-10 15:49:15 -08:00
Jake VanderPlas
1a39d8fdb2 [array API] implement jnp.pow; alias for jnp.power 2024-01-10 14:59:46 -08:00
Jake VanderPlas
4e55086dfb array api: add jnp.bitwise_* aliases 2024-01-10 14:22:20 -08:00
Jake VanderPlas
d673b9bf5c [array api] add jax.numpy.permute_dims function 2024-01-08 09:30:51 -08:00
Jake VanderPlas
38257389af [array api] fix linalg.solve and enable test 2024-01-05 15:30:18 -08:00
Jake VanderPlas
8b62516676 [array api] add stable & descending params to jnp.sort & jnp.argsort 2024-01-04 14:21:25 -08:00
Jake VanderPlas
5e957c6063 array api: add unique_* interfaces 2023-12-21 15:49:54 -08:00
Jake VanderPlas
832ac874bd jnp.linalg: add matmul, tensordot, & svdvals 2023-12-19 11:36:09 -08:00
Jake VanderPlas
0c7b959dac jnp.linalg: add matrix_norm, matrix_transpose, vector_norm, vector_transpose
These have been added upstream to numpy.linalg in NumPy 2.0, as part of the Array API standard.
2023-12-15 14:17:36 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Jake VanderPlas
fe2ad89209 array api: add jnp.linalg.cross & jnp.linalg.outer 2023-12-12 11:22:31 -08:00
jax authors
809a37c567 Merge pull request #18881 from superbobry:pyupgrade
PiperOrigin-RevId: 589191161
2023-12-08 11:20:50 -08:00
Jake VanderPlas
4b1077da09 array-api: update test suite & fix nonzero 2023-12-08 08:55:57 -08:00
Sergei Lebedev
36f6b52e42 Upgrade most .py sources to 3.9
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
Jake VanderPlas
d6154e5d89 [array-api] remove some test skips 2023-11-30 13:28:08 -08:00
Jake VanderPlas
271d31c1c8 Add jax.experimental.array_api interface 2023-11-16 14:21:04 -08:00