Jake VanderPlas
8b62516676
[array api] add stable & descending params to jnp.sort & jnp.argsort
2024-01-04 14:21:25 -08:00
jax authors
c06c2925aa
Merge pull request #19186 from jakevdp:asarray-copy
...
PiperOrigin-RevId: 595541132
2024-01-03 17:09:19 -08:00
Jake VanderPlas
97fc213eb0
[array API] support copy argument to jnp.asarray
2024-01-03 15:20:27 -08:00
Jake VanderPlas
df4e9c0d41
DOC: add warning about dlpack and buffer mutation
2024-01-03 13:31:57 -08:00
Jake VanderPlas
cab63114b4
Remove deprecated function jax.numpy.trapz
...
This was deprecated prior to the JAX 0.4.16 release, so we have now met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html ).
PiperOrigin-RevId: 592266215
2023-12-19 09:57:39 -08:00
Sergei Lebedev
f936613b06
Upgrade remaining sources to Python 3.9
...
This PR is a follow up to #18881 .
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Matthew Johnson
25eb913d10
don't call lax.xeinsum from jnp.einsum when str contains '{'
...
can still call lax.xeinsum directly
2023-12-09 11:11:31 -08:00
Matthew Johnson
9a1a09c28b
remove _use_xeinsum from jnp.einsum api
...
can still call jnp.einsum with a '{' in the spec string to trigger xeinsum, or
just call lax.xeinsum directly
2023-12-09 10:53:22 -08:00
jax authors
709564ab78
Move jit to the callsite.
...
PiperOrigin-RevId: 589328135
2023-12-08 22:19:56 -08:00
Sergei Lebedev
36f6b52e42
Upgrade most .py sources to 3.9
...
This commit was generated by running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
George Necula
0a02d83015
[shape_poly] Add simpler APIs max_dim and min_dim, improve >= 0
...
Add core.max_dim and core.min_dim as nicer wrappers around the
core.non_negative_dim. Also improve the completeness of the
heuristics for deciding >= 0, and add more tests.
2023-12-07 09:41:47 +01:00
Jake VanderPlas
51960048f0
jnp.nonzero: deprecate zero-dimensional inputs
2023-12-06 12:57:25 -08:00
George Necula
ec460585c8
Fix indexing with slices when the slice elements are jax.Array.
...
This fixes a bug introduced in #18679 , for the case when some
elements of the slice are `jax.Array`. We add a new test also.
2023-12-05 08:02:50 +01:00
George Necula
d2f62612d7
Fix bug in indexing with slices that overflow, and add tests.
...
This bug was introduced in #18679 , and was not caught
in unit tests because we were not testing cases when the
slice needs to be clamped.
2023-12-02 16:47:06 +02:00
jax authors
e61f7a8149
Merge pull request #18757 from jakevdp:astype
...
PiperOrigin-RevId: 587167285
2023-12-01 17:09:02 -08:00
George Necula
2d1ce133bc
[shape_poly] Simplify the indexing with slice to make it compatible with shape polymorphism
...
Currently, we do not support shape polymorphism when we index with a
slice, e.g., `x[a🅱️ c]`, and insted we direct the user to use to
`lax.dynamic_slice`. This is only because so far we have not tried
to ensure that the index and bounds checking computations in gather
are compatible with shape polymorphism. The problem was that there
were a lot of conditionals, e.g., `if start >= stop` that cannot be
handled in general in presence of symbolic shapes.
Here we introduce a new helper function `_preprocess_slice` to contain
all the computations for the start and the size of the slice.
To test that this does not break the JAX index computations, I ran
the tests with `JAX_NUM_GENERATED_CASES=1000`, especially the `lax_numpy_indexer_test.py`.
2023-12-01 08:40:07 +02:00
Jake VanderPlas
d77cd9a0f4
Add jax.numpy.astype function
2023-11-30 15:50:22 -08:00
jax authors
0fce77a70e
Merge pull request #18708 from jakevdp:array-equal-dep
...
PiperOrigin-RevId: 586357829
2023-11-29 08:58:29 -08:00
Jake VanderPlas
13dd5e42cc
Deprecate non-array inputs to jnp.array_equal & jnp.array_equiv
2023-11-28 13:55:18 -08:00
Jake VanderPlas
a8723ecb9c
Fix grad of jnp.i0 at zero
2023-11-28 12:34:56 -08:00
Jake VanderPlas
2acdb120a0
DOC: document preferred_element_type argument to dot functions
2023-11-22 09:49:34 -08:00
Matthew Johnson
67677eb10e
improve error message for e.g. jnp.zeros(5)[:, 0]
2023-11-21 15:59:21 -08:00
Peter Hawkins
84c1e825c0
Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.
...
PiperOrigin-RevId: 584377134
2023-11-21 11:10:12 -08:00
Jake VanderPlas
84aa7e5c53
Deprecate passing of None to jax.numpy.array
2023-11-16 15:10:56 -08:00
Matthew Johnson
6b6b44d409
add error hint about common jnp.ones / jnp.zeros mistake
2023-11-15 19:52:16 -08:00
jax authors
8f8b2550f1
Merge pull request #18554 from mattjj:rot90-error-message
...
PiperOrigin-RevId: 582878992
2023-11-15 19:16:50 -08:00
Matthew Johnson
2288f64563
rot90 validate argument has ndim at least 2
2023-11-15 18:24:42 -08:00
Matthew Johnson
4654eedb10
improve jnp.reshape's error message
2023-11-15 16:21:13 -08:00
Jake VanderPlas
416b734567
Fix boolean indexing check with newaxis
2023-11-15 09:03:15 -08:00
Neil Girdhar
1452e219df
Annotate jax.numpy.atleast_xd functions
...
These tighter annotations should reduce the need for casting and
ignoring types without causing any type errors.
2023-11-09 19:21:35 -05:00
Jake VanderPlas
a30d51ba2e
jnp.histogram: avoid flattening input
2023-11-08 08:55:09 -08:00
Jake VanderPlas
4f863e9148
jnp.cross: account for numpy 2.0 deprecation
2023-11-03 14:15:23 -07:00
Jake VanderPlas
fbacebc11e
jnp.einsum: mention default value for optimize param
2023-10-30 09:22:37 -07:00
Sergei Lebedev
f2ce5dbd01
MAINT Do not use str()
and repr()
in f-string replacement fields
...
`str()` is called by default by the formatting machinery, and `repr()` only
needs `!r`.
2023-10-23 15:12:04 +01:00
jax authors
73a973eaa8
Merge pull request #18000 from alhridoy:arange-precision-warning
...
PiperOrigin-RevId: 575328461
2023-10-20 15:10:12 -07:00
alhridoy
63f7cfe04c
Add precision warning and workaround to jnp.arange documentation
2023-10-20 15:34:12 -06:00
carlosgmartin
3cb504c583
Add jax.numpy.fill_diagonal.
2023-10-20 16:47:46 -04:00
Jake VanderPlas
1815bc7632
[typing] allow scalar shape for jnp.broadcast_to
2023-10-13 13:37:20 -07:00
Sergei Lebedev
2f70ae700a
Migrate another subset of internal modules to use state objects
...
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008 .
PiperOrigin-RevId: 572587137
2023-10-11 08:46:06 -07:00
Sergei Lebedev
5d9c39f4b0
MAINT Use a generator expression with all()
and any()
...
There is no reason to allocate a list only for the purpose of iteration.
2023-10-10 22:33:03 +01:00
Jake VanderPlas
911f745775
Make jax._src.typing.DTypeLike more strictly defined
...
This is in preparation for exporting this to `jax.typing.DTypeLike`. Currently this is effectively just Any, and we want to make certain it's a meaningful type before exporting.
PiperOrigin-RevId: 572260744
2023-10-10 09:01:19 -07:00
Jake VanderPlas
2902b32e33
[typing] allow Sequence inputs in several jax.numpy functions
2023-10-02 11:48:36 -07:00
Jake VanderPlas
e8ebe462d2
Deprecate non-array inputs to several jax.numpy functions
2023-09-22 14:21:23 -07:00
Yash Katariya
426970591b
If an input to jnp.asarray
is a numpy array, then convert it to a jax.Array via device_put to avoid a copy.
...
Do a similar thing for jax.Array too if dtypes match.
Fixes https://github.com/google/jax/issues/17702
PiperOrigin-RevId: 567644997
2023-09-22 09:40:25 -07:00
Jake VanderPlas
4edb74ba7b
Fix some numpy 2.0 incompatibilities
2023-09-21 10:24:52 -07:00
Jake VanderPlas
505f03b40f
Avoid references to symbols removed in numpy 2.0
2023-09-19 11:50:21 -07:00
Jake VanderPlas
3386e54fe0
jnp.inner: add preferred_element_type argument
2023-09-14 16:40:19 -07:00
Brennan Saeta
1cef3e85f4
Fix error message for zeros_like which was referencing ones_like.
...
PiperOrigin-RevId: 565413589
2023-09-14 10:43:57 -07:00
Brian Patton
ed955ea7bf
Fully unroll the scan
in jnp.searchsorted
, when method 'scan_unrolled' is specified. On GPU, XLA's 'scan' (fori_loop) implementation launches multiple calls to the body_fun GPU kernel, whereas a fully unrolled scan can be fused into a single kernel launch.
...
Since we only require log-many steps, this is often quite practical, and can be a nice speedup. (from 4.5ms down to 1.5ms in my scenario.)
PiperOrigin-RevId: 565371859
2023-09-14 08:10:49 -07:00
Jake VanderPlas
9289f3250b
Add missing preferred_element_type tests
...
Followup to https://github.com/google/jax/pull/17506
2023-09-08 13:07:37 -07:00