26 Commits

Author SHA1 Message Date
Yash Katariya
6e00b5e02d [NFC] Rename standard_insert_pbroadcast to standard_insert_pvary
PiperOrigin-RevId: 747943230
2025-04-15 11:02:45 -07:00
Yash Katariya
25c106d132 Add standard_insert_pbroadcasts and standard_vma_rule to all primitives in following files: (Don't add standard_insert_broadcast for unary ops though)
* slicing.py
* windowed_reductions.py
* special.py
* convolution.py
* fft.py
* linalg.py
* ann.py

PiperOrigin-RevId: 741327361
2025-03-27 16:56:39 -07:00
Jake VanderPlas
40367a9eaf Cleanup: remove uses of no-op raise_to_shaped 2024-12-12 09:49:06 -08:00
Dougal Maclaurin
478b750c29 Reverts f281c6f46475270a57a02416469226315377592c
PiperOrigin-RevId: 693339094
2024-11-05 07:17:14 -08:00
Dougal Maclaurin
f281c6f464 Reverts ec39b592f7c096b0b8183723feaab2ed0d001041
PiperOrigin-RevId: 692949053
2024-11-04 06:54:06 -08:00
Dougal Maclaurin
ec39b592f7 Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat)
PiperOrigin-RevId: 692557993
2024-11-02 17:03:50 -07:00
Dougal Maclaurin
3b89a2e573 Add a utility function to create a tangent zero value from a primal value.
PiperOrigin-RevId: 676449863
2024-09-19 09:42:12 -07:00
Dougal Maclaurin
018189491b Clean up and fix primal type to tangent type mapping
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00
Sergei Lebedev
3e1c2b3ee9 Removed dead code from add_jaxvals
PiperOrigin-RevId: 672103395
2024-09-07 11:26:33 -07:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Matthew Johnson
d5646dbac1 partial rollback of #19096 due to internal breakage (relying on jax internals)
PiperOrigin-RevId: 593175212
2023-12-22 15:54:26 -08:00
Matthew Johnson
be3ca507db del add_any_p and zeros_like_p, replace aval-dispatched traceable 2023-12-21 17:04:21 -08:00
Matthew Johnson
68635692b5 remove use of cast in ad_util
i hate cast
2023-12-20 21:00:42 -08:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Peter Hawkins
47177e1417 Split more targets out the main JAX Bazel target.
Namely:
* abstract_arrays
* ad_util
* api_util
* interpreters/partial_eval
* lax_reference
PiperOrigin-RevId: 520618715
2023-03-30 06:12:45 -07:00
Matthew Johnson
5c4525cb10 custom_jvp symbolic zeros support
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>
2023-02-24 07:33:49 -08:00
jax authors
39b14b1b1f Merge pull request #13693 from jakevdp:typing-ad-util
PiperOrigin-RevId: 496783324
2022-12-20 16:50:11 -08:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Jake VanderPlas
8c11caeadd [typing] annotate ad_util 2022-12-16 16:00:38 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Sharad Vikram
5d3f48204d Add stateful for loop primitives (#10982)
Adds a `get/swap/addupdate` primitive, along with impl, abstract_eval
and jvp rules.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-06-15 15:55:38 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Matthew Johnson
9cd55a2bbd [remove-units] remove units 2022-05-04 10:58:56 -07:00
Matthew Johnson
e7acb82b14 [remove-units] remove units from api_util.py 2022-04-26 12:31:08 -07:00
Peter Hawkins
e9611eb090 Move jax.ad_util to jax._src.ad_util.
Expose ad_util.stop_gradient_p as jax.lax.stop_gradient_p. stop_gradient() is already under the external lax namespace.

PiperOrigin-RevId: 378011152
2021-06-07 14:51:34 -07:00