Peter Hawkins
816ba91263
Use lower-case PEP 585 names for types.
...
Issue https://github.com/google/jax/issues/16537
PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Jake VanderPlas
fbc1ee2ba3
Remove some dead code and unused imports
2023-04-12 12:15:15 -07:00
Jake VanderPlas
e061b91ffc
Fix uint32 scatter assignment
2023-04-10 14:24:26 -07:00
Jake VanderPlas
760deb310e
Remove leading underscores in jax._src.numpy.util
2023-03-13 12:18:36 -07:00
Jake VanderPlas
c8c269f5f5
internal: avoid unused imports in lax_numpy
2023-03-08 10:29:04 -08:00
Roy Frostig
cb8dcce2fe
migrate more internal dependencies from jax.core
to jax._src.core
...
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Jake VanderPlas
4389216d0c
Remove typing_extensions dependency
2022-12-05 15:42:26 -08:00
jax authors
e1d118c38d
Merge pull request #13476 from jakevdp:x64-lax-numpy
...
PiperOrigin-RevId: 492367125
2022-12-01 20:29:46 -08:00
Jake VanderPlas
3cf2924ed6
[x64] minor fixes for lax_numpy_test type safety
2022-12-01 13:56:42 -08:00
Jake VanderPlas
b037feb105
[x64] more type safety for lax_numpy-related tests
2022-12-01 11:18:02 -08:00
Yash Katariya
a419e1917a
Use jax.Array by default for doctests
...
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Peter Hawkins
cd84eb10a6
Add a number of missing function cross-references in the docs.
2022-11-07 12:00:26 -05:00
Jake VanderPlas
709ffd7e77
[typing] annotate jax.numpy reduction operations
2022-10-26 13:33:15 -07:00
Jake VanderPlas
8ac9ea312a
[typing] annotate jax.scipy.special
2022-10-13 12:16:12 -07:00
Peter Hawkins
9ab88071a7
Avoid loading scipy eagerly.
...
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
ed06838006
[typing] clear up logic in scatter_update
...
Static type checkers do not parse deeply enough to know that by line 182
bucket_size cannot by None; branching on an explicit None check is easier
to follow (even for human readers)
2022-09-14 15:32:52 -07:00
Jake VanderPlas
d2f80ef117
[x64] deprecate unsafe type casting in scatter-update operations
2022-06-09 15:21:49 -07:00
Jake VanderPlas
0a72adbd5e
lax_numpy: factor out indexing tricks
2022-03-17 11:05:45 -07:00
Jake VanderPlas
4f6f4e5554
segment_max: fix identity for boolean dtype
2022-03-15 09:20:20 -07:00
jax authors
cf1161ff8b
Merge pull request #9826 from froystig:lax-cleanup2
...
PiperOrigin-RevId: 433827272
2022-03-10 12:48:34 -08:00
Peter Hawkins
adb37d9d52
Speed up compilation time of segment_sum() operators with bucketing enabled.
2022-03-10 09:26:46 -05:00
Roy Frostig
8f93629e87
remove _convert_element_type
from public jax.lax
module
2022-03-09 18:46:38 -08:00
Peter Hawkins
f51a05a889
Remove jax.ops.index... functions.
...
These functions have been deprecated and have issued a DeprecationWarning since jax 0.2.22 in October 2021.
2022-02-24 09:36:28 -05:00
Jake VanderPlas
47e88ded05
[x64] ensure scatter functionality preserves weak_type
2021-11-30 15:43:06 -08:00
Peter Hawkins
4679f455f9
Change the default out-of-bounds behavior for jax.ops.segment_... to FILL_OR_DROP.
...
This matches the documented behavior.
Fixes https://github.com/google/jax/issues/8634
PiperOrigin-RevId: 411635687
2021-11-22 13:32:58 -08:00
Peter Hawkins
9fee130d6b
[JAX] Update users of jax.ops.index...
functions, which are deprecated.
...
* replace uses of `jax.ops.index[...]` with `jax.numpy.index_exp[...]`, which is a standard NumPy function that does the same thing.
* remove some redundant uses of `jax.ops.index[...]`, where the expression is passed directly to an indexed accessor function like `.at[...]`.
* update some remaining users of `jax.ops.index_update(x, jax.ops.index[idx], y)` to use the `x.at[idx].set(y)` APIs.
PiperOrigin-RevId: 404395250
2021-10-19 16:39:09 -07:00
Peter Hawkins
104a46594b
Add DeprecationWarnings to jax.ops.index_... operators.
...
Remove uses of index_... in Common Gotchas notebook.
2021-10-05 20:47:22 -04:00
Peter Hawkins
6a284ce5ad
Fix incorrect EllipsisType reference for Python 3.10
2021-10-05 16:16:59 -04:00
Jake VanderPlas
c35b2f2485
DOC: move index update API docs to jnp.ndarray.at
...
- Add docstring to abstract property
- Add explicit HTML documentation of this property
- Mark index update functions as deprecated, linking to this documentation
2021-10-01 14:06:08 -07:00
Peter Hawkins
867068821e
Drop out-of-bounds indexes in gather.
2021-09-23 10:35:03 -04:00
Peter Hawkins
46288a299e
Fix wrong dtype output from indexing with an empty slice.
...
Fix test failures in lax_numpy_indexing_test for complex64 pow() scatters.
The existing tests actually catch both of these bugs, but only when run with a sufficiently high number of test cases.
2021-09-22 15:24:56 -04:00
Jake VanderPlas
08e1c831ba
Validate shapes for boolean indices
2021-08-03 15:33:44 -07:00
Peter Hawkins
2168483a62
Add x.at[idx].get().
...
This allows the sorted/unique keyword arguments to be passed to indexed gather operations.
2021-07-07 08:51:09 -04:00
Peter Hawkins
1ff12f05b3
Add unique/sorted annotations to gather().
...
XLA itself does not consume these, but they can be propagated onto scatter() when computing gradients.
Compute unique/sorted information on indexed accesses and indexed updates. Non-advanced indexes are always sorted and unique.
2021-06-09 21:05:41 -04:00
Peter Hawkins
44c98ad4e8
Improve JVP rule for scatters with non-overlapping indices.
...
If the scattered values don't overlap, we don't need complicated masking logic to work out which of the two overlapping values "win".
2021-05-12 14:16:35 -04:00
Jake VanderPlas
cbb7052379
Implement segment_prod, segment_max, segment_min
2021-04-09 12:06:51 -07:00
Jake VanderPlas
48ac77d272
jax.ops.segment_sum: improve input validation
2021-04-08 08:39:58 -07:00
Jake VanderPlas
8e789c7380
Run doctest on all source files except jax2tf
2021-04-05 10:39:59 -07:00
Neil Girdhar
af0988ee7f
Annotate scatter and random
2021-03-24 20:06:17 -04:00
Jake VanderPlas
cfe934c053
Fix some doc build warnings
2021-01-25 14:08:57 -08:00
Peter Hawkins
9ffe009964
Change segment_sum to drop out-of-bounds indices rather than wrap them.
...
This is a breaking change.
2021-01-13 10:54:27 -05:00
Peter Hawkins
3ac809ede3
[JAX] Move jax.util to jax._src_util.
...
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
AdrienCorenflos
3bc34e4f5a
Update scatter.py
...
Looks like the documentation was not up to date
2020-11-05 18:13:44 +02:00
Peter Hawkins
ef57858deb
Move jax.ops implementation into jax._src.ops.
2020-10-17 11:45:28 -04:00