Yash Katariya
6e00b5e02d
[NFC] Rename standard_insert_pbroadcast
to standard_insert_pvary
...
PiperOrigin-RevId: 747943230
2025-04-15 11:02:45 -07:00
Yash Katariya
25c106d132
Add standard_insert_pbroadcasts and standard_vma_rule to all primitives in following files: (Don't add standard_insert_broadcast
for unary ops though)
...
* slicing.py
* windowed_reductions.py
* special.py
* convolution.py
* fft.py
* linalg.py
* ann.py
PiperOrigin-RevId: 741327361
2025-03-27 16:56:39 -07:00
Jake VanderPlas
40367a9eaf
Cleanup: remove uses of no-op raise_to_shaped
2024-12-12 09:49:06 -08:00
Dougal Maclaurin
478b750c29
Reverts f281c6f46475270a57a02416469226315377592c
...
PiperOrigin-RevId: 693339094
2024-11-05 07:17:14 -08:00
Dougal Maclaurin
f281c6f464
Reverts ec39b592f7c096b0b8183723feaab2ed0d001041
...
PiperOrigin-RevId: 692949053
2024-11-04 06:54:06 -08:00
Dougal Maclaurin
ec39b592f7
Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat)
...
PiperOrigin-RevId: 692557993
2024-11-02 17:03:50 -07:00
Dougal Maclaurin
3b89a2e573
Add a utility function to create a tangent zero value from a primal value.
...
PiperOrigin-RevId: 676449863
2024-09-19 09:42:12 -07:00
Dougal Maclaurin
018189491b
Clean up and fix primal type to tangent type mapping
...
This is part of the ["stackless"](#23299 ) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.
Changes:
1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
3. Add `to_tangent_type` calls in various other places they're missing.
4. Remove non-support for float0 in custom deriviatives?
5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00
Sergei Lebedev
3e1c2b3ee9
Removed dead code from add_jaxvals
...
PiperOrigin-RevId: 672103395
2024-09-07 11:26:33 -07:00
Peter Hawkins
7f4ef63cd8
Run pyupgrade --py310-plus
.
...
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Matthew Johnson
d5646dbac1
partial rollback of #19096 due to internal breakage (relying on jax internals)
...
PiperOrigin-RevId: 593175212
2023-12-22 15:54:26 -08:00
Matthew Johnson
be3ca507db
del add_any_p and zeros_like_p, replace aval-dispatched traceable
2023-12-21 17:04:21 -08:00
Matthew Johnson
68635692b5
remove use of cast in ad_util
...
i hate cast
2023-12-20 21:00:42 -08:00
Peter Hawkins
319ab98980
Apply pyupgrade --py39-plus.
...
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Peter Hawkins
816ba91263
Use lower-case PEP 585 names for types.
...
Issue https://github.com/google/jax/issues/16537
PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Peter Hawkins
47177e1417
Split more targets out the main JAX Bazel target.
...
Namely:
* abstract_arrays
* ad_util
* api_util
* interpreters/partial_eval
* lax_reference
PiperOrigin-RevId: 520618715
2023-03-30 06:12:45 -07:00
Matthew Johnson
5c4525cb10
custom_jvp symbolic zeros support
...
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>
2023-02-24 07:33:49 -08:00
jax authors
39b14b1b1f
Merge pull request #13693 from jakevdp:typing-ad-util
...
PiperOrigin-RevId: 496783324
2022-12-20 16:50:11 -08:00
Roy Frostig
d927a5dbf3
migrate internal dependencies from jax.core
to jax._src.core
...
... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Jake VanderPlas
8c11caeadd
[typing] annotate ad_util
2022-12-16 16:00:38 -08:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Sharad Vikram
5d3f48204d
Add stateful for loop primitives ( #10982 )
...
Adds a `get/swap/addupdate` primitive, along with impl, abstract_eval
and jvp rules.
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-06-15 15:55:38 -07:00
Jeppe Klitgaard
17de89b16a
feat: refactor code using pyupgrade
...
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```
a
2022-05-17 22:14:05 +01:00
Matthew Johnson
9cd55a2bbd
[remove-units] remove units
2022-05-04 10:58:56 -07:00
Matthew Johnson
e7acb82b14
[remove-units] remove units from api_util.py
2022-04-26 12:31:08 -07:00
Peter Hawkins
e9611eb090
Move jax.ad_util to jax._src.ad_util.
...
Expose ad_util.stop_gradient_p as jax.lax.stop_gradient_p. stop_gradient() is already under the external lax namespace.
PiperOrigin-RevId: 378011152
2021-06-07 14:51:34 -07:00