765 Commits

Author SHA1 Message Date
Jake VanderPlas
ea9eb6d2b1 [CI] avoid referencing numpy.core to fix nightly CI 2023-10-23 09:11:33 -07:00
jax authors
dde17cd5bc Merge pull request #18180 from carlosgmartin:fill_diagonal
PiperOrigin-RevId: 575317151
2023-10-20 14:20:14 -07:00
carlosgmartin
3cb504c583 Add jax.numpy.fill_diagonal. 2023-10-20 16:47:46 -04:00
Jake VanderPlas
e7bcfcff4c Avoid numpy.core import for NumPy 2.0 2023-10-19 09:23:11 -07:00
Jake VanderPlas
1815bc7632 [typing] allow scalar shape for jnp.broadcast_to 2023-10-13 13:37:20 -07:00
Jake VanderPlas
a09fdf6e2f Add jax.numpy.bitwise_count() 2023-10-03 13:48:16 -07:00
Jake VanderPlas
2902b32e33 [typing] allow Sequence inputs in several jax.numpy functions 2023-10-02 11:48:36 -07:00
Jake VanderPlas
adba2f0859 Add type stubs for jax.numpy.
This allows mypy/pytype to obtain accurate types for the public jax.numpy APIs, which is helpful to downstream users of JAX, if not JAX itself.

PiperOrigin-RevId: 570058363
2023-10-02 07:20:20 -07:00
Patrick Kidger
baab7b181b
Fix pyright complaining about jnp.{linalg,fft} not existing. 2023-09-20 20:23:05 +01:00
Jake VanderPlas
22ff7bd19a Finish the deprecation cycle for jnp.alltrue, jnp.sometrue, jnp.product, jnp.cumproduct
These have been deprecated in JAX following similar deprecations in numpy v1.25.0

PiperOrigin-RevId: 565122288
2023-09-13 12:07:36 -07:00
Peter Hawkins
70206ee6cd Give jax.numpy.array the type Callable.
This is to prevent users from using as the type of arrays in type annotations.

PiperOrigin-RevId: 560754568
2023-08-28 10:41:07 -07:00
Peter Hawkins
975dae34a4 Deprecate jax.numpy.trapz.
Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead.

Fixes https://github.com/google/jax/issues/17244
2023-08-25 09:04:13 -06:00
Peter Hawkins
7c871916f7 Deprecate jax.numpy.in1d.
Issue https://github.com/google/jax/issues/17244
2023-08-23 17:36:14 -06:00
Peter Hawkins
abff9d2898 Remove jax.numpy.alltrue from type stub.
This function is already deprecated.

PiperOrigin-RevId: 559257301
2023-08-22 16:33:38 -07:00
Jake VanderPlas
19a57e1a01 Deprecate jax.numpy.row_stack 2023-08-22 13:12:49 -07:00
Peter Hawkins
3082109a59 Add a type stub for jax.numpy.
This type stub is intended to match what pytype currently infers for jax.numpy, which is not particularly accurate in many cases. Future changes will add more accurate types to this stub.

Fix a number of new type errors this reveals to mypy.

PiperOrigin-RevId: 559179804
2023-08-22 11:50:49 -07:00
Jake VanderPlas
8bba992f9a deprecate jax.numpy.issubsctype 2023-08-17 12:27:52 -07:00
Jake VanderPlas
ad8e719b82 Add jnp.ufunc and jnp.frompyfunc 2023-08-10 14:58:18 -07:00
Peter Hawkins
1a539bbddc Fix type errors introduced by NINF deprecation change.
PiperOrigin-RevId: 555614877
2023-08-10 12:54:22 -07:00
Peter Hawkins
0e80d959c8 Mark jnp.{NINF,NZERO,PZERO} as deprecated.
This follows the upstream NumPy deprecation of these names (https://github.com/numpy/numpy/pull/24357).

PiperOrigin-RevId: 555548986
2023-08-10 10:25:21 -07:00
Mateusz Sokół
1fedf04ed5 API: Remove NINF and PINF usages 2023-08-09 14:16:33 +02:00
Jake Hall
85f124c18d Add support for float8_e4m3fnuz and float8_e5m2fnuz. 2023-08-07 11:48:53 +01:00
Jake VanderPlas
21f6736005 Remove several deprecated APIs 2023-07-11 12:42:32 -07:00
Jake VanderPlas
9962065deb Require ml_dtypes>=0.2 2023-07-07 12:07:44 -07:00
Jake VanderPlas
2502e2a7be Add support for dtype float8_e4m3b11fnuz
PiperOrigin-RevId: 538101985
2023-06-06 00:50:57 -07:00
Jake VanderPlas
3bef6214bb Deprecate jax.numpy functions alltrue, sometrue, product, cumproduct 2023-06-02 04:10:46 -07:00
Jake VanderPlas
333ff4abbc Add jnp.matrix_transpose() and jax.Array.mT
This is an API proposed by the Python Array API Standard (https://data-apis.org/array-api/2022.12/). It's lightweight enough that there's hardly any downside to supporting it in JAX.
2023-05-25 09:02:05 -07:00
Peter Hawkins
39097df02e Add some preliminary support for int4/uint4 types to JAX.
PiperOrigin-RevId: 533251630
2023-05-18 14:27:33 -07:00
Jake VanderPlas
749dc1b95e Remove deprecated function jnp.msort 2023-03-31 08:24:36 -07:00
Jake VanderPlas
6f8885a0c2 lax_numpy: move quantile-based functions to reductions.py 2023-03-23 16:39:20 -07:00
Jake VanderPlas
87aec2433b internal: refactor array methods into separate private submodule 2023-03-23 10:57:53 -07:00
Peter Hawkins
28e4038933 Mark jax.numpy.DeviceArray as deprecated. Use jax.Array instead.
PiperOrigin-RevId: 516835920
2023-03-15 08:50:00 -07:00
Peter Hawkins
9bf476aaad Redefine jnp.DeviceArray as jax.Array.
The concrete DeviceArray class is slated for deletion on March 15 as part of the jax.Array migration. Replace jnp.DeviceArray with its superclass (jax.Array), which is the closest equivalent in the new world.

In the future, jnp.DeviceArray will be deprecated and deleted in favor of jax.Array.

PiperOrigin-RevId: 515432432
2023-03-09 13:56:13 -08:00
Peter Hawkins
a4412e2715 Remove internal ndarray type name. Use Array throughout.
jax.numpy.ndarray remains an exported alias for jax.Array.

PiperOrigin-RevId: 513046188
2023-02-28 14:51:08 -08:00
Jake VanderPlas
4fbaee5920 Implement jax.numpy.argpartition 2023-02-08 14:41:39 -08:00
Yash Katariya
8a69444ff9 Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
Peter Hawkins
b730ed4645 Remove placeholder functions for unimplemented NumPy functions.
These don't seem necessary now JAX has fairly complete coverage of the NumPy API. Also removes the accidental export of _NOT_IMPLEMENTED in several modules.
2023-02-02 13:00:18 -05:00
Peter Hawkins
c90a85403b Merge pull request #14248 from jakevdp:dead-code
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Jake VanderPlas
217ca5db4b Add implementation of jnp.partition 2023-01-30 13:50:25 -08:00
Qiao Zhang
3cb5c937a3 Expose fp8 from jnp. Add the missing import.
PiperOrigin-RevId: 503500995
2023-01-20 12:43:59 -08:00
Jake VanderPlas
26f2f97805 Document why 'import name as name' is used 2022-12-14 15:07:04 -08:00
Jake VanderPlas
8bde3a0a70 Point to ndarray.at from docstring of unimplemented jnp.put & jnp.place 2022-10-28 14:13:36 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
9769a0accf DOC: ensure that _wraps() generates correct links to wrapped functions 2022-07-21 11:12:35 -07:00
Jake VanderPlas
2f4c485a54 Add dlpack support to device_array and jax.numpy 2022-07-15 17:31:11 -07:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Jiajie Li
128e51c638 Add polydiv to jax.numpy
Fix code style, fix tests

Add warning when use polydiv with trim_leading_zeros

Update warning for polydiv

Co-authored-by: Jake Vanderplas <jakevdp@gmail.com>

Enable type check in _CompileAndCheck

Fix cutoff

Fix cut-off in polydiv

Add trim_zeros_tol, remove redundant code in polydiv

Remove unused import

Fix trim_zero_tol usage in polydiv
2022-04-13 18:31:27 +00:00
Jake VanderPlas
fbfc3d8edf Better error messages for jnp.fromiter and jnp.fromfile 2022-03-29 14:30:32 -07:00
Jake VanderPlas
093b7032a8 Implement jnp.from* array creation functions 2022-03-29 10:52:47 -07:00
Jake VanderPlas
f4d240c036 Remove lax_numpy from jax.numpy namespace
This is a private module that was inadvertently exported in the past.
2022-03-25 15:02:45 -07:00