rename and simplify TypedJaxpr -> ClosedJaxpr
This change:
* simplifies code that constructs TypedJaxprs/ClosedJaxprs (because
in_avals / out_avals no longer need to be constructed), making them
easier to work with;
* correspondingly rules out a class of errors (mismatches between
invars/outvars and in_avals/out_avals);
* provides a more descriptive class name (ClosedJaxprs are like jaxprs
but they're closed in that they are packaged with their constant
values).
This is part 1 of an attempt to remove TypedJaxprs completely, or at
least significantly reduce our use of them. However, I'm not getting rid
of them entirely in this first step because it'd require bigger changes
(basically allowing all constants to be represented as literals, rather
than only scalars) that would not only touch a lot more code (jaxpr
formation, jaxpr-to-jaxpr transformations, control flow, XLA lowering)
but also might affect XLA lowering right before a conference deadline
(ICLR). Plus I'm trying to make big changes in smaller steps :)
Co-authored-by: George Necula <gcnecula@gmail.com>
* applied simple find+sed for 'master' -> 'main'
* Rename master->main in JAX API and internals (#4178)
* Started with #4174
* Renamed Trace.master to Trace.main
* Renamed core.new_master and core.new_base_master
Co-authored-by: George Necula <gcnecula@gmail.com>
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.
See https://github.com/google/jax/pull/3370 fo more information.
* block-unrolled scan implementation, via optional `_unroll` scan parameter
* index statically in the inlined path of lax.scan
* make `unroll` a required scan parameter, and test that it unrolls
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:
* instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
* instead of PartialVal((None, pval)) we use PartialVal.known(pval)
Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
Before, bound_subjaxprs was a tuple (0 or 1 values) of
a pair of a Jaxpr and its constant values. Now we close up all such Jaxprs
such that they do not take constvars and their constant values are part of the
arguments.
We also rename bound_subjaxprs to bound_subjaxpr (an optional Jaxpr)
This is first part of a simplification. In a subsequent PR I will move
the bound_subjaxpr into params, as for most higher-order primitives.
Some higher-order primitives, like 'scan' and 'while', benefit from
distinguishing constants from other inputs to their closure-converted
function arguments; the reason is that for those primitives constants
act differently from the other inputs, which are loop carries or
scanned-over values, and are handled differently by transformations. For
example, they're used differently than loop carries in lattice
fixed-point computations. As another example, in scan the constants in
the forward computation are fanned out, so when transposing scan we
generate an accumulate-add.
However, these considerations don't hold true for cond: since there's no
looping going on (and hence no lattice fixed-points), constants are
treated just like the other operands. So we don't need to carry around
the distinction. That simplifies the cond rules a bit.
Co-authored-by: Roy Frostig <frostig@google.com>