Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:
* remove symbolic_equal_dim and symbolic_equal_shape in favor of the
newer definitely_equal and definitely_equal_shape
* remove is_special_dim_size, which checks that a value is a
dimension expression (not a constant). Some uses are replaced
with `not is_constant_dim` and others with `is_dim`.
* introduce concrete_dim_or_error to check that a value is
a dimension
The advantage (already being realized) is that the batching rules
become much simpler: we just batch along the stacked axis as always,
and when a reduction is about to occur, also mask out the padding
elements, replacing them with the identity element of the reduction.
This commit
- Changes the intended representation of data for piles and the
corresponding BatchTracers.
- Re-defines ConcatAxis as RaggedAxis to represent the metadata.
- Updates `defreducer` to require the identity function (in case
masking is needed), and supplies it everywhere.
- Flushes batching.segment_sum, as it is dead code now.
- Deletes unpack_concat_axes and reassemble_concat_axes, because they
are irrelevant to the padded representation.
This was called shape_poly.compute_dim_values. We rename it to
shape_poly.unify_avals_with_args and we add better error reporting to it.
Now it will identify the arg/kwarg where there is a shape discrepancy.
This is intended to be a pure refactoring, in preparation for adding
support for shape polymorphism to jax_export.call_exported.
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
This change re-introduces symbolic zero support for `custom_vjp`.
This time:
* The forward rule API is slightly different, accepting two-field
records at pytree leaves rather than pairs.
* In the default setting where symbolic_zeros is not set, there are no
new requirements from pytree node definitions that are involved in
the primal arguments. This avoids any change in behavior on the
default path. In particular, custom pytree node definitions that
aren't completely polymorphic in unflattening can remain as is.
* There is an additional test involving a custom pytree node.
We unify the way we compute with dimension variables (computing
their values from the shape of the actual arguments, and also
using those values to evaluate shapes that contain dimension variables).
We remove DimExprValueMlir, and all computations with dimension variables
and DimExpr are now done by JAX interpretation, followed by lowering to
TF or StableHLO.
Give pjit_p a custom typecheck rule, which basically just calls the
core._check_call utility (which was made for xla_call_p and core.call_p).
This revealed the need for a slight generalization of the custom_typecheck rule
signature, for better "context-aware" printing of jaxpr type errors: the rules
should have a `ctx_factory` first argument. **The reason this PR touches so
many files is just that it makes the trivial tweaks to all existing typecheck
rules to accomodate that new signature.** I didn't adapt any other higher-order
primitives' rules to actually use the context, but presumably errors for HOPs
like scan would be improved by using it. Follow-up work!
It's key that core._check_call works with dynamic shapes; this PR is soon to be
followed by some djax+pjit PRs!
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
I was looking at some profiles and noticed canonicalize_shape showing up as a noticeable
overhead in certain cases. Which makes sense, given that we carefully check all possible
cases before trying to consider integers as plausible elements (which are the most popular
_by far_). And this function is pretty hot, because it gets called any time we create a new
`ShapedArray`.
I wrote a small benchmark that repeatedly calls canonicalize_shape on a 4-sized tuple of
integers.
Before:
7.62µs ± 8%
After:
1.42µs ± 2%
So a pretty easy 5x improvement overall. And in more real cases, when resharding an array
onto 8 TPUs, 50% of the time was spent on creating shapes for avals of device buffers.
PiperOrigin-RevId: 516795311
instantiate_const() must take and return a JaxprTracer.
Teach pytype that the Tracer returned by full_raise() must be an instance of the Tracer type associated with the Trace, using a Generic type.
PiperOrigin-RevId: 516554216
These are the methods that are only valid for actual materialized arrays (i.e. not Tracers)
In order to simplify the experience for users, we want to maintain only a single jax.Array
type, so we define all methods here and raise explicit errors on Tracer instances.
I noticed some slightly-too-general type annotations in core.py. By tightening
them we could simplify the code too. (I think these were leftovers from
pre-omnistaging...)