45 Commits

Author SHA1 Message Date
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Jake VanderPlas
fbc1ee2ba3 Remove some dead code and unused imports 2023-04-12 12:15:15 -07:00
Jake VanderPlas
e061b91ffc Fix uint32 scatter assignment 2023-04-10 14:24:26 -07:00
Jake VanderPlas
760deb310e Remove leading underscores in jax._src.numpy.util 2023-03-13 12:18:36 -07:00
Jake VanderPlas
c8c269f5f5 internal: avoid unused imports in lax_numpy 2023-03-08 10:29:04 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00
jax authors
e1d118c38d Merge pull request #13476 from jakevdp:x64-lax-numpy
PiperOrigin-RevId: 492367125
2022-12-01 20:29:46 -08:00
Jake VanderPlas
3cf2924ed6 [x64] minor fixes for lax_numpy_test type safety 2022-12-01 13:56:42 -08:00
Jake VanderPlas
b037feb105 [x64] more type safety for lax_numpy-related tests 2022-12-01 11:18:02 -08:00
Yash Katariya
a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Peter Hawkins
cd84eb10a6 Add a number of missing function cross-references in the docs. 2022-11-07 12:00:26 -05:00
Jake VanderPlas
709ffd7e77 [typing] annotate jax.numpy reduction operations 2022-10-26 13:33:15 -07:00
Jake VanderPlas
8ac9ea312a [typing] annotate jax.scipy.special 2022-10-13 12:16:12 -07:00
Peter Hawkins
9ab88071a7 Avoid loading scipy eagerly.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
ed06838006 [typing] clear up logic in scatter_update
Static type checkers do not parse deeply enough to know that by line 182
bucket_size cannot by None; branching on an explicit None check is easier
to follow (even for human readers)
2022-09-14 15:32:52 -07:00
Jake VanderPlas
d2f80ef117 [x64] deprecate unsafe type casting in scatter-update operations 2022-06-09 15:21:49 -07:00
Jake VanderPlas
0a72adbd5e lax_numpy: factor out indexing tricks 2022-03-17 11:05:45 -07:00
Jake VanderPlas
4f6f4e5554 segment_max: fix identity for boolean dtype 2022-03-15 09:20:20 -07:00
jax authors
cf1161ff8b Merge pull request #9826 from froystig:lax-cleanup2
PiperOrigin-RevId: 433827272
2022-03-10 12:48:34 -08:00
Peter Hawkins
adb37d9d52 Speed up compilation time of segment_sum() operators with bucketing enabled. 2022-03-10 09:26:46 -05:00
Roy Frostig
8f93629e87 remove _convert_element_type from public jax.lax module 2022-03-09 18:46:38 -08:00
Peter Hawkins
f51a05a889 Remove jax.ops.index... functions.
These functions have been deprecated and have issued a DeprecationWarning since jax 0.2.22 in October 2021.
2022-02-24 09:36:28 -05:00
Jake VanderPlas
47e88ded05 [x64] ensure scatter functionality preserves weak_type 2021-11-30 15:43:06 -08:00
Peter Hawkins
4679f455f9 Change the default out-of-bounds behavior for jax.ops.segment_... to FILL_OR_DROP.
This matches the documented behavior.

Fixes https://github.com/google/jax/issues/8634

PiperOrigin-RevId: 411635687
2021-11-22 13:32:58 -08:00
Peter Hawkins
9fee130d6b [JAX] Update users of jax.ops.index... functions, which are deprecated.
* replace uses of `jax.ops.index[...]` with `jax.numpy.index_exp[...]`, which is a standard NumPy function that does the same thing.
* remove some redundant uses of `jax.ops.index[...]`, where the expression is passed directly to an indexed accessor function like `.at[...]`.
* update some remaining users of `jax.ops.index_update(x, jax.ops.index[idx], y)` to use the `x.at[idx].set(y)` APIs.

PiperOrigin-RevId: 404395250
2021-10-19 16:39:09 -07:00
Peter Hawkins
104a46594b Add DeprecationWarnings to jax.ops.index_... operators.
Remove uses of index_... in Common Gotchas notebook.
2021-10-05 20:47:22 -04:00
Peter Hawkins
6a284ce5ad Fix incorrect EllipsisType reference for Python 3.10 2021-10-05 16:16:59 -04:00
Jake VanderPlas
c35b2f2485 DOC: move index update API docs to jnp.ndarray.at
- Add docstring to abstract  property
- Add explicit HTML documentation of this property
- Mark index update functions as deprecated, linking to this documentation
2021-10-01 14:06:08 -07:00
Peter Hawkins
867068821e Drop out-of-bounds indexes in gather. 2021-09-23 10:35:03 -04:00
Peter Hawkins
46288a299e Fix wrong dtype output from indexing with an empty slice.
Fix test failures in lax_numpy_indexing_test for complex64 pow() scatters.

The existing tests actually catch both of these bugs, but only when run with a sufficiently high number of test cases.
2021-09-22 15:24:56 -04:00
Jake VanderPlas
08e1c831ba Validate shapes for boolean indices 2021-08-03 15:33:44 -07:00
Peter Hawkins
2168483a62 Add x.at[idx].get().
This allows the sorted/unique keyword arguments to be passed to indexed gather operations.
2021-07-07 08:51:09 -04:00
Peter Hawkins
1ff12f05b3 Add unique/sorted annotations to gather().
XLA itself does not consume these, but they can be propagated onto scatter() when computing gradients.

Compute unique/sorted information on indexed accesses and indexed updates. Non-advanced indexes are always sorted and unique.
2021-06-09 21:05:41 -04:00
Peter Hawkins
44c98ad4e8 Improve JVP rule for scatters with non-overlapping indices.
If the scattered values don't overlap, we don't need complicated masking logic to work out which of the two overlapping values "win".
2021-05-12 14:16:35 -04:00
Jake VanderPlas
cbb7052379 Implement segment_prod, segment_max, segment_min 2021-04-09 12:06:51 -07:00
Jake VanderPlas
48ac77d272 jax.ops.segment_sum: improve input validation 2021-04-08 08:39:58 -07:00
Jake VanderPlas
8e789c7380 Run doctest on all source files except jax2tf 2021-04-05 10:39:59 -07:00
Neil Girdhar
af0988ee7f Annotate scatter and random 2021-03-24 20:06:17 -04:00
Jake VanderPlas
cfe934c053 Fix some doc build warnings 2021-01-25 14:08:57 -08:00
Peter Hawkins
9ffe009964 Change segment_sum to drop out-of-bounds indices rather than wrap them.
This is a breaking change.
2021-01-13 10:54:27 -05:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
AdrienCorenflos
3bc34e4f5a
Update scatter.py
Looks like the documentation was not up to date
2020-11-05 18:13:44 +02:00
Peter Hawkins
ef57858deb Move jax.ops implementation into jax._src.ops. 2020-10-17 11:45:28 -04:00