8 Commits

Author SHA1 Message Date
Jake VanderPlas
a09fdf6e2f Add jax.numpy.bitwise_count() 2023-10-03 13:48:16 -07:00
Jake VanderPlas
2902b32e33 [typing] allow Sequence inputs in several jax.numpy functions 2023-10-02 11:48:36 -07:00
Jake VanderPlas
adba2f0859 Add type stubs for jax.numpy.
This allows mypy/pytype to obtain accurate types for the public jax.numpy APIs, which is helpful to downstream users of JAX, if not JAX itself.

PiperOrigin-RevId: 570058363
2023-10-02 07:20:20 -07:00
Patrick Kidger
baab7b181b
Fix pyright complaining about jnp.{linalg,fft} not existing. 2023-09-20 20:23:05 +01:00
Peter Hawkins
70206ee6cd Give jax.numpy.array the type Callable.
This is to prevent users from using as the type of arrays in type annotations.

PiperOrigin-RevId: 560754568
2023-08-28 10:41:07 -07:00
Peter Hawkins
abff9d2898 Remove jax.numpy.alltrue from type stub.
This function is already deprecated.

PiperOrigin-RevId: 559257301
2023-08-22 16:33:38 -07:00
Jake VanderPlas
19a57e1a01 Deprecate jax.numpy.row_stack 2023-08-22 13:12:49 -07:00
Peter Hawkins
3082109a59 Add a type stub for jax.numpy.
This type stub is intended to match what pytype currently infers for jax.numpy, which is not particularly accurate in many cases. Future changes will add more accurate types to this stub.

Fix a number of new type errors this reveals to mypy.

PiperOrigin-RevId: 559179804
2023-08-22 11:50:49 -07:00