748 Commits

Author SHA1 Message Date
Jake VanderPlas
bbfd4f2c26 jax.numpy: implement scalar boolean indexing 2024-02-09 11:00:00 -08:00
Jake VanderPlas
1ffce4da1a Add (private) mechanism for registering and accelerating deprecations
The idea is that each deprecated behavior would have an associated ID by which it could be referred to globally, so that we could call `deprecations.accelerate(module_name, ID)` in order to accelerate the deprecation period and run code with the post-deprecation behavior.

For now, these deprecation accelerations will be private APIs, but we could think about how to expose these to the user, perhaps via a config flag that finalizes all deprecations in the library.

PiperOrigin-RevId: 605064227
2024-02-07 12:28:03 -08:00
Pearu Peterson
82b2ae211c Add CUDA Array Interface consumer support 2024-02-07 12:08:36 +02:00
Jake VanderPlas
9549c745af jnp.full_like & co: support device parameter 2024-01-26 10:01:54 -08:00
Jake VanderPlas
43a9faa06a Rename _wraps to implements 2024-01-24 14:14:19 -08:00
Jake VanderPlas
d55cd7c9e2 jax.numpy: support device argument for full, empty, zeros, ones 2024-01-24 12:01:09 -08:00
Jake VanderPlas
17f5658db8 jnp.diff: support scalar prepend/append 2024-01-16 08:46:44 -08:00
Jake VanderPlas
989618c5f7 [array api] add jax.numpy.concat 2024-01-12 13:12:09 -08:00
Jake VanderPlas
9890b23b0a Add jnp.vecdot 2024-01-10 13:11:37 -08:00
Jake VanderPlas
707657e5b7 Adjust permute_dims signature to match NumPy
This really doesn't matter because it's a position-only argument, but this
change satisfies our tests and is easier than making the tests smarter.
2024-01-09 09:56:19 -08:00
jax authors
856915f3c4 Merge pull request #19244 from jakevdp:permute-dims
PiperOrigin-RevId: 596741266
2024-01-08 17:02:46 -08:00
Jake VanderPlas
d673b9bf5c [array api] add jax.numpy.permute_dims function 2024-01-08 09:30:51 -08:00
Jake VanderPlas
6278363e25 jnp.argsort/sort: explicitly deprecate the kind argument
This argument is a carry-over from NumPy, and has never had any effect (all jax.numpy
sorts were stable by default). Now that the new stable parameter is supported, it will
be clearer if we explicitly deprecate and eventually remove this argument.
2024-01-05 09:19:36 -08:00
Qiao Zhang
1ed6a818c7 Fix a type annotation typo.
PiperOrigin-RevId: 595891648
2024-01-04 22:16:11 -08:00
Jake VanderPlas
8b62516676 [array api] add stable & descending params to jnp.sort & jnp.argsort 2024-01-04 14:21:25 -08:00
jax authors
c06c2925aa Merge pull request #19186 from jakevdp:asarray-copy
PiperOrigin-RevId: 595541132
2024-01-03 17:09:19 -08:00
Jake VanderPlas
97fc213eb0 [array API] support copy argument to jnp.asarray 2024-01-03 15:20:27 -08:00
Jake VanderPlas
df4e9c0d41 DOC: add warning about dlpack and buffer mutation 2024-01-03 13:31:57 -08:00
Jake VanderPlas
cab63114b4 Remove deprecated function jax.numpy.trapz
This was deprecated prior to the JAX 0.4.16 release, so we have now met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 592266215
2023-12-19 09:57:39 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Matthew Johnson
25eb913d10 don't call lax.xeinsum from jnp.einsum when str contains '{'
can still call lax.xeinsum directly
2023-12-09 11:11:31 -08:00
Matthew Johnson
9a1a09c28b remove _use_xeinsum from jnp.einsum api
can still call jnp.einsum with a '{' in the spec string to trigger xeinsum, or
just call lax.xeinsum directly
2023-12-09 10:53:22 -08:00
jax authors
709564ab78 Move jit to the callsite.
PiperOrigin-RevId: 589328135
2023-12-08 22:19:56 -08:00
Sergei Lebedev
36f6b52e42 Upgrade most .py sources to 3.9
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
George Necula
0a02d83015 [shape_poly] Add simpler APIs max_dim and min_dim, improve >= 0
Add core.max_dim and core.min_dim as nicer wrappers around the
core.non_negative_dim. Also improve the completeness of the
heuristics for deciding >= 0, and add more tests.
2023-12-07 09:41:47 +01:00
Jake VanderPlas
51960048f0 jnp.nonzero: deprecate zero-dimensional inputs 2023-12-06 12:57:25 -08:00
George Necula
ec460585c8 Fix indexing with slices when the slice elements are jax.Array.
This fixes a bug introduced in #18679, for the case when some
elements of the slice are `jax.Array`. We add a new test also.
2023-12-05 08:02:50 +01:00
George Necula
d2f62612d7 Fix bug in indexing with slices that overflow, and add tests.
This bug was introduced in #18679, and was not caught
in unit tests because we were not testing cases when the
slice needs to be clamped.
2023-12-02 16:47:06 +02:00
jax authors
e61f7a8149 Merge pull request #18757 from jakevdp:astype
PiperOrigin-RevId: 587167285
2023-12-01 17:09:02 -08:00
George Necula
2d1ce133bc [shape_poly] Simplify the indexing with slice to make it compatible with shape polymorphism
Currently, we do not support shape polymorphism when we index with a
slice, e.g., `x[a🅱️c]`, and insted we direct the user to use to
`lax.dynamic_slice`. This is only because so far we have not tried
to ensure that the index and bounds checking computations in gather
are compatible with shape polymorphism. The problem was that there
were a lot of conditionals, e.g., `if start >= stop` that cannot be
handled in general in presence of symbolic shapes.

Here we introduce a new helper function `_preprocess_slice` to contain
all the computations for the start and the size of the slice.

To test that this does not break the JAX index computations, I ran
the tests with `JAX_NUM_GENERATED_CASES=1000`, especially the `lax_numpy_indexer_test.py`.
2023-12-01 08:40:07 +02:00
Jake VanderPlas
d77cd9a0f4 Add jax.numpy.astype function 2023-11-30 15:50:22 -08:00
jax authors
0fce77a70e Merge pull request #18708 from jakevdp:array-equal-dep
PiperOrigin-RevId: 586357829
2023-11-29 08:58:29 -08:00
Jake VanderPlas
13dd5e42cc Deprecate non-array inputs to jnp.array_equal & jnp.array_equiv 2023-11-28 13:55:18 -08:00
Jake VanderPlas
a8723ecb9c Fix grad of jnp.i0 at zero 2023-11-28 12:34:56 -08:00
Jake VanderPlas
2acdb120a0 DOC: document preferred_element_type argument to dot functions 2023-11-22 09:49:34 -08:00
Matthew Johnson
67677eb10e improve error message for e.g. jnp.zeros(5)[:, 0] 2023-11-21 15:59:21 -08:00
Peter Hawkins
84c1e825c0 Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.
PiperOrigin-RevId: 584377134
2023-11-21 11:10:12 -08:00
Jake VanderPlas
84aa7e5c53 Deprecate passing of None to jax.numpy.array 2023-11-16 15:10:56 -08:00
Matthew Johnson
6b6b44d409 add error hint about common jnp.ones / jnp.zeros mistake 2023-11-15 19:52:16 -08:00
jax authors
8f8b2550f1 Merge pull request #18554 from mattjj:rot90-error-message
PiperOrigin-RevId: 582878992
2023-11-15 19:16:50 -08:00
Matthew Johnson
2288f64563 rot90 validate argument has ndim at least 2 2023-11-15 18:24:42 -08:00
Matthew Johnson
4654eedb10 improve jnp.reshape's error message 2023-11-15 16:21:13 -08:00
Jake VanderPlas
416b734567 Fix boolean indexing check with newaxis 2023-11-15 09:03:15 -08:00
Neil Girdhar
1452e219df Annotate jax.numpy.atleast_xd functions
These tighter annotations should reduce the need for casting and
ignoring types without causing any type errors.
2023-11-09 19:21:35 -05:00
Jake VanderPlas
a30d51ba2e jnp.histogram: avoid flattening input 2023-11-08 08:55:09 -08:00
Jake VanderPlas
4f863e9148 jnp.cross: account for numpy 2.0 deprecation 2023-11-03 14:15:23 -07:00
Jake VanderPlas
fbacebc11e jnp.einsum: mention default value for optimize param 2023-10-30 09:22:37 -07:00
Sergei Lebedev
f2ce5dbd01 MAINT Do not use str() and repr() in f-string replacement fields
`str()` is called by default by the formatting machinery, and `repr()` only
needs `!r`.
2023-10-23 15:12:04 +01:00
jax authors
73a973eaa8 Merge pull request #18000 from alhridoy:arange-precision-warning
PiperOrigin-RevId: 575328461
2023-10-20 15:10:12 -07:00
alhridoy
63f7cfe04c Add precision warning and workaround to jnp.arange documentation 2023-10-20 15:34:12 -06:00