375 Commits

Author SHA1 Message Date
jax authors
414449e142 Merge pull request #26078 from gnecula:debug_info_jaxpr
PiperOrigin-RevId: 723151082
2025-02-04 10:54:26 -08:00
George Necula
d12aead696 [better_errors] Add debug info to more Jaxprs and WrappedFun (step 1)
The plan is for all `core.Jaxpr` and `lu.WrappedFun` to carry
non-None debug info.

We change `lu.wrap_init` to construct the result paths thunk
whenever it is passed a `debug_info`. The goal is to make sure that
all `WrappedFun` have a debug info with result paths support.

We change some calling conventions for internal functions to not
pass along a separate debug_info if we have a `WrappedFun` or
a `Jaxpr`.

We obtain several improvements in presence of debug infos
in debug_info_test.py
2025-02-04 10:02:35 +02:00
George Necula
32c98b9a76 [better_errors] Refactor more uses of pe.tracing_debug_info (part 3)
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 3 of a series (following #26097, #26099) for jit, pmap, checkify,
and the custom_partitioning (the last few uses).

In order to land this, I had to remove a safety check that the number of
`arg_names` and `result_paths` in a Jaxpr's debug info match the number
of Jaxpr invars and outvars, respectively. Additionally, I added two
accessors `safe_arg_names` and `safe_result_paths` to ensure that
the arg names and result paths match the expected length. These accessors
return no-op results when the lengths are not as expected.
From my testint, this happens only in Jaxprs that
are not used for lowering, hence there is no actual user-visible
change here. Simply, more internal Jaxprs are getting debug_info
and in some cases the `arg_names` and `result_paths` are not correct.
Still, this change is worth it because the `func_src_info` is the most
useful part of the debug info (used for leaked tracers), and that is
accurate. We will fix the `arg_names` and `result_paths` in a future change.

One can see in the changes in debug_info_test.py the improvements in the
user-visible debug info, including for `pjit` and `pmap` cases when
it was wrong.
2025-01-30 07:40:05 +02:00
George Necula
f9dfe7f646 [better_errors] More cleanup 2025-01-15 10:22:29 +00:00
Matthew Johnson
9acd4a95b6 improve checkpoint / remat concreteness error with static_argnums 2024-12-18 04:24:54 +00:00
James Martens
310ff7347c Change to internal dead code elimination. Now the functions in dce_rules are responsible for checking if the equation has no used outputs or effects, and behaving appropriately in that case (which usually means eliminating said equation).
PiperOrigin-RevId: 695789033
2024-11-12 10:37:04 -08:00
Dougal Maclaurin
c36e1f7c1a Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
Dougal Maclaurin
63e7b7d364 Remove some untested dynamic shapes paths (prep work for stackless).
PiperOrigin-RevId: 676529297
2024-09-19 12:59:48 -07:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
George Necula
71ac0bb446 [shape_poly] More cleanup for the internal APIs for shape polymorphism.
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.

We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
2023-07-13 16:37:53 +03:00
George Necula
58d6c4c1ec Roll back #16689
PiperOrigin-RevId: 547773322
2023-07-13 06:05:50 -07:00
George Necula
d21a667235 [shape_poly] More cleanup for the internal APIs for shape polymorphism.
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.

We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
2023-07-13 09:59:41 +03:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Matthew Johnson
7743fcd758 [dynamic-shapes] make dynamic shape staging-to-jaxpr work with pjit 2023-03-23 20:20:01 -07:00
Peter Hawkins
64e1f5fe3d Revert: custom_vjp symbolic zeros support
PiperOrigin-RevId: 518597609
2023-03-22 09:56:09 -07:00
Roy Frostig
ac7491ced0 custom_vjp symbolic zeros support 2023-03-21 14:14:35 +00:00
Matthew Johnson
af63365b8e make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)

Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).

