734 Commits

Author SHA1 Message Date
Jake VanderPlas
8b62516676 [array api] add stable & descending params to jnp.sort & jnp.argsort 2024-01-04 14:21:25 -08:00
jax authors
c06c2925aa Merge pull request #19186 from jakevdp:asarray-copy
PiperOrigin-RevId: 595541132
2024-01-03 17:09:19 -08:00
Jake VanderPlas
97fc213eb0 [array API] support copy argument to jnp.asarray 2024-01-03 15:20:27 -08:00
Jake VanderPlas
df4e9c0d41 DOC: add warning about dlpack and buffer mutation 2024-01-03 13:31:57 -08:00
Jake VanderPlas
cab63114b4 Remove deprecated function jax.numpy.trapz
This was deprecated prior to the JAX 0.4.16 release, so we have now met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 592266215
2023-12-19 09:57:39 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Matthew Johnson
25eb913d10 don't call lax.xeinsum from jnp.einsum when str contains '{'
can still call lax.xeinsum directly
2023-12-09 11:11:31 -08:00
Matthew Johnson
9a1a09c28b remove _use_xeinsum from jnp.einsum api
can still call jnp.einsum with a '{' in the spec string to trigger xeinsum, or
just call lax.xeinsum directly
2023-12-09 10:53:22 -08:00
jax authors
709564ab78 Move jit to the callsite.
PiperOrigin-RevId: 589328135
2023-12-08 22:19:56 -08:00
Sergei Lebedev
36f6b52e42 Upgrade most .py sources to 3.9
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
George Necula
0a02d83015 [shape_poly] Add simpler APIs max_dim and min_dim, improve >= 0
Add core.max_dim and core.min_dim as nicer wrappers around the
core.non_negative_dim. Also improve the completeness of the
heuristics for deciding >= 0, and add more tests.
2023-12-07 09:41:47 +01:00
Jake VanderPlas
51960048f0 jnp.nonzero: deprecate zero-dimensional inputs 2023-12-06 12:57:25 -08:00
George Necula
ec460585c8 Fix indexing with slices when the slice elements are jax.Array.
This fixes a bug introduced in #18679, for the case when some
elements of the slice are `jax.Array`. We add a new test also.
2023-12-05 08:02:50 +01:00
George Necula
d2f62612d7 Fix bug in indexing with slices that overflow, and add tests.
This bug was introduced in #18679, and was not caught
in unit tests because we were not testing cases when the
slice needs to be clamped.
2023-12-02 16:47:06 +02:00
jax authors
e61f7a8149 Merge pull request #18757 from jakevdp:astype
PiperOrigin-RevId: 587167285
2023-12-01 17:09:02 -08:00
George Necula
2d1ce133bc [shape_poly] Simplify the indexing with slice to make it compatible with shape polymorphism
Currently, we do not support shape polymorphism when we index with a
slice, e.g., `x[a🅱️c]`, and insted we direct the user to use to
`lax.dynamic_slice`. This is only because so far we have not tried
to ensure that the index and bounds checking computations in gather
are compatible with shape polymorphism. The problem was that there
were a lot of conditionals, e.g., `if start >= stop` that cannot be
handled in general in presence of symbolic shapes.

Here we introduce a new helper function `_preprocess_slice` to contain
all the computations for the start and the size of the slice.

To test that this does not break the JAX index computations, I ran
the tests with `JAX_NUM_GENERATED_CASES=1000`, especially the `lax_numpy_indexer_test.py`.
2023-12-01 08:40:07 +02:00
Jake VanderPlas
d77cd9a0f4 Add jax.numpy.astype function 2023-11-30 15:50:22 -08:00
jax authors
0fce77a70e Merge pull request #18708 from jakevdp:array-equal-dep
PiperOrigin-RevId: 586357829
2023-11-29 08:58:29 -08:00
Jake VanderPlas
13dd5e42cc Deprecate non-array inputs to jnp.array_equal & jnp.array_equiv 2023-11-28 13:55:18 -08:00
Jake VanderPlas
a8723ecb9c Fix grad of jnp.i0 at zero 2023-11-28 12:34:56 -08:00
Jake VanderPlas
2acdb120a0 DOC: document preferred_element_type argument to dot functions 2023-11-22 09:49:34 -08:00
Matthew Johnson
67677eb10e improve error message for e.g. jnp.zeros(5)[:, 0] 2023-11-21 15:59:21 -08:00
Peter Hawkins
84c1e825c0 Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.
PiperOrigin-RevId: 584377134
2023-11-21 11:10:12 -08:00
Jake VanderPlas
84aa7e5c53 Deprecate passing of None to jax.numpy.array 2023-11-16 15:10:56 -08:00
Matthew Johnson
6b6b44d409 add error hint about common jnp.ones / jnp.zeros mistake 2023-11-15 19:52:16 -08:00
jax authors
8f8b2550f1 Merge pull request #18554 from mattjj:rot90-error-message
PiperOrigin-RevId: 582878992
2023-11-15 19:16:50 -08:00
Matthew Johnson
2288f64563 rot90 validate argument has ndim at least 2 2023-11-15 18:24:42 -08:00
Matthew Johnson
4654eedb10 improve jnp.reshape's error message 2023-11-15 16:21:13 -08:00
Jake VanderPlas
416b734567 Fix boolean indexing check with newaxis 2023-11-15 09:03:15 -08:00
Neil Girdhar
1452e219df Annotate jax.numpy.atleast_xd functions
These tighter annotations should reduce the need for casting and
ignoring types without causing any type errors.
2023-11-09 19:21:35 -05:00
Jake VanderPlas
a30d51ba2e jnp.histogram: avoid flattening input 2023-11-08 08:55:09 -08:00
Jake VanderPlas
4f863e9148 jnp.cross: account for numpy 2.0 deprecation 2023-11-03 14:15:23 -07:00
Jake VanderPlas
fbacebc11e jnp.einsum: mention default value for optimize param 2023-10-30 09:22:37 -07:00
Sergei Lebedev
f2ce5dbd01 MAINT Do not use str() and repr() in f-string replacement fields
`str()` is called by default by the formatting machinery, and `repr()` only
needs `!r`.
2023-10-23 15:12:04 +01:00
jax authors
73a973eaa8 Merge pull request #18000 from alhridoy:arange-precision-warning
PiperOrigin-RevId: 575328461
2023-10-20 15:10:12 -07:00
alhridoy
63f7cfe04c Add precision warning and workaround to jnp.arange documentation 2023-10-20 15:34:12 -06:00
carlosgmartin
3cb504c583 Add jax.numpy.fill_diagonal. 2023-10-20 16:47:46 -04:00
Jake VanderPlas
1815bc7632 [typing] allow scalar shape for jnp.broadcast_to 2023-10-13 13:37:20 -07:00
Sergei Lebedev
2f70ae700a Migrate another subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.

PiperOrigin-RevId: 572587137
2023-10-11 08:46:06 -07:00
Sergei Lebedev
5d9c39f4b0 MAINT Use a generator expression with all() and any()
There is no reason to allocate a list only for the purpose of iteration.
2023-10-10 22:33:03 +01:00
Jake VanderPlas
911f745775 Make jax._src.typing.DTypeLike more strictly defined
This is in preparation for exporting this to `jax.typing.DTypeLike`. Currently this is effectively just Any, and we want to make certain it's a meaningful type before exporting.

PiperOrigin-RevId: 572260744
2023-10-10 09:01:19 -07:00
Jake VanderPlas
2902b32e33 [typing] allow Sequence inputs in several jax.numpy functions 2023-10-02 11:48:36 -07:00
Jake VanderPlas
e8ebe462d2 Deprecate non-array inputs to several jax.numpy functions 2023-09-22 14:21:23 -07:00
Yash Katariya
426970591b If an input to jnp.asarray is a numpy array, then convert it to a jax.Array via device_put to avoid a copy.
Do a similar thing for jax.Array too if dtypes match.

Fixes https://github.com/google/jax/issues/17702

PiperOrigin-RevId: 567644997
2023-09-22 09:40:25 -07:00
Jake VanderPlas
4edb74ba7b Fix some numpy 2.0 incompatibilities 2023-09-21 10:24:52 -07:00
Jake VanderPlas
505f03b40f Avoid references to symbols removed in numpy 2.0 2023-09-19 11:50:21 -07:00
Jake VanderPlas
3386e54fe0 jnp.inner: add preferred_element_type argument 2023-09-14 16:40:19 -07:00
Brennan Saeta
1cef3e85f4 Fix error message for zeros_like which was referencing ones_like.
PiperOrigin-RevId: 565413589
2023-09-14 10:43:57 -07:00
Brian Patton
ed955ea7bf Fully unroll the scan in jnp.searchsorted, when method 'scan_unrolled' is specified. On GPU, XLA's 'scan' (fori_loop) implementation launches multiple calls to the body_fun GPU kernel, whereas a fully unrolled scan can be fused into a single kernel launch.
Since we only require log-many steps, this is often quite practical, and can be a nice speedup. (from 4.5ms down to 1.5ms in my scenario.)

PiperOrigin-RevId: 565371859
2023-09-14 08:10:49 -07:00
Jake VanderPlas
9289f3250b Add missing preferred_element_type tests
Followup to https://github.com/google/jax/pull/17506
2023-09-08 13:07:37 -07:00