35 Commits

Author SHA1 Message Date
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Jake VanderPlas
26f2f97805 Document why 'import name as name' is used 2022-12-14 15:07:04 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Peter Hawkins
f51a05a889 Remove jax.ops.index... functions.
These functions have been deprecated and have issued a DeprecationWarning since jax 0.2.22 in October 2021.
2022-02-24 09:36:28 -05:00
Jake VanderPlas
245581411e Add PEP484-compatible export for jax and its subpackages 2021-09-13 14:08:48 -07:00
Jake VanderPlas
cbb7052379 Implement segment_prod, segment_max, segment_min 2021-04-09 12:06:51 -07:00
Peter Hawkins
ef57858deb Move jax.ops implementation into jax._src.ops. 2020-10-17 11:45:28 -04:00
Peter Hawkins
aa107cf1f4 Move jax.numpy internals into jax._src.numpy. 2020-10-16 20:35:19 -04:00
Qiao Zhang
35d231990c Add ceil_of_ratio util and bucket_size TODO. 2020-09-21 14:45:37 -07:00
Qiao Zhang
614acce43c Change segment_sum to use no bucketing by default. 2020-09-21 09:46:26 -07:00
Qiao Zhang
bbe3a6a9a2 Improve segment_sum stability by k-way summation. 2020-09-16 11:11:39 -07:00
Alvaro
ca1d8f4109
Fixing weird behavior in segment_sum when num_segments is None (#4034)
Co-authored-by: alvarosg <alvarosg@google.com>
2020-09-11 13:51:42 -04:00
James Bradbury
f574b11499
Support preconditions on scatter indices (#3147)
* wire through precondition flags to XLA scatter

* use scatter precondition flags in tests

* fix DUS batching rule

* make new arguments kw-only

* onp -> np

* fix jax2tf for new args

* fix more test failures
2020-07-21 23:16:27 -07:00
Jake Vanderplas
fb1717233a
Cleanup: deflake jax.experimental and jax.ops (#3329) 2020-06-05 19:00:04 -07:00
Peter Hawkins
2f09e89e72
Update internal aliases to lax_numpy to jnp instead of np. (#2975) 2020-05-05 20:41:57 -04:00
Peter Hawkins
714b276b9a
Implement jax.ops.index_mul. (#2696)
* Implement jax.ops.index_mul.

* Add index_mul to documentation.

* Fix RHS JVP rule for scatter_mul, fix test bug that meant it was not tested.

* Fix typo in docstring.
2020-04-13 16:16:34 -04:00
George Necula
c52f32b59d
Removed unused imports (#2385)
Also disabled a couple more linalg tests that crash on my Mac
2020-03-09 20:42:08 +01:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Peter Hawkins
45a02f39f0 Temporarily remove jit decorator on gather/scatter ops. 2019-09-16 13:57:07 -07:00
Peter Hawkins
5ffddc182e JIT-compile index and index-update expressions.
Improves the performance of indexing in op-by-op mode.
2019-09-13 10:37:41 -04:00
Peter Hawkins
549cd23f3a Remove debug print statement and stale comment. 2019-07-21 16:35:00 -04:00
Peter Hawkins
f25b2f878b Merge scatter and gather indexing implementations. 2019-07-16 18:55:44 +01:00
Peter Hawkins
7effcf8512 Edit some comments. 2019-07-16 09:49:35 +01:00
Peter Hawkins
0850318a83 Add support for mixing basic and advanced indexing in the same scatter operation. 2019-07-14 11:55:26 -04:00
Peter Hawkins
b45ea2b416 Remove unnecessary reshape from scatter advanced indexing. 2019-07-13 15:07:51 -04:00
Justin Lebar
d5ba04b79e Add jax.ops.index_min/max.
These are analogous to index_add.
2019-06-21 19:33:34 -07:00
Peter Hawkins
e6082d203d Add segment_sum to the docs and fix its rendering.
Minor doc fixes.
2019-05-08 21:18:17 -04:00
Matthew Johnson
3ab52646f2 reviewer-suggested fixes 2019-05-06 14:37:41 -07:00
Matthew Johnson
41a7a9448d fix up the is_advanced_int_indexer logic 2019-05-06 14:20:24 -07:00
Matthew Johnson
9adfb80625 add advanced indexing support to jax index ops
fixes #658

This commit adds advanced indexing support to jax index operations,
namely index_update and index_add, but does *not* add support for mixed
advanced indexing and slicing. That's left as a NotImplementedError.

This commit also added a segment_sum convenience wrapper.
2019-05-06 14:20:24 -07:00
Matthew Johnson
0cf14837c9 make a lax package, revert control flow names (#607)
c.f. #597
pair=skyewm
2019-04-12 16:28:40 -07:00
QBatista
eb4945c011 DOC: Fix typo in ops.index_update 2019-03-30 17:18:53 +09:00
Peter Hawkins
854e3b1500 Add missing jax/ops/ to jax/BUILD.
Fix typo in ops/scatter.py
2019-03-06 09:07:15 -05:00
Peter Hawkins
8b5e09f10a Add new functions jax.ops.index_add and jax.ops.index_update for NumPy-style indexed updates.
Create a new library `jax.ops` for user-facing ops that don't exist in NumPy or SciPy.

Progress on issue #101. Fixes #122.

Reenable some disabled TPU indexing tests that now pass.
2019-03-04 15:13:14 -05:00