797 Commits

Author SHA1 Message Date
Meekail Zain
ceeb975735 Add new cumulative_sum function to numpy and array_api 2024-04-16 19:57:55 +00:00
Meekail Zain
6bdc83c680 Add new unstack function to numpy/array_api namespaces 2024-04-15 21:03:26 +00:00
jax authors
2512843a56 Merge pull request #20550 from Micky774:api_clip
PiperOrigin-RevId: 622045823
2024-04-04 19:58:06 -07:00
Meekail Zain
8b7aae586b Update jnp.clip to Array API 2023 standard 2024-04-04 22:55:10 +00:00
Meekail Zain
2b1c3deee2 Update from_dlpack to match array API 2023 2024-04-04 22:51:25 +00:00
Jake VanderPlas
9e01afe7af Add jax.numpy.trapezoid
This function has been added to NumPy in version 2.0, as a replacement
for the already deprecated trapz function.
2024-04-01 13:05:20 -07:00
Jake VanderPlas
ac19d0f3b2 Fix bool annotations in jax.numpy APIs
PiperOrigin-RevId: 609704549
2024-02-23 06:03:04 -08:00
Dan F-M
beeaf3570e Fixing tiny typo in jax.numpy type definitions
Fixing typo in function definition too
2024-02-21 20:46:51 -05:00
Jake VanderPlas
9549c745af jnp.full_like & co: support device parameter 2024-01-26 10:01:54 -08:00
Jake VanderPlas
d55cd7c9e2 jax.numpy: support device argument for full, empty, zeros, ones 2024-01-24 12:01:09 -08:00
Jake VanderPlas
71e96faf1a [array API] add jnp.bool 2024-01-17 14:34:27 -08:00
Jake VanderPlas
fbf7492a2c Add jnp.isdtype function, following np.isdtype in NumPy 2.0 2024-01-17 12:14:55 -08:00
Jake VanderPlas
989618c5f7 [array api] add jax.numpy.concat 2024-01-12 13:12:09 -08:00
Jake VanderPlas
b08a010949 [array API] add jnp.linalg.diagonal 2024-01-11 12:52:15 -08:00
Jake VanderPlas
1a39d8fdb2 [array API] implement jnp.pow; alias for jnp.power 2024-01-10 14:59:46 -08:00
Jake VanderPlas
4e55086dfb array api: add jnp.bitwise_* aliases 2024-01-10 14:22:20 -08:00
Jake VanderPlas
9890b23b0a Add jnp.vecdot 2024-01-10 13:11:37 -08:00
Jake VanderPlas
d673b9bf5c [array api] add jax.numpy.permute_dims function 2024-01-08 09:30:51 -08:00
jax authors
ab6dd273c9 Fix a type annotation.
PiperOrigin-RevId: 595984433
2024-01-05 07:01:05 -08:00
Jake VanderPlas
8b62516676 [array api] add stable & descending params to jnp.sort & jnp.argsort 2024-01-04 14:21:25 -08:00
Jake VanderPlas
97fc213eb0 [array API] support copy argument to jnp.asarray 2024-01-03 15:20:27 -08:00
Jake VanderPlas
5e957c6063 array api: add unique_* interfaces 2023-12-21 15:49:54 -08:00
Jake VanderPlas
e3e26f2dde jnp.unique: add support for the equal_nan keyword 2023-12-21 12:37:09 -08:00
Jake VanderPlas
e98bb7c3ab jax.numpy: add trig aliases acos(h), asin(h), atan(h), atan2 2023-12-19 14:15:29 -08:00
Jake VanderPlas
832ac874bd jnp.linalg: add matmul, tensordot, & svdvals 2023-12-19 11:36:09 -08:00
Jake VanderPlas
cab63114b4 Remove deprecated function jax.numpy.trapz
This was deprecated prior to the JAX 0.4.16 release, so we have now met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 592266215
2023-12-19 09:57:39 -08:00
Jake VanderPlas
e356d76913 Remove a number of deprecated APIs
All of these were deprecated prior to the JAX 0.4.16 release, on Sept 18 2023.
As of Monday Dec 18, we have met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 591933493
2023-12-18 10:08:47 -08:00
Jake VanderPlas
0c7b959dac jnp.linalg: add matrix_norm, matrix_transpose, vector_norm, vector_transpose
These have been added upstream to numpy.linalg in NumPy 2.0, as part of the Array API standard.
2023-12-15 14:17:36 -08:00
Jake VanderPlas
fe2ad89209 array api: add jnp.linalg.cross & jnp.linalg.outer 2023-12-12 11:22:31 -08:00
Jake VanderPlas
d77cd9a0f4 Add jax.numpy.astype function 2023-11-30 15:50:22 -08:00
Peter Hawkins
84c1e825c0 Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.
PiperOrigin-RevId: 584377134
2023-11-21 11:10:12 -08:00
Neil Girdhar
1452e219df Annotate jax.numpy.atleast_xd functions
These tighter annotations should reduce the need for casting and
ignoring types without causing any type errors.
2023-11-09 19:21:35 -05:00
Jake VanderPlas
ea9eb6d2b1 [CI] avoid referencing numpy.core to fix nightly CI 2023-10-23 09:11:33 -07:00
jax authors
dde17cd5bc Merge pull request #18180 from carlosgmartin:fill_diagonal
PiperOrigin-RevId: 575317151
2023-10-20 14:20:14 -07:00
carlosgmartin
3cb504c583 Add jax.numpy.fill_diagonal. 2023-10-20 16:47:46 -04:00
Jake VanderPlas
e7bcfcff4c Avoid numpy.core import for NumPy 2.0 2023-10-19 09:23:11 -07:00
Jake VanderPlas
1815bc7632 [typing] allow scalar shape for jnp.broadcast_to 2023-10-13 13:37:20 -07:00
Jake VanderPlas
a09fdf6e2f Add jax.numpy.bitwise_count() 2023-10-03 13:48:16 -07:00
Jake VanderPlas
2902b32e33 [typing] allow Sequence inputs in several jax.numpy functions 2023-10-02 11:48:36 -07:00
Jake VanderPlas
adba2f0859 Add type stubs for jax.numpy.
This allows mypy/pytype to obtain accurate types for the public jax.numpy APIs, which is helpful to downstream users of JAX, if not JAX itself.