This commit includes the changes from PR #15079, so that PR should be merged first.

Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
  handle static_argnums or static_argnames correctly. Instead it would fail,
  resulting in debug info being dropped from the jaxpr and ultimately the MLIR
  computation (but no Exception raised). We need to handle
  static_argnums/argnames because while the corresponding parameters remain on
  the Python callable signature, they are excluded from the args/kwargs
  pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
  when we still have the original args/kwargs in hand, i.e. much earlier than
  the previous mechanism. We then just have to pass this debug info to the
  right places. Indeed we often already had to work out some debug-related
  information at these call sites (e.g. whether the function is being staged
  out for jit, or scan, or whatever), so after this change we're working out
  all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
  unflatten user pytree defs with dummy objects (to reconstruct dummy
  args/kwargs trees so that we can call inspect.signature(fun).bind), since we
  just use the original args/kwargs instead. Since some user pytree node types
  are not fully polymorphic in their element types (e.g. their __init__ methods
  sometimes contained assertions about their elements' shapes, expecting them
  to be arrays), that means the new mechanism is fundamentally more compatible
  with custom pytree node types.

More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
  which in addition to the more precise name has fields like
  `arg_names: Tuple[Optional[str], ...]` and
  `result_paths: Tuple[Optional[str], ...]`, rather than
  `in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
  actual debug info more eagerly than before and we don't need pytrees for
  dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
  debug info about inputs which we have available at tracing time; in a
  follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
  delete `partial_eval.DebugInfo` and its corresponding helper methods (not
  done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
  partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
  partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
  `core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
  elements from the `arg_names` field), maintaining now-checked invariants like
  a Jaxpr's `debug_info` should have the same number of argument names as the
  jaxpr has invars (the jaxpr-processing functions updated here are enough for
  top-level jit jaxprs to have debug info attached, handling the original
  intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
  be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
  debug info on their outputs);
* add some tests for static_argnums/static_argnames.

Phew! Can't wait to land those follow-ups too :P
2023-03-20 11:50:30 -07:00
Peter Hawkins
8c7ba99f82 Make Tracer types on JaxprTrace more precise.
instantiate_const() must take and return a JaxprTracer.

Teach pytype that the Tracer returned by full_raise() must be an instance of the Tracer type associated with the Trace, using a Generic type.

PiperOrigin-RevId: 516554216
2023-03-14 09:56:21 -07:00
Matthew Johnson
e9fbf3e98a [shard_map] bug fix: extend axis env in partial_eval_custom rule 2023-03-09 22:06:56 -08:00
Matthew Johnson
b05975b964 add result info to mhlo, fixes #14780
incidentally fixes #14787
2023-03-06 21:21:26 -08:00
jax authors
ad8c39ad7c Internal change
PiperOrigin-RevId: 513953876
2023-03-04 13:24:11 +00:00
Matthew Johnson
c2aa5c5eed attach debug info to jaxpr, pass to mlir/mhlo
Co-authored-by: Peter Hawkins <phawkins@google.com>
2023-03-02 17:23:58 -08:00
Matthew Johnson
5c4525cb10 custom_jvp symbolic zeros support
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>
2023-02-24 07:33:49 -08:00
Sharad Vikram
a6c4c87f3e Add JaxprInputEffect and refactor StateEffects to use it 2023-02-21 16:30:06 -08:00
Sharad Vikram
af2306c0a8 Refactor effects system to use effect types, not objects 2023-02-17 17:40:08 -08:00
Yash Katariya
7350f00acd Remove jax_experimental_subjaxpr_lowering_cache since it was only for jit and was False by default. Now that jit/pjit are merged, this cache is not needed since pjit does the caching and we get it for free.
PiperOrigin-RevId: 508191408
2023-02-08 14:55:56 -08:00
Jake VanderPlas
0b5443c6e8 Clean up: remove unused helper functions 2023-02-01 09:55:58 -08:00
Yash Katariya
78c4ed0e7a Add forwarding support to pjit which was introduced as an optimization. The inputs that are forwarded to outputs are pruned from the outputs of a known_jaxpr.
PiperOrigin-RevId: 503559787
2023-01-20 18:04:26 -08:00
Matthew Johnson
8b585302db add pjit partial_eval_jaxpr_custom rule
fix some issues with closed_call's partial_eval_jaxpr_custom rule

Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-01-11 09:30:49 -08:00
Lena Martens
caf4f7b3f7 Lift global_axis calculation from lowering in pxla.py to api.py.
Add an "explicit_global_axis_size" arg. `global_axis` used to be set to `None`
when the user did not provide an explicit axis size. After this change,
`global_axis` should never be set to `None` internally, and always contain the
size of the global axis. It's still useful to thread the information that the
user has provided an explicit axis size so we can throw explicit errors in
`pxla` when explicit axis sizes are not allowed.

