Jake VanderPlas
41fa67c2dc
Finalize deprecation of zero-dimensional inputs to jnp.nonzero
...
PiperOrigin-RevId: 626299531
2024-04-19 02:19:10 -07:00
Meekail Zain
6bdc83c680
Add new unstack function to numpy/array_api namespaces
2024-04-15 21:03:26 +00:00
Jake VanderPlas
e07325a672
Make complex_arr.astype(bool) follow NumPy's semantics
2024-04-09 16:15:59 -07:00
jax authors
967c38d53d
Merge pull request #20666 from curlup:main
...
PiperOrigin-RevId: 623250005
2024-04-09 12:45:48 -07:00
Pavel T
44b47035ae
better unsupported indexing handling in lax_numpy.py
2024-04-09 14:09:35 -04:00
jax authors
2512843a56
Merge pull request #20550 from Micky774:api_clip
...
PiperOrigin-RevId: 622045823
2024-04-04 19:58:06 -07:00
Meekail Zain
8b7aae586b
Update jnp.clip
to Array API 2023 standard
2024-04-04 22:55:10 +00:00
Meekail Zain
2b1c3deee2
Update from_dlpack
to match array API 2023
2024-04-04 22:51:25 +00:00
Peter Hawkins
e2f47748e3
Fix tests that fail if enable_checks is true under NumPy 2.0.0rc1.
...
np.vecdot is missing `__module__` under NumPy 2.0.0rc1.
PiperOrigin-RevId: 621532796
2024-04-03 08:35:20 -07:00
Jake VanderPlas
fd7c85b349
jnp.geomspace: make complex behavior consistent with NumPy 2.0
2024-04-02 16:12:49 -07:00
Jake VanderPlas
9e01afe7af
Add jax.numpy.trapezoid
...
This function has been added to NumPy in version 2.0, as a replacement
for the already deprecated trapz function.
2024-04-01 13:05:20 -07:00
Yash Katariya
d7e5ddee5e
Remove _maybe_device_put because jax.device_put accepts None
on the device parameter
...
PiperOrigin-RevId: 618223250
2024-03-22 10:39:57 -07:00
Giacomo Petrillo
fb91b51320
test digitize corner case and fix it
2024-03-18 15:55:27 -05:00
Neil Girdhar
1e580457ba
Repair various type errors
2024-03-13 15:13:56 -04:00
Meekail Zain
895ad60d60
Removed dlpack extraction leading to forced legacy path
2024-03-08 18:00:36 +00:00
Jake VanderPlas
32da56fc95
jnp.array: fix failure under numpy 2.0 copy semantics
2024-03-04 10:39:38 -08:00
Jake VanderPlas
c2d07a6623
Finalize deprecation of non-array arguments to array_equal/array_equiv
2024-02-29 05:31:37 -08:00
Jake VanderPlas
85f205bdc7
typing: fix incorrect tuple annotations
2024-02-26 10:53:19 -08:00
Till Hoffmann
2d95075ee5
Promote isclose
arguments to inexact dtype unless extended ( fixes #19935 ).
2024-02-23 09:24:09 -05:00
Dan F-M
beeaf3570e
Fixing tiny typo in jax.numpy type definitions
...
Fixing typo in function definition too
2024-02-21 20:46:51 -05:00
George Necula
18698a1f19
[shape_poly] Add support for jnp.split
2024-02-15 14:43:41 +01:00
jax authors
2c9c7dc891
Merge pull request #19739 from jakevdp:atleast-nd
...
PiperOrigin-RevId: 605776917
2024-02-09 17:46:23 -08:00
Jake VanderPlas
ebb602296e
streamline jnp.atleast_nd implementations
2024-02-09 15:10:59 -08:00
Jake VanderPlas
bbfd4f2c26
jax.numpy: implement scalar boolean indexing
2024-02-09 11:00:00 -08:00
Jake VanderPlas
1ffce4da1a
Add (private) mechanism for registering and accelerating deprecations
...
The idea is that each deprecated behavior would have an associated ID by which it could be referred to globally, so that we could call `deprecations.accelerate(module_name, ID)` in order to accelerate the deprecation period and run code with the post-deprecation behavior.
For now, these deprecation accelerations will be private APIs, but we could think about how to expose these to the user, perhaps via a config flag that finalizes all deprecations in the library.
PiperOrigin-RevId: 605064227
2024-02-07 12:28:03 -08:00
Pearu Peterson
82b2ae211c
Add CUDA Array Interface consumer support
2024-02-07 12:08:36 +02:00
Jake VanderPlas
9549c745af
jnp.full_like & co: support device parameter
2024-01-26 10:01:54 -08:00
Jake VanderPlas
43a9faa06a
Rename _wraps to implements
2024-01-24 14:14:19 -08:00
Jake VanderPlas
d55cd7c9e2
jax.numpy: support device argument for full, empty, zeros, ones
2024-01-24 12:01:09 -08:00
Jake VanderPlas
17f5658db8
jnp.diff: support scalar prepend/append
2024-01-16 08:46:44 -08:00
Jake VanderPlas
989618c5f7
[array api] add jax.numpy.concat
2024-01-12 13:12:09 -08:00
Jake VanderPlas
9890b23b0a
Add jnp.vecdot
2024-01-10 13:11:37 -08:00
Jake VanderPlas
707657e5b7
Adjust permute_dims signature to match NumPy
...
This really doesn't matter because it's a position-only argument, but this
change satisfies our tests and is easier than making the tests smarter.
2024-01-09 09:56:19 -08:00
jax authors
856915f3c4
Merge pull request #19244 from jakevdp:permute-dims
...
PiperOrigin-RevId: 596741266
2024-01-08 17:02:46 -08:00
Jake VanderPlas
d673b9bf5c
[array api] add jax.numpy.permute_dims function
2024-01-08 09:30:51 -08:00
Jake VanderPlas
6278363e25
jnp.argsort/sort: explicitly deprecate the kind argument
...
This argument is a carry-over from NumPy, and has never had any effect (all jax.numpy
sorts were stable by default). Now that the new stable parameter is supported, it will
be clearer if we explicitly deprecate and eventually remove this argument.
2024-01-05 09:19:36 -08:00
Qiao Zhang
1ed6a818c7
Fix a type annotation typo.
...
PiperOrigin-RevId: 595891648
2024-01-04 22:16:11 -08:00
Jake VanderPlas
8b62516676
[array api] add stable & descending params to jnp.sort & jnp.argsort
2024-01-04 14:21:25 -08:00
jax authors
c06c2925aa
Merge pull request #19186 from jakevdp:asarray-copy
...
PiperOrigin-RevId: 595541132
2024-01-03 17:09:19 -08:00
Jake VanderPlas
97fc213eb0
[array API] support copy argument to jnp.asarray
2024-01-03 15:20:27 -08:00
Jake VanderPlas
df4e9c0d41
DOC: add warning about dlpack and buffer mutation
2024-01-03 13:31:57 -08:00
Jake VanderPlas
cab63114b4
Remove deprecated function jax.numpy.trapz
...
This was deprecated prior to the JAX 0.4.16 release, so we have now met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html ).
PiperOrigin-RevId: 592266215
2023-12-19 09:57:39 -08:00
Sergei Lebedev
f936613b06
Upgrade remaining sources to Python 3.9
...
This PR is a follow up to #18881 .
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Matthew Johnson
25eb913d10
don't call lax.xeinsum from jnp.einsum when str contains '{'
...
can still call lax.xeinsum directly
2023-12-09 11:11:31 -08:00
Matthew Johnson
9a1a09c28b
remove _use_xeinsum from jnp.einsum api
...
can still call jnp.einsum with a '{' in the spec string to trigger xeinsum, or
just call lax.xeinsum directly
2023-12-09 10:53:22 -08:00
jax authors
709564ab78
Move jit to the callsite.
...
PiperOrigin-RevId: 589328135
2023-12-08 22:19:56 -08:00
Sergei Lebedev
36f6b52e42
Upgrade most .py sources to 3.9
...
This commit was generated by running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
George Necula
0a02d83015
[shape_poly] Add simpler APIs max_dim and min_dim, improve >= 0
...
Add core.max_dim and core.min_dim as nicer wrappers around the
core.non_negative_dim. Also improve the completeness of the
heuristics for deciding >= 0, and add more tests.
2023-12-07 09:41:47 +01:00
Jake VanderPlas
51960048f0
jnp.nonzero: deprecate zero-dimensional inputs
2023-12-06 12:57:25 -08:00
George Necula
ec460585c8
Fix indexing with slices when the slice elements are jax.Array.
...
This fixes a bug introduced in #18679 , for the case when some
elements of the slice are `jax.Array`. We add a new test also.
2023-12-05 08:02:50 +01:00