PiperOrigin-RevId: 570058363
2023-10-02 07:20:20 -07:00
Patrick Kidger
baab7b181b
Fix pyright complaining about jnp.{linalg,fft} not existing. 2023-09-20 20:23:05 +01:00
Jake VanderPlas
22ff7bd19a Finish the deprecation cycle for jnp.alltrue, jnp.sometrue, jnp.product, jnp.cumproduct
These have been deprecated in JAX following similar deprecations in numpy v1.25.0

PiperOrigin-RevId: 565122288
2023-09-13 12:07:36 -07:00
Peter Hawkins
70206ee6cd Give jax.numpy.array the type Callable.
This is to prevent users from using as the type of arrays in type annotations.

PiperOrigin-RevId: 560754568
2023-08-28 10:41:07 -07:00
Peter Hawkins
975dae34a4 Deprecate jax.numpy.trapz.
Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead.

Fixes https://github.com/google/jax/issues/17244
2023-08-25 09:04:13 -06:00
Peter Hawkins
7c871916f7 Deprecate jax.numpy.in1d.
Issue https://github.com/google/jax/issues/17244
2023-08-23 17:36:14 -06:00
Peter Hawkins
abff9d2898 Remove jax.numpy.alltrue from type stub.
This function is already deprecated.

PiperOrigin-RevId: 559257301
2023-08-22 16:33:38 -07:00
Jake VanderPlas
19a57e1a01 Deprecate jax.numpy.row_stack 2023-08-22 13:12:49 -07:00
Peter Hawkins
3082109a59 Add a type stub for jax.numpy.
This type stub is intended to match what pytype currently infers for jax.numpy, which is not particularly accurate in many cases. Future changes will add more accurate types to this stub.

Fix a number of new type errors this reveals to mypy.

PiperOrigin-RevId: 559179804
2023-08-22 11:50:49 -07:00
Jake VanderPlas
8bba992f9a deprecate jax.numpy.issubsctype 2023-08-17 12:27:52 -07:00
Jake VanderPlas
ad8e719b82 Add jnp.ufunc and jnp.frompyfunc 2023-08-10 14:58:18 -07:00