149 Commits

Author SHA1 Message Date
George Necula
a0812cd57e [better_errors] Make it explicit that debug_info is not None.
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.

For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.

See https://github.com/jax-ml/jax/issues/26480 for more details.

PiperOrigin-RevId: 726770483
2025-02-13 22:07:04 -08:00
George Necula
550d1aa187 [better_errors] Continue adding debug info to Jaxprs (step 6)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
2025-02-11 11:28:58 +01:00
George Necula
1e813e1693 [better_errors] Continue adding debug info to Jaxprs (step 4)
This follows after #26078, #26313, #26348, adding `debug_info` to more calls to `lu.wrap_init`.

As part of this I have changed the primitive `custom_transpose` to take the `transpose` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`.

These changes ensure that all the `lu.wrap_init` and `Jaxpr` are called with debug_info in the `api_test.py:CustomTransposeTest`.
2025-02-08 09:13:55 +02:00
George Necula
904b74860c [better_errors] Continue adding debug info to Jaxprs (step 3)
This follows after #26078, and #26313, adding `debug_info` to
more calls to `lu.wrap_init`.

As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
2025-02-06 16:26:49 +02:00
George Necula
abcaec7081 [better_errors] Add debug info to the Jaxprs formed for AD
Following #26078 , we add debug info to more calls of lu.wrap_init.
2025-02-05 19:21:02 +02:00
Yash Katariya
bc1a706688 [sharding_in_types] Add a canonicalize_value step before dispatching bind so that we can insert mesh_casts under the following conditions:
* When current_mesh is Manual and aval mesh is Auto

* When current mesh is set and aval mesh is unset

* Final style primitives skip this canonicalization and they are free to add it in their own `bind` method.

* `mesh_cast` is skipped from this canonicalization to avoid recursion errors.

This is required to make sure that after we hit abstract_eval rule and check_jaxpr, everything is properly typed in JAX's type system.

`Auto` right now is a bit more permissive because we need to keep the current code at HEAD working but `Explicit` and `Manual` are very strict.

PiperOrigin-RevId: 722868091
2025-02-03 18:00:19 -08:00
Dan Foreman-Mackey
28d573354b Add DCE rules for custom_jvp and custom_vjp. 2025-01-23 15:22:43 -05:00
Dan Foreman-Mackey
e3b3b913f7 Add an experimental interface for customizing DCE behavior.
We use dead code elimination (DCE) throughout JAX core to remove unused computations from Jaxprs. This typically works transparently when we're just using `lax` primitives, but opaque calls to `pallas_call` or `ffi_call` can't be cleaned up this way. For many kernels however, the author will know how to generate a more efficient call for specific patterns of used outputs, so it is useful to provide a mechanism for customizing this behavior.

In https://github.com/jax-ml/jax/pull/22735, I attempted to automatically tackle one specific example of this that comes up frequently, but there have been feature requests for a more general API. This version is bare bones and probably rough around the edges, but it could be a useful starting point for iteration.

PiperOrigin-RevId: 718950828
2025-01-23 11:38:47 -08:00
George Necula
dcf72b01f4 [better_errors] Improvements in propagation of debugging info
Added some documentation for `TracingDebugInfo` (docstring, comments
about `arg_names`, since it was not obvious to me that this would
flatten the non-static arguments).

Laying the ground for the unification of the old `api_util.debug_info`
and `partial_eval.tracing_debug_info`: we rename the former to
`api_util.tracing_debug_info`, we push inside the calls to
`fun_sourceinfo` and `fun_signature` (which were done by the callers
until now), and we rewrite the latter in terms
of the former. We leave for a future PR the actual replacing of the
latter with the former throughout.

In the process of above, cleaned up the one case when `partial_eval.tracing_debug_info`
received None for the `in_tree` and `out_tracer_thunk`. The function contained
catch-all exception clauses to handle those, but doing so it masked other places
where we fail to collect debug info due to programming mistakes. E.g., in
one place we passed a `WrappedFun` instead of a `Callable`, resulting in missing debugging info.

Added more type declarations.

Added a `state_test` with a failure to track debugging information, manifested
with a leaked tracer without function provenance. Fixing this in a subsequent PR.
2025-01-20 15:09:51 +01:00
Matthew Johnson
b6482f126e add mutable array ref error checks to cond and custom_vjp 2024-12-20 01:44:50 +00:00
Jake VanderPlas
40367a9eaf Cleanup: remove uses of no-op raise_to_shaped 2024-12-12 09:49:06 -08:00
Dougal
1c9b23c566 Stop using generators for linear_util transformations.
They lead to confusing code, nasty bugs, and unhelpful (but terse!) stack traces.
2024-11-13 13:47:07 -08:00
James Martens
310ff7347c Change to internal dead code elimination. Now the functions in dce_rules are responsible for checking if the equation has no used outputs or effects, and behaving appropriately in that case (which usually means eliminating said equation).
PiperOrigin-RevId: 695789033
2024-11-12 10:37:04 -08:00
Dougal Maclaurin
478b750c29 Reverts f281c6f46475270a57a02416469226315377592c
PiperOrigin-RevId: 693339094
2024-11-05 07:17:14 -08:00
Dougal Maclaurin
f281c6f464 Reverts ec39b592f7c096b0b8183723feaab2ed0d001041
PiperOrigin-RevId: 692949053
2024-11-04 06:54:06 -08:00
Dougal Maclaurin
ec39b592f7 Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat)
PiperOrigin-RevId: 692557993
2024-11-02 17:03:50 -07:00
Sergei Lebedev
bdf2ca10fc Removed more dead code from various submodules
PiperOrigin-RevId: 691342832
2024-10-30 02:41:53 -07:00
Dougal Maclaurin
c36e1f7c1a Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
George Necula
5ccfc8d716 Reverts c3b4b76080dbedfebfed978c812338e2f680ee23
PiperOrigin-RevId: 690990311
2024-10-29 06:07:15 -07:00
Matthew Johnson
86a47a7d4e fix jax.custom_gradient to allow closing over non-autodiff tracers 2024-10-29 00:32:01 +00:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Matthew Johnson
7571b9e7f8 custom_vjp: don't drop tangents just because they have a different dtype than the primal
instead, drop them when primal_aval.to_tangent_aval().dtype == float0

TODO: don't do that either. we shouldn't drop the user's output on the floor;
we should require that their rule produce a value of the correct float0 dtype,
or else produce a special symbol that means "zero of whatever type I need" (and
that symbol should probably be a None). but i'm not doing that TODO right now...
2024-09-19 23:31:40 +00:00
Dougal Maclaurin
018189491b Clean up and fix primal type to tangent type mapping
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00
Dan Foreman-Mackey
dbb34f56dd Raise a clearer error message when closure_converted function is
called with inputs with the wrong structure.

Fixes https://github.com/google/jax/issues/23588
2024-09-17 15:04:09 -04:00
Jake VanderPlas
91f5512965 Document methods of custom_jvp/custom_vjp 2024-08-14 15:37:20 -07:00
Dan Foreman-Mackey
850edee36e Fix bug in custom_vjp with optimize_remat and custom_vmap.
When used with a `custom_vmap` that introduces a new const the previous
implementation of `optimize_remat` would error in its DCE rule because
of unexpected consts when closing the fwd jaxpr. This shouldn't have
ever been hit, but there was a bug in the batching rule for
`remat_opt_p` where we weren't properly converting constvars to invars.
This fixes this bug and should unbreak internal users.
2024-08-13 09:06:57 +01:00
Dan Foreman-Mackey
69fc8bb419 Consolidate handling of input argument resolution in custom_* APIs.
This is a partial re-land of https://github.com/google/jax/pull/22869 with some updates to ensure that it doesn't break existing uses of `custom_vmap`.

Previously, using a `custom_jvp` or `custom_vjp` with a primal function that has keyword-only arguments would result in a type error, even if these arguments weren't passed by the caller. I believe that this check is actually slightly stricter than it needed to be, as discovered when adding a similar check to `custom_vmap`. Instead, I think that it is sufficient to check that the caller hasn't _passed_ any keyword-only arguments.

The previous behavior in `custom_vmap` was even harsher: it would error if any keyword arguments were passed.

In this change, I have moved `resolve_kwargs` into `api_utils` so that the same function can be used in both `custom_derivatives` and `custom_batching`. I've also updated the logic to only throw a `TypeError` if the caller passes a keyword only argument when calling a `custom_*`-decorated function. This changes the behavior of `custom_jvp` and `custom_vjp`, although users shouldn't see that effect, since previously having kwargs would have errored.

PiperOrigin-RevId: 662402158
2024-08-13 00:30:23 -07:00
Dan Foreman-Mackey
efb7721671 Remove unnecessary constraint on keyword-only arguments in custom_vjp with optimize_remat=True.
PiperOrigin-RevId: 660945559
2024-08-08 12:49:27 -07:00
Jake VanderPlas
551f72979c Rollback of #22869
This is causing breakages due to overly-restrictive checks on kwargs

Reverts 893ae6eb800851b1c17c437982608bb59d3bc6be

PiperOrigin-RevId: 660803968
2024-08-08 06:00:17 -07:00
Dan Foreman-Mackey
1d425b2d30 Small tweaks to custom_vmap UI.
I'm working on some extensions to `custom_vmap` and came across these
small UI improvements (I think!). This includes the two changes:

1. A weakening of the `kwargs` check to be consistent with the one in
   `custom_vjp`/`custom_jvp`, and
2. An improved error message when `def_vmap` isn't called.
2024-08-05 17:51:47 +01:00
jax authors
2241dadab6 Merge pull request #22814 from superbobry:maint-2
PiperOrigin-RevId: 658560253
2024-08-01 15:31:40 -07:00
Sergei Lebedev
fb1dbf15df Bumped mypy to 1.11.0 and jaxlib to 0.4.31 on the CI 2024-08-01 22:30:24 +01:00
Dan Foreman-Mackey
80cfe83ddc Fix issue with multiple arguments when using custom_vjp with optimize_remat.
This was a silly bug in how we were handling the fact that the `fwd`
function expects `bool` entries for symbolic zeros. At least now I've
added a test!
2024-08-01 08:44:35 -04:00
Dan Foreman-Mackey
30d5a78b1c Add optional automatic remat optimization to custom_vjp.
As reported in https://github.com/google/jax/issues/21303, using `remat`
with `custom_vjp` can produce inefficient results. The high level
summary is that computing the grad of such a function results in the
`fwd` function of the `custom_vjp` being evaluated twice, even though
the first time the residuals are not actually used. In many cases this
isn't a problem because DCE will clean up the unnecessary computations.
But, when the fwd function requires an opaque call (e.g. pallas_call or
ffi_call), this no longer saves the day.

In this PR, I have added a parameter to `custom_vjp` called
`optimize_remat` (open for discussion!), which can be used to opt-in to
automatic optimization of this operation. Setting this flag to true
results in the `fwd` function being wrapped in a new custom primitive
which will DCE into a call to the primal function whenever the residuals
are unused.

This can be used to fix https://github.com/google/jax/issues/21303, and
I think it would make sense to eventually make this behavior the
default, but this implementation comes with a few caveats:

1. This feature is currently implemented in "initial style", which means
   that the `fwd` function is traced to a jaxpr when it is initially
   called. This means that when `optimize_remat=True`, the `custom_vjp`
   function doesn't support data dependent conditionals within `fwd`.
   This isn't a fundamental limitation of the method, but this
   implementation is much simpler so it seemed like a good place to
   start, and much of the complexity of the "final style" version of
   this logic should be simplified by work that @dougalm is doing.
   Furthermore, for the immediate use case of opaque calls, initial
   style is not a serious limitation.
2. When `optimize_remat=True`, symbolic zeros are not supported. Again
   this isn't a required restriction, but I chose to start without this
   added complexity and we can add support for symbolic zeros as needed
   in the future.
3. More subtly, while this new primitive supports `vmap`, it doesn't
   currently implement rules for composing with the AD system. This
   means that a `custom_vjp` constructed with `optimize_remat=True`
   won't currently work with some approaches to higher-order AD. I
   expect I know how to fix that and will either include that here or in
   a follow-up.
2024-07-31 10:48:29 -04:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Peter Hawkins
8ab0c07edc Don't wrap singleton ir.Values with tuples during HLO lowering.
In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons.

To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
2024-07-01 16:11:00 -04:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Dan Foreman-Mackey
6d35b109fd Rename "Example" to "Examples" in docstrings.
This PR updates all docstrings that previously had a section heading
called "Example" and replaces that with "Examples" to be consistent.
2024-06-21 11:43:16 -04:00
Matthew Johnson
2f4cc6e9cb [custom_vjp] allow bwd rule to produce top-level list (not just tuple) 2024-05-31 21:49:06 +00:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Matthew Johnson
3d4687fbfc add a temporary config option to disable custom_vjp shape checking 2024-04-04 18:21:10 -07:00
Matthew Johnson
6f38f277b9 temporarily relax the cotangent dtype check introduced in #19009
PiperOrigin-RevId: 615883208
2024-03-14 13:22:42 -07:00
Matthew Johnson
1326c74db7 add a shape mismatch check and error to custom_vjp
no idea how we lasted so long without this...
2024-03-13 19:57:09 -07:00
Matthew Johnson
3736b322b7 [xmap-removal] remove reduce_axes from grad / vjp / backward_pass
The reduce_axes machinery was planned to be used for xmap. It's not needed for
e.g. shard_map, see https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html.
2024-02-25 15:50:54 -08:00
Matthew Johnson
bc1e5f0220 [custom_vjp] handle Nones in subtrees returned by bwd rule
fixes #8356
2024-02-21 00:37:04 -08:00
Peter Hawkins
f1ea67117e Split name_stack out of mlir.ModuleContext.
A unique name_stack is built for every equation, which means that we're constantly rebuilding ModuleContext objects, even though the lifetime of almost everything else (naturally) is the Module scope. Split name_stack into an object that is threaded separately, including as part of mlir.LoweringRuleContext.

PiperOrigin-RevId: 608594374
2024-02-20 07:17:23 -08:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
Yash Katariya
4c9241ecda Cache ClosedJaxpr creation to minimize cache_misses. ClosedJaxpr should always be created under a cache.
PiperOrigin-RevId: 593023314
2023-12-21 22:15:52 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Sergei Lebedev
65d3058944 Migrate a subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

PiperOrigin-RevId: 571932143
2023-10-09 07:29:53 -07:00