57 Commits

Author SHA1 Message Date
Jake VanderPlas
5b28170b94 Support scalar boolean indices in arr.at[idx].set(vals) 2024-05-20 05:33:36 -07:00
Jake VanderPlas
74f1d8897c DOC: add manual documentation to jax.scipy.special functions.
This lets us give more implementation-specific information, and
lets us avoid a needless dependency on scipy.
2024-04-29 10:58:07 -07:00
carlosgmartin
e98612e2ab Add where argument to logsumexp. 2024-04-08 12:57:06 -04:00
Jake VanderPlas
bbfd4f2c26 jax.numpy: implement scalar boolean indexing 2024-02-09 11:00:00 -08:00
Jake VanderPlas
7d6a134f4e logsumexp: use NumPy 2.0 convention for complex sign 2024-01-16 16:15:06 -08:00
George Necula
30bc5a2a5f [shape_poly] Update the jax.ops.segment{max|...} to with with shape polymorphism
The fix is very small, just had to check how we check for cases when tracers
are passed as num_segments. We add tests.
2023-12-19 12:02:39 +02:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Sergei Lebedev
5ab05e42c9 MAINT Clean up leftover Array = Any aliases in jax/_src/**.py
I had to revert to using `Any` for `RaggedAxis.ragged_axes` because pytype
found more latent type errors, which require the understanding of ragedness
and dynamic shapes internals to fix properly.
2023-10-01 12:19:21 +01:00
Jake VanderPlas
8412781127 Internal: add dtypes.safe_to_cast utility & use to generate indexing warning 2023-09-07 12:18:14 -07:00
Jake VanderPlas
27324ff18c special.logsumexp: fix incorrect annotation 2023-08-21 09:10:19 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Jake VanderPlas
fbc1ee2ba3 Remove some dead code and unused imports 2023-04-12 12:15:15 -07:00
Jake VanderPlas
e061b91ffc Fix uint32 scatter assignment 2023-04-10 14:24:26 -07:00
Jake VanderPlas
760deb310e Remove leading underscores in jax._src.numpy.util 2023-03-13 12:18:36 -07:00
Jake VanderPlas
c8c269f5f5 internal: avoid unused imports in lax_numpy 2023-03-08 10:29:04 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00
jax authors
e1d118c38d Merge pull request #13476 from jakevdp:x64-lax-numpy
PiperOrigin-RevId: 492367125
2022-12-01 20:29:46 -08:00
Jake VanderPlas
3cf2924ed6 [x64] minor fixes for lax_numpy_test type safety 2022-12-01 13:56:42 -08:00
Jake VanderPlas
b037feb105 [x64] more type safety for lax_numpy-related tests 2022-12-01 11:18:02 -08:00
Yash Katariya
a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Peter Hawkins
cd84eb10a6 Add a number of missing function cross-references in the docs. 2022-11-07 12:00:26 -05:00
Jake VanderPlas
709ffd7e77 [typing] annotate jax.numpy reduction operations 2022-10-26 13:33:15 -07:00
Jake VanderPlas
8ac9ea312a [typing] annotate jax.scipy.special 2022-10-13 12:16:12 -07:00
Peter Hawkins
9ab88071a7 Avoid loading scipy eagerly.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
ed06838006 [typing] clear up logic in scatter_update
Static type checkers do not parse deeply enough to know that by line 182
bucket_size cannot by None; branching on an explicit None check is easier
to follow (even for human readers)
2022-09-14 15:32:52 -07:00
Jake VanderPlas
d2f80ef117 [x64] deprecate unsafe type casting in scatter-update operations 2022-06-09 15:21:49 -07:00
Jake VanderPlas
0a72adbd5e lax_numpy: factor out indexing tricks 2022-03-17 11:05:45 -07:00
Jake VanderPlas
4f6f4e5554 segment_max: fix identity for boolean dtype 2022-03-15 09:20:20 -07:00
jax authors
cf1161ff8b Merge pull request #9826 from froystig:lax-cleanup2
PiperOrigin-RevId: 433827272
2022-03-10 12:48:34 -08:00
Peter Hawkins
adb37d9d52 Speed up compilation time of segment_sum() operators with bucketing enabled. 2022-03-10 09:26:46 -05:00
Roy Frostig
8f93629e87 remove _convert_element_type from public jax.lax module 2022-03-09 18:46:38 -08:00
Peter Hawkins
f51a05a889 Remove jax.ops.index... functions.
These functions have been deprecated and have issued a DeprecationWarning since jax 0.2.22 in October 2021.
2022-02-24 09:36:28 -05:00
Jake VanderPlas
47e88ded05 [x64] ensure scatter functionality preserves weak_type 2021-11-30 15:43:06 -08:00
Peter Hawkins
4679f455f9 Change the default out-of-bounds behavior for jax.ops.segment_... to FILL_OR_DROP.
This matches the documented behavior.

Fixes https://github.com/google/jax/issues/8634

PiperOrigin-RevId: 411635687
2021-11-22 13:32:58 -08:00
Peter Hawkins
9fee130d6b [JAX] Update users of jax.ops.index... functions, which are deprecated.
* replace uses of `jax.ops.index[...]` with `jax.numpy.index_exp[...]`, which is a standard NumPy function that does the same thing.
* remove some redundant uses of `jax.ops.index[...]`, where the expression is passed directly to an indexed accessor function like `.at[...]`.
* update some remaining users of `jax.ops.index_update(x, jax.ops.index[idx], y)` to use the `x.at[idx].set(y)` APIs.

PiperOrigin-RevId: 404395250
2021-10-19 16:39:09 -07:00
Peter Hawkins
104a46594b Add DeprecationWarnings to jax.ops.index_... operators.
Remove uses of index_... in Common Gotchas notebook.
2021-10-05 20:47:22 -04:00
Peter Hawkins
6a284ce5ad Fix incorrect EllipsisType reference for Python 3.10 2021-10-05 16:16:59 -04:00
Jake VanderPlas
c35b2f2485 DOC: move index update API docs to jnp.ndarray.at
- Add docstring to abstract  property
- Add explicit HTML documentation of this property
- Mark index update functions as deprecated, linking to this documentation
2021-10-01 14:06:08 -07:00
Peter Hawkins
867068821e Drop out-of-bounds indexes in gather. 2021-09-23 10:35:03 -04:00
Peter Hawkins
46288a299e Fix wrong dtype output from indexing with an empty slice.
Fix test failures in lax_numpy_indexing_test for complex64 pow() scatters.

The existing tests actually catch both of these bugs, but only when run with a sufficiently high number of test cases.
2021-09-22 15:24:56 -04:00
Jake VanderPlas
08e1c831ba Validate shapes for boolean indices 2021-08-03 15:33:44 -07:00
Peter Hawkins
2168483a62 Add x.at[idx].get().
This allows the sorted/unique keyword arguments to be passed to indexed gather operations.
2021-07-07 08:51:09 -04:00
Peter Hawkins
1ff12f05b3 Add unique/sorted annotations to gather().
XLA itself does not consume these, but they can be propagated onto scatter() when computing gradients.

Compute unique/sorted information on indexed accesses and indexed updates. Non-advanced indexes are always sorted and unique.
2021-06-09 21:05:41 -04:00
Peter Hawkins
44c98ad4e8 Improve JVP rule for scatters with non-overlapping indices.
If the scattered values don't overlap, we don't need complicated masking logic to work out which of the two overlapping values "win".
2021-05-12 14:16:35 -04:00
Jake VanderPlas
cbb7052379 Implement segment_prod, segment_max, segment_min 2021-04-09 12:06:51 -07:00
Jake VanderPlas
48ac77d272 jax.ops.segment_sum: improve input validation 2021-04-08 08:39:58 -07:00