Michael Hudgins
d4d1518c3d
Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
...
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Jake VanderPlas
26f2f97805
Document why 'import name as name' is used
2022-12-14 15:07:04 -08:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
5782210174
CI: fix flake8 ignore declarations
2022-04-21 13:44:12 -07:00
Peter Hawkins
f51a05a889
Remove jax.ops.index... functions.
...
These functions have been deprecated and have issued a DeprecationWarning since jax 0.2.22 in October 2021.
2022-02-24 09:36:28 -05:00
Jake VanderPlas
245581411e
Add PEP484-compatible export for jax and its subpackages
2021-09-13 14:08:48 -07:00
Jake VanderPlas
cbb7052379
Implement segment_prod, segment_max, segment_min
2021-04-09 12:06:51 -07:00
Peter Hawkins
ef57858deb
Move jax.ops implementation into jax._src.ops.
2020-10-17 11:45:28 -04:00
Peter Hawkins
aa107cf1f4
Move jax.numpy internals into jax._src.numpy.
2020-10-16 20:35:19 -04:00
Qiao Zhang
35d231990c
Add ceil_of_ratio util and bucket_size TODO.
2020-09-21 14:45:37 -07:00
Qiao Zhang
614acce43c
Change segment_sum to use no bucketing by default.
2020-09-21 09:46:26 -07:00
Qiao Zhang
bbe3a6a9a2
Improve segment_sum stability by k-way summation.
2020-09-16 11:11:39 -07:00
Alvaro
ca1d8f4109
Fixing weird behavior in segment_sum when num_segments is None ( #4034 )
...
Co-authored-by: alvarosg <alvarosg@google.com>
2020-09-11 13:51:42 -04:00
James Bradbury
f574b11499
Support preconditions on scatter indices ( #3147 )
...
* wire through precondition flags to XLA scatter
* use scatter precondition flags in tests
* fix DUS batching rule
* make new arguments kw-only
* onp -> np
* fix jax2tf for new args
* fix more test failures
2020-07-21 23:16:27 -07:00
Jake Vanderplas
fb1717233a
Cleanup: deflake jax.experimental and jax.ops ( #3329 )
2020-06-05 19:00:04 -07:00
Peter Hawkins
2f09e89e72
Update internal aliases to lax_numpy to jnp instead of np. ( #2975 )
2020-05-05 20:41:57 -04:00
Peter Hawkins
714b276b9a
Implement jax.ops.index_mul. ( #2696 )
...
* Implement jax.ops.index_mul.
* Add index_mul to documentation.
* Fix RHS JVP rule for scatter_mul, fix test bug that meant it was not tested.
* Fix typo in docstring.
2020-04-13 16:16:34 -04:00
George Necula
c52f32b59d
Removed unused imports ( #2385 )
...
Also disabled a couple more linalg tests that crash on my Mac
2020-03-09 20:42:08 +01:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. ( #2117 )
...
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Peter Hawkins
45a02f39f0
Temporarily remove jit
decorator on gather/scatter ops.
2019-09-16 13:57:07 -07:00
Peter Hawkins
5ffddc182e
JIT-compile index and index-update expressions.
...
Improves the performance of indexing in op-by-op mode.
2019-09-13 10:37:41 -04:00
Peter Hawkins
549cd23f3a
Remove debug print statement and stale comment.
2019-07-21 16:35:00 -04:00
Peter Hawkins
f25b2f878b
Merge scatter and gather indexing implementations.
2019-07-16 18:55:44 +01:00
Peter Hawkins
7effcf8512
Edit some comments.
2019-07-16 09:49:35 +01:00
Peter Hawkins
0850318a83
Add support for mixing basic and advanced indexing in the same scatter operation.
2019-07-14 11:55:26 -04:00
Peter Hawkins
b45ea2b416
Remove unnecessary reshape from scatter advanced indexing.
2019-07-13 15:07:51 -04:00
Justin Lebar
d5ba04b79e
Add jax.ops.index_min/max.
...
These are analogous to index_add.
2019-06-21 19:33:34 -07:00
Peter Hawkins
e6082d203d
Add segment_sum to the docs and fix its rendering.
...
Minor doc fixes.
2019-05-08 21:18:17 -04:00
Matthew Johnson
3ab52646f2
reviewer-suggested fixes
2019-05-06 14:37:41 -07:00
Matthew Johnson
41a7a9448d
fix up the is_advanced_int_indexer logic
2019-05-06 14:20:24 -07:00
Matthew Johnson
9adfb80625
add advanced indexing support to jax index ops
...
fixes #658
This commit adds advanced indexing support to jax index operations,
namely index_update and index_add, but does *not* add support for mixed
advanced indexing and slicing. That's left as a NotImplementedError.
This commit also added a segment_sum convenience wrapper.
2019-05-06 14:20:24 -07:00
Matthew Johnson
0cf14837c9
make a lax package, revert control flow names ( #607 )
...
c.f. #597
pair=skyewm
2019-04-12 16:28:40 -07:00
QBatista
eb4945c011
DOC: Fix typo in ops.index_update
2019-03-30 17:18:53 +09:00
Peter Hawkins
854e3b1500
Add missing jax/ops/ to jax/BUILD.
...
Fix typo in ops/scatter.py
2019-03-06 09:07:15 -05:00
Peter Hawkins
8b5e09f10a
Add new functions jax.ops.index_add
and jax.ops.index_update
for NumPy-style indexed updates.
...
Create a new library `jax.ops` for user-facing ops that don't exist in NumPy or SciPy.
Progress on issue #101 . Fixes #122 .
Reenable some disabled TPU indexing tests that now pass.
2019-03-04 15:13:14 -05:00