761 Commits

Author SHA1 Message Date
Jake VanderPlas
9e01afe7af Add jax.numpy.trapezoid
This function has been added to NumPy in version 2.0, as a replacement
for the already deprecated trapz function.
2024-04-01 13:05:20 -07:00
Yash Katariya
d7e5ddee5e Remove _maybe_device_put because jax.device_put accepts None on the device parameter
PiperOrigin-RevId: 618223250
2024-03-22 10:39:57 -07:00
Giacomo Petrillo
fb91b51320 test digitize corner case and fix it 2024-03-18 15:55:27 -05:00
Neil Girdhar
1e580457ba Repair various type errors 2024-03-13 15:13:56 -04:00
Meekail Zain
895ad60d60 Removed dlpack extraction leading to forced legacy path 2024-03-08 18:00:36 +00:00
Jake VanderPlas
32da56fc95 jnp.array: fix failure under numpy 2.0 copy semantics 2024-03-04 10:39:38 -08:00
Jake VanderPlas
c2d07a6623 Finalize deprecation of non-array arguments to array_equal/array_equiv 2024-02-29 05:31:37 -08:00
Jake VanderPlas
85f205bdc7 typing: fix incorrect tuple annotations 2024-02-26 10:53:19 -08:00
Till Hoffmann
2d95075ee5 Promote isclose arguments to inexact dtype unless extended (fixes #19935). 2024-02-23 09:24:09 -05:00
Dan F-M
beeaf3570e Fixing tiny typo in jax.numpy type definitions
Fixing typo in function definition too
2024-02-21 20:46:51 -05:00
George Necula
18698a1f19 [shape_poly] Add support for jnp.split 2024-02-15 14:43:41 +01:00
jax authors
2c9c7dc891 Merge pull request #19739 from jakevdp:atleast-nd
PiperOrigin-RevId: 605776917
2024-02-09 17:46:23 -08:00
Jake VanderPlas
ebb602296e streamline jnp.atleast_nd implementations 2024-02-09 15:10:59 -08:00
Jake VanderPlas
bbfd4f2c26 jax.numpy: implement scalar boolean indexing 2024-02-09 11:00:00 -08:00
Jake VanderPlas
1ffce4da1a Add (private) mechanism for registering and accelerating deprecations
The idea is that each deprecated behavior would have an associated ID by which it could be referred to globally, so that we could call `deprecations.accelerate(module_name, ID)` in order to accelerate the deprecation period and run code with the post-deprecation behavior.

For now, these deprecation accelerations will be private APIs, but we could think about how to expose these to the user, perhaps via a config flag that finalizes all deprecations in the library.

PiperOrigin-RevId: 605064227
2024-02-07 12:28:03 -08:00
Pearu Peterson
82b2ae211c Add CUDA Array Interface consumer support 2024-02-07 12:08:36 +02:00
Jake VanderPlas
9549c745af jnp.full_like & co: support device parameter 2024-01-26 10:01:54 -08:00
Jake VanderPlas
43a9faa06a Rename _wraps to implements 2024-01-24 14:14:19 -08:00
Jake VanderPlas
d55cd7c9e2 jax.numpy: support device argument for full, empty, zeros, ones 2024-01-24 12:01:09 -08:00
Jake VanderPlas
17f5658db8 jnp.diff: support scalar prepend/append 2024-01-16 08:46:44 -08:00
Jake VanderPlas
989618c5f7 [array api] add jax.numpy.concat 2024-01-12 13:12:09 -08:00
Jake VanderPlas
9890b23b0a Add jnp.vecdot 2024-01-10 13:11:37 -08:00
Jake VanderPlas
707657e5b7 Adjust permute_dims signature to match NumPy
This really doesn't matter because it's a position-only argument, but this
change satisfies our tests and is easier than making the tests smarter.
2024-01-09 09:56:19 -08:00
jax authors
856915f3c4 Merge pull request #19244 from jakevdp:permute-dims
PiperOrigin-RevId: 596741266
2024-01-08 17:02:46 -08:00
Jake VanderPlas
d673b9bf5c [array api] add jax.numpy.permute_dims function 2024-01-08 09:30:51 -08:00
Jake VanderPlas
6278363e25 jnp.argsort/sort: explicitly deprecate the kind argument
This argument is a carry-over from NumPy, and has never had any effect (all jax.numpy
sorts were stable by default). Now that the new stable parameter is supported, it will
be clearer if we explicitly deprecate and eventually remove this argument.
2024-01-05 09:19:36 -08:00
Qiao Zhang
1ed6a818c7 Fix a type annotation typo.
PiperOrigin-RevId: 595891648
2024-01-04 22:16:11 -08:00
Jake VanderPlas
8b62516676 [array api] add stable & descending params to jnp.sort & jnp.argsort 2024-01-04 14:21:25 -08:00
jax authors
c06c2925aa Merge pull request #19186 from jakevdp:asarray-copy
PiperOrigin-RevId: 595541132
2024-01-03 17:09:19 -08:00
Jake VanderPlas
97fc213eb0 [array API] support copy argument to jnp.asarray 2024-01-03 15:20:27 -08:00
Jake VanderPlas
df4e9c0d41 DOC: add warning about dlpack and buffer mutation 2024-01-03 13:31:57 -08:00
Jake VanderPlas
cab63114b4 Remove deprecated function jax.numpy.trapz
This was deprecated prior to the JAX 0.4.16 release, so we have now met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 592266215
2023-12-19 09:57:39 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Matthew Johnson
25eb913d10 don't call lax.xeinsum from jnp.einsum when str contains '{'
can still call lax.xeinsum directly
2023-12-09 11:11:31 -08:00
Matthew Johnson
9a1a09c28b remove _use_xeinsum from jnp.einsum api
can still call jnp.einsum with a '{' in the spec string to trigger xeinsum, or
just call lax.xeinsum directly
2023-12-09 10:53:22 -08:00
jax authors
709564ab78 Move jit to the callsite.
PiperOrigin-RevId: 589328135
2023-12-08 22:19:56 -08:00
Sergei Lebedev
36f6b52e42 Upgrade most .py sources to 3.9
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
George Necula
0a02d83015 [shape_poly] Add simpler APIs max_dim and min_dim, improve >= 0
Add core.max_dim and core.min_dim as nicer wrappers around the
core.non_negative_dim. Also improve the completeness of the
heuristics for deciding >= 0, and add more tests.
2023-12-07 09:41:47 +01:00
Jake VanderPlas
51960048f0 jnp.nonzero: deprecate zero-dimensional inputs 2023-12-06 12:57:25 -08:00
George Necula
ec460585c8 Fix indexing with slices when the slice elements are jax.Array.
This fixes a bug introduced in #18679, for the case when some
elements of the slice are `jax.Array`. We add a new test also.
2023-12-05 08:02:50 +01:00
George Necula
d2f62612d7 Fix bug in indexing with slices that overflow, and add tests.
This bug was introduced in #18679, and was not caught
in unit tests because we were not testing cases when the
slice needs to be clamped.
2023-12-02 16:47:06 +02:00
jax authors
e61f7a8149 Merge pull request #18757 from jakevdp:astype
PiperOrigin-RevId: 587167285
2023-12-01 17:09:02 -08:00
George Necula
2d1ce133bc [shape_poly] Simplify the indexing with slice to make it compatible with shape polymorphism
Currently, we do not support shape polymorphism when we index with a
slice, e.g., `x[a🅱️c]`, and insted we direct the user to use to
`lax.dynamic_slice`. This is only because so far we have not tried
to ensure that the index and bounds checking computations in gather
are compatible with shape polymorphism. The problem was that there
were a lot of conditionals, e.g., `if start >= stop` that cannot be
handled in general in presence of symbolic shapes.

Here we introduce a new helper function `_preprocess_slice` to contain
all the computations for the start and the size of the slice.

To test that this does not break the JAX index computations, I ran
the tests with `JAX_NUM_GENERATED_CASES=1000`, especially the `lax_numpy_indexer_test.py`.
2023-12-01 08:40:07 +02:00
Jake VanderPlas
d77cd9a0f4 Add jax.numpy.astype function 2023-11-30 15:50:22 -08:00
jax authors
0fce77a70e Merge pull request #18708 from jakevdp:array-equal-dep
PiperOrigin-RevId: 586357829
2023-11-29 08:58:29 -08:00
Jake VanderPlas
13dd5e42cc Deprecate non-array inputs to jnp.array_equal & jnp.array_equiv 2023-11-28 13:55:18 -08:00
Jake VanderPlas
a8723ecb9c Fix grad of jnp.i0 at zero 2023-11-28 12:34:56 -08:00
Jake VanderPlas
2acdb120a0 DOC: document preferred_element_type argument to dot functions 2023-11-22 09:49:34 -08:00
Matthew Johnson
67677eb10e improve error message for e.g. jnp.zeros(5)[:, 0] 2023-11-21 15:59:21 -08:00
Peter Hawkins
84c1e825c0 Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.
PiperOrigin-RevId: 584377134
2023-11-21 11:10:12 -08:00