Why do we need to do this? We only go down the lowering path when calling
`pmap`s impl rule (while executing or final-style transforming), but not when
initial-style transforming. The global_axis size should be computed earlier,
such that it is available for initial-style transformations/primitives, e.g. if
we round-trip a multi-host pmap computation through make_jaxpr and eval_jaxpr.

We have tests for "initial-style transform of a `pmap`", but no such test for
_multi-host_ `pmap`! Alors, this bug went unnoticed.
#13545 makes `checkify` initial-style, and because `checkify-of-pmap` is a
valid way to check a `pmap`, an internal multi-host test uncovered this bug.

PiperOrigin-RevId: 499877003
2023-01-05 07:54:53 -08:00
Jake VanderPlas
4a6bbde409 Move jax.linear_util to jax._src.linear_util 2022-12-20 14:49:27 -08:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Roy Frostig
523c6f7a53 [jax] move jax.core to jax._src.core
Re-export roughly all of the same symbols via `jax.core` for now.

Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 495766963
2022-12-15 20:35:20 -08:00
Sharad Vikram
74b136e62c Delete jax_experimental_name_stack flag
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
Matthew Johnson
f2f2faa4fa add a basic prototype of piles, behind jax_dynamic_shapes
Co-authored-by: Adam Paszke <apaszke@google.com>
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-11-06 17:03:04 -08:00
Matthew Johnson
6ebf44a681 make leak checker errors explain why objects are alive
Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2022-10-28 14:12:17 -07:00
Jake VanderPlas
5d15757741 [typing] annotate jax._src.util.safe_map 2022-10-20 10:15:04 -07:00
Jake VanderPlas
524745f322 TMP: annotate util.safe_zip 2022-10-19 10:29:53 -07:00
Jake VanderPlas
d60ceeadd0 [typing] annotate util.unzip2 & util.unzip3 2022-10-18 09:47:49 -07:00
Matthew Johnson
df5f7cb8d3 Rolling forward https://github.com/google/jax/pull/12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests.
PiperOrigin-RevId: 480229237
2022-10-10 18:51:37 -07:00
jax authors
9cabd227d7 Copybara import of the project:
--
6d2aaac2454117d54997243714c1a009827707ca by Matthew Johnson <mattjj@google.com>:

implement bint arrays (opaque dtypes), add padding rules

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
PiperOrigin-RevId: 479883102
2022-10-09 01:25:50 -07:00
Matthew Johnson
6d2aaac245 implement bint arrays (opaque dtypes), add padding rules
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-08 22:57:29 -07:00
Matthew Johnson
076a7348d0 fix -O / PYTHONOPTIMIZE bug
fixes #12688

I'm not sure how to write test cases for PYTHONOPTIMIZE=1 (without growing our
whole test matrix), so I'm leaving this untested...
2022-10-06 23:15:22 -07:00
Matthew Johnson
1e7ca8f77a fix bug in djax type signature inference logic
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-26 17:48:25 -07:00
Matthew Johnson
b6ef90ffdd fix leak checker internal error
The issue was that partial_eval.py's _memoize, used in custom_jvp, was made
into an identity function by enabling config.jax_check_tracer_leaks (from
references to the main trace (needed for the jvp_jaxpr thunk) and hence trigger
the leak checker (which would see if any references to the main trace persisted
after finishing tracing of the user function).

But after #7345, the leak checker should only trigger when actual Tracers are
leaked. So disabling the memoization when jax_check_tracer_leaks is no longer
active shouldn't be necessary. (These PR numbers seem out of order! We're not
sure why.)

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-23 12:33:45 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Yash Katariya
7fbf8ec669 Fix Forward. The fix is on the user's end. Original PR: https://github.com/google/jax/pull/12217
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 472999907
2022-09-08 08:49:40 -07:00
jax authors
14f1a345a1 roll back breakage
PiperOrigin-RevId: 472949225
2022-09-08 03:59:54 -07:00
Yash Katariya
b7e4e44cbf DCE jaxpr and trivial_jaxpr support for lower_sharding_computation
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 471274989
2022-09-06 14:09:10 -07:00
Sharad Vikram
b2a5d2c3bb Add partial_eval_custom rule for for_loop 2022-09-06 11:00:26 -07:00