Jake VanderPlas
ea9eb6d2b1
[CI] avoid referencing numpy.core to fix nightly CI
2023-10-23 09:11:33 -07:00
jax authors
dde17cd5bc
Merge pull request #18180 from carlosgmartin:fill_diagonal
...
PiperOrigin-RevId: 575317151
2023-10-20 14:20:14 -07:00
carlosgmartin
3cb504c583
Add jax.numpy.fill_diagonal.
2023-10-20 16:47:46 -04:00
Jake VanderPlas
e7bcfcff4c
Avoid numpy.core import for NumPy 2.0
2023-10-19 09:23:11 -07:00
Jake VanderPlas
1815bc7632
[typing] allow scalar shape for jnp.broadcast_to
2023-10-13 13:37:20 -07:00
Jake VanderPlas
a09fdf6e2f
Add jax.numpy.bitwise_count()
2023-10-03 13:48:16 -07:00
Jake VanderPlas
2902b32e33
[typing] allow Sequence inputs in several jax.numpy functions
2023-10-02 11:48:36 -07:00
Jake VanderPlas
adba2f0859
Add type stubs for jax.numpy.
...
This allows mypy/pytype to obtain accurate types for the public jax.numpy APIs, which is helpful to downstream users of JAX, if not JAX itself.
PiperOrigin-RevId: 570058363
2023-10-02 07:20:20 -07:00
Patrick Kidger
baab7b181b
Fix pyright complaining about jnp.{linalg,fft} not existing.
2023-09-20 20:23:05 +01:00
Jake VanderPlas
22ff7bd19a
Finish the deprecation cycle for jnp.alltrue, jnp.sometrue, jnp.product, jnp.cumproduct
...
These have been deprecated in JAX following similar deprecations in numpy v1.25.0
PiperOrigin-RevId: 565122288
2023-09-13 12:07:36 -07:00
Peter Hawkins
70206ee6cd
Give jax.numpy.array the type Callable
.
...
This is to prevent users from using as the type of arrays in type annotations.
PiperOrigin-RevId: 560754568
2023-08-28 10:41:07 -07:00
Peter Hawkins
975dae34a4
Deprecate jax.numpy.trapz.
...
Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead.
Fixes https://github.com/google/jax/issues/17244
2023-08-25 09:04:13 -06:00
Peter Hawkins
7c871916f7
Deprecate jax.numpy.in1d.
...
Issue https://github.com/google/jax/issues/17244
2023-08-23 17:36:14 -06:00
Peter Hawkins
abff9d2898
Remove jax.numpy.alltrue from type stub.
...
This function is already deprecated.
PiperOrigin-RevId: 559257301
2023-08-22 16:33:38 -07:00
Jake VanderPlas
19a57e1a01
Deprecate jax.numpy.row_stack
2023-08-22 13:12:49 -07:00
Peter Hawkins
3082109a59
Add a type stub for jax.numpy.
...
This type stub is intended to match what pytype currently infers for jax.numpy, which is not particularly accurate in many cases. Future changes will add more accurate types to this stub.
Fix a number of new type errors this reveals to mypy.
PiperOrigin-RevId: 559179804
2023-08-22 11:50:49 -07:00
Jake VanderPlas
8bba992f9a
deprecate jax.numpy.issubsctype
2023-08-17 12:27:52 -07:00
Jake VanderPlas
ad8e719b82
Add jnp.ufunc and jnp.frompyfunc
2023-08-10 14:58:18 -07:00
Peter Hawkins
1a539bbddc
Fix type errors introduced by NINF deprecation change.
...
PiperOrigin-RevId: 555614877
2023-08-10 12:54:22 -07:00
Peter Hawkins
0e80d959c8
Mark jnp.{NINF,NZERO,PZERO} as deprecated.
...
This follows the upstream NumPy deprecation of these names (https://github.com/numpy/numpy/pull/24357 ).
PiperOrigin-RevId: 555548986
2023-08-10 10:25:21 -07:00
Mateusz Sokół
1fedf04ed5
API: Remove NINF and PINF usages
2023-08-09 14:16:33 +02:00
Jake Hall
85f124c18d
Add support for float8_e4m3fnuz and float8_e5m2fnuz.
2023-08-07 11:48:53 +01:00
Jake VanderPlas
21f6736005
Remove several deprecated APIs
2023-07-11 12:42:32 -07:00
Jake VanderPlas
9962065deb
Require ml_dtypes>=0.2
2023-07-07 12:07:44 -07:00
Jake VanderPlas
2502e2a7be
Add support for dtype float8_e4m3b11fnuz
...
PiperOrigin-RevId: 538101985
2023-06-06 00:50:57 -07:00
Jake VanderPlas
3bef6214bb
Deprecate jax.numpy functions alltrue, sometrue, product, cumproduct
2023-06-02 04:10:46 -07:00
Jake VanderPlas
333ff4abbc
Add jnp.matrix_transpose() and jax.Array.mT
...
This is an API proposed by the Python Array API Standard (https://data-apis.org/array-api/2022.12/ ). It's lightweight enough that there's hardly any downside to supporting it in JAX.
2023-05-25 09:02:05 -07:00
Peter Hawkins
39097df02e
Add some preliminary support for int4/uint4 types to JAX.
...
PiperOrigin-RevId: 533251630
2023-05-18 14:27:33 -07:00
Jake VanderPlas
749dc1b95e
Remove deprecated function jnp.msort
2023-03-31 08:24:36 -07:00
Jake VanderPlas
6f8885a0c2
lax_numpy: move quantile-based functions to reductions.py
2023-03-23 16:39:20 -07:00
Jake VanderPlas
87aec2433b
internal: refactor array methods into separate private submodule
2023-03-23 10:57:53 -07:00
Peter Hawkins
28e4038933
Mark jax.numpy.DeviceArray as deprecated. Use jax.Array instead.
...
PiperOrigin-RevId: 516835920
2023-03-15 08:50:00 -07:00
Peter Hawkins
9bf476aaad
Redefine jnp.DeviceArray as jax.Array.
...
The concrete DeviceArray class is slated for deletion on March 15 as part of the jax.Array migration. Replace jnp.DeviceArray with its superclass (jax.Array), which is the closest equivalent in the new world.
In the future, jnp.DeviceArray will be deprecated and deleted in favor of jax.Array.
PiperOrigin-RevId: 515432432
2023-03-09 13:56:13 -08:00
Peter Hawkins
a4412e2715
Remove internal ndarray type name. Use Array throughout.
...
jax.numpy.ndarray remains an exported alias for jax.Array.
PiperOrigin-RevId: 513046188
2023-02-28 14:51:08 -08:00
Jake VanderPlas
4fbaee5920
Implement jax.numpy.argpartition
2023-02-08 14:41:39 -08:00
Yash Katariya
8a69444ff9
Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
...
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
Peter Hawkins
b730ed4645
Remove placeholder functions for unimplemented NumPy functions.
...
These don't seem necessary now JAX has fairly complete coverage of the NumPy API. Also removes the accidental export of _NOT_IMPLEMENTED in several modules.
2023-02-02 13:00:18 -05:00
Peter Hawkins
c90a85403b
Merge pull request #14248 from jakevdp:dead-code
...
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Jake VanderPlas
217ca5db4b
Add implementation of jnp.partition
2023-01-30 13:50:25 -08:00
Qiao Zhang
3cb5c937a3
Expose fp8 from jnp. Add the missing import.
...
PiperOrigin-RevId: 503500995
2023-01-20 12:43:59 -08:00
Jake VanderPlas
26f2f97805
Document why 'import name as name' is used
2022-12-14 15:07:04 -08:00
Jake VanderPlas
8bde3a0a70
Point to ndarray.at from docstring of unimplemented jnp.put & jnp.place
2022-10-28 14:13:36 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
9769a0accf
DOC: ensure that _wraps() generates correct links to wrapped functions
2022-07-21 11:12:35 -07:00
Jake VanderPlas
2f4c485a54
Add dlpack support to device_array and jax.numpy
2022-07-15 17:31:11 -07:00
Jake VanderPlas
5782210174
CI: fix flake8 ignore declarations
2022-04-21 13:44:12 -07:00
Jiajie Li
128e51c638
Add polydiv to jax.numpy
...
Fix code style, fix tests
Add warning when use polydiv with trim_leading_zeros
Update warning for polydiv
Co-authored-by: Jake Vanderplas <jakevdp@gmail.com>
Enable type check in _CompileAndCheck
Fix cutoff
Fix cut-off in polydiv
Add trim_zeros_tol, remove redundant code in polydiv
Remove unused import
Fix trim_zero_tol usage in polydiv
2022-04-13 18:31:27 +00:00
Jake VanderPlas
fbfc3d8edf
Better error messages for jnp.fromiter and jnp.fromfile
2022-03-29 14:30:32 -07:00
Jake VanderPlas
093b7032a8
Implement jnp.from* array creation functions
2022-03-29 10:52:47 -07:00
Jake VanderPlas
f4d240c036
Remove lax_numpy from jax.numpy namespace
...
This is a private module that was inadvertently exported in the past.
2022-03-25 15:02:45 -07:00