small autodidax tweaks

This commit is contained in:
Matthew Johnson 2021-08-05 04:51:24 -07:00
parent c75f77362d
commit 24de3e992c
3 changed files with 240 additions and 109 deletions

View File

@ -1766,6 +1766,7 @@
},
"outputs": [],
"source": [
"from typing import DefaultDict\n",
"from collections import defaultdict\n",
"import string\n",
"\n",
@ -1800,7 +1801,7 @@
"def vcat(ps: List[PPrint]) -> PPrint:\n",
" return sum(ps, pp(''))\n",
"\n",
"def pp_jaxpr(jaxpr: Jaxpr):\n",
"def pp_jaxpr(jaxpr: Jaxpr) -> PPrint:\n",
" namegen = (''.join(s) for r in it.count(1)\n",
" for s in it.permutations(string.ascii_lowercase, r))\n",
" names = defaultdict(lambda: next(namegen))\n",
@ -1811,15 +1812,19 @@
" return (pp(f'{{ lambda {in_binders} .') +\n",
" ((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))\n",
"\n",
"def var_str(names: Dict[Var, str], v: Var) -> str:\n",
"def var_str(names: DefaultDict[Var, str], v: Var) -> str:\n",
" return f'{names[v]}:{v.aval.str_short()}'\n",
"\n",
"def pp_eqn(names: Dict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
" lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
" rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>\n",
" pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
" for x in eqn.inputs)))\n",
" return lhs >> pp(' = ') >> rhs\n",
"def pp_eqn(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
" rule = pp_rules.get(eqn.primitive)\n",
" if rule:\n",
" return rule(names, eqn)\n",
" else:\n",
" lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
" rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>\n",
" pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
" for x in eqn.inputs)))\n",
" return lhs >> pp(' = ') >> rhs\n",
"\n",
"def pp_params(params: Dict[str, Any]) -> PPrint:\n",
" items = sorted(params.items())\n",
@ -1828,7 +1833,8 @@
" else:\n",
" return pp(' ')\n",
"\n",
"Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))"
"Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))\n",
"pp_rules: Dict[Primitive, Callable[..., PPrint]] = {}"
]
},
{
@ -2167,7 +2173,7 @@
" [bool, int, float, np.ndarray, np.float64, np.float32]}\n",
"\n",
"def handle_result(aval: ShapedArray, buf):\n",
" del aval # Unused for now.\n",
" del aval # Unused for now\n",
" return buf.to_py()\n",
"\n",
"xla_translations = {}"
@ -2332,7 +2338,7 @@
"outputs": [],
"source": [
"def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):\n",
" del num_consts # Unused.\n",
" del num_consts # Unused\n",
" new_jaxpr, new_consts = jvp_jaxpr(jaxpr)\n",
" outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,\n",
" num_consts=len(new_consts))\n",
@ -2362,7 +2368,7 @@
"outputs": [],
"source": [
"def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):\n",
" del num_consts # Unused.\n",
" del num_consts # Unused\n",
" new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))\n",
" outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,\n",
" num_consts=len(new_consts))\n",
@ -2397,7 +2403,7 @@
"outputs": [],
"source": [
"def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):\n",
" del num_consts # Unused.\n",
" del num_consts # Unused\n",
" jaxpr_type = typecheck_jaxpr(jaxpr)\n",
" if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):\n",
" raise TypeError\n",
@ -2532,6 +2538,29 @@
"print(ydot)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_end_of_cell_marker": 0,
"lines_to_next_cell": 1,
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"def pprint_xla_call(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
" lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
" params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'}\n",
" rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >>\n",
" pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
" for x in eqn.inputs)))\n",
" return vcat([lhs >> pp(' = ') >> rhs,\n",
" pp_jaxpr(eqn.params['jaxpr']).indent(2)])\n",
"pp_rules[xla_call_p] = pprint_xla_call"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -2729,11 +2758,11 @@
"operations out of Python first before sorting out what can be evaluated now\n",
"and what must be delayed, we want only to form a jaxpr for those operations\n",
"that _must_ be delayed due to a dependence on unknown inputs. In the context\n",
"of automatic differentiation, this is the feature that ultimately enables us to\n",
"handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control\n",
"flow works because partial evaluation keeps the primal computation in Python.\n",
"As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out\n",
"what can be evaluated and what must be staged out into a jaxpr.\n",
"of automatic differentiation, this is the feature that ultimately enables us\n",
"to handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python\n",
"control flow works because partial evaluation keeps the primal computation in\n",
"Python. As a consequence, our `Trace` and `Tracer` subclasses must on the fly\n",
"sort out what can be evaluated and what must be staged out into a jaxpr.\n",
"\n",
"First, we start with a `PartialVal` class, which represents a value that can\n",
"be either known or unknown:"
@ -2803,8 +2832,9 @@
"do so, it builds a bipartite directed acyclic graph (DAG) between\n",
"`PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe`\n",
"nodes, representing formulas for how to compute some values from others. One\n",
"kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive\n",
"application, but we also have recipe types for constants and lambda binders:"
"kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s\n",
"primitive application, but we also have recipe types for constants and lambda\n",
"binders:"
]
},
{
@ -2945,11 +2975,12 @@
"source": [
"def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer],\n",
" tracers_out: List[PartialEvalTracer]):\n",
" tracer_to_var = {id(t): Var(raise_to_shaped(t.aval)) for t in tracers_in}\n",
" constvar_to_val = {}\n",
" constid_to_var = {}\n",
" processed_eqns = set()\n",
" eqns = []\n",
" tracer_to_var: Dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))\n",
" for t in tracers_in}\n",
" constvar_to_val: Dict[int, Any] = {}\n",
" constid_to_var: Dict[int, Var] = {}\n",
" processed_eqns: Set[int] = set()\n",
" eqns: List[JaxprEqn] = []\n",
" for t in toposort(tracers_out, tracer_parents):\n",
" if isinstance(t.recipe, LambdaBindingRecipe):\n",
" assert id(t) in set(map(id, tracers_in))\n",
@ -3083,7 +3114,7 @@
"outputs": [],
"source": [
"def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):\n",
" del num_consts # Unused.\n",
" del num_consts # Unused\n",
" in_unknowns = [not t.pval.is_known for t in tracers]\n",
" jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)\n",
" known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)\n",
@ -3106,8 +3137,8 @@
" env: Dict[Var, bool] = {}\n",
" residuals: Set[Var] = set()\n",
"\n",
" def read(v: Atom) -> bool:\n",
" return type(v) is Var and env[v]\n",
" def read(x: Atom) -> bool:\n",
" return type(x) is Var and env[x]\n",
"\n",
" def write(unk: bool, v: Var) -> None:\n",
" env[v] = unk\n",
@ -3139,6 +3170,7 @@
" out_unknowns = map(op.or_, out_unknowns, instantiate)\n",
"\n",
" residuals, num_res = list(residuals), len(residuals)\n",
" assert all(type(v) is Var for v in residuals), residuals\n",
"\n",
" ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)\n",
" outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)\n",
@ -3170,16 +3202,16 @@
"partial_eval_jaxpr_rules = {}\n",
"\n",
"def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,\n",
" ) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:\n",
" ) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Var]]:\n",
" jaxpr = eqn.params['jaxpr']\n",
" jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)\n",
" ins1, ins2 = partition_list(unks_in, eqn.inputs)\n",
" outs1, outs2 = partition_list(unks_out, eqn.out_binders)\n",
" residuals, _ = split_list(jaxpr2.in_binders, num_res)\n",
" out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)\n",
" residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]\n",
" eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),\n",
" outs1 + residuals)\n",
" out_binders1 + residuals)\n",
" eqn2 = JaxprEqn(xla_call_p, residuals + ins2,\n",
" dict(jaxpr=jaxpr2, num_consts=0), outs2)\n",
" dict(jaxpr=jaxpr2, num_consts=0), out_binders2)\n",
" return eqn1, eqn2, unks_out, residuals\n",
"partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn"
]
@ -3395,7 +3427,7 @@
"transpose_rules[add_p] = add_transpose_rule\n",
"\n",
"def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):\n",
" del num_consts # Unused.\n",
" del num_consts # Unused\n",
" undef_primals = [type(x) is UndefPrimal for x in invals]\n",
" transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))\n",
" residuals, _ = partition_list(undef_primals, invals)\n",
@ -3804,7 +3836,7 @@
"abstract_eval_rules[cond_p] = cond_abstract_eval\n",
"\n",
"def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):\n",
" del in_avals # Unused.\n",
" del in_avals # Unused\n",
" pred, *in_vals = in_vals\n",
" flat_vals, in_tree = tree_flatten(in_vals)\n",
" operand = xops.Tuple(c, flat_vals)\n",
@ -3857,6 +3889,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_end_of_cell_marker": 0,
"lines_to_next_cell": 1
},
"outputs": [],
@ -3954,7 +3987,8 @@
" eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],\n",
" dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),\n",
" outs2)\n",
" return eqn1, eqn2, unks_out, [eqn.inputs[0], *residuals]\n",
" res = [eqn.inputs[0], *residuals] if type(eqn.inputs[0]) is Var else residuals\n",
" return eqn1, eqn2, unks_out, res\n",
"partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn"
]
},
@ -4002,12 +4036,37 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)\n",
"print(out)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"def pprint_cond(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
" true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']\n",
" new_params = {k:v for k, v in eqn.params.items() if not k.endswith('jaxpr')}\n",
" lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
" rhs = (pp(eqn.primitive.name) >> pp_params(new_params) >>\n",
" pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
" for x in eqn.inputs)))\n",
" return vcat([lhs >> pp(' = ') >> rhs,\n",
" pp_jaxpr(true_jaxpr).indent(2),\n",
" pp_jaxpr(false_jaxpr).indent(2)])\n",
"pp_rules[cond_p] = pprint_cond"
]
}
],
"metadata": {

View File

@ -1304,6 +1304,7 @@ def make_jaxpr_v1(f, *avals_in):
```{code-cell}
:tags: [hide-input]
from typing import DefaultDict
from collections import defaultdict
import string
@ -1338,7 +1339,7 @@ def pp(s: Any) -> PPrint:
def vcat(ps: List[PPrint]) -> PPrint:
return sum(ps, pp(''))
def pp_jaxpr(jaxpr: Jaxpr):
def pp_jaxpr(jaxpr: Jaxpr) -> PPrint:
namegen = (''.join(s) for r in it.count(1)
for s in it.permutations(string.ascii_lowercase, r))
names = defaultdict(lambda: next(namegen))
@ -1349,15 +1350,19 @@ def pp_jaxpr(jaxpr: Jaxpr):
return (pp(f'{{ lambda {in_binders} .') +
((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))
def var_str(names: Dict[Var, str], v: Var) -> str:
def var_str(names: DefaultDict[Var, str], v: Var) -> str:
return f'{names[v]}:{v.aval.str_short()}'
def pp_eqn(names: Dict[Var, str], eqn: JaxprEqn) -> PPrint:
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return lhs >> pp(' = ') >> rhs
def pp_eqn(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:
rule = pp_rules.get(eqn.primitive)
if rule:
return rule(names, eqn)
else:
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return lhs >> pp(' = ') >> rhs
def pp_params(params: Dict[str, Any]) -> PPrint:
items = sorted(params.items())
@ -1367,6 +1372,7 @@ def pp_params(params: Dict[str, Any]) -> PPrint:
return pp(' ')
Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))
pp_rules: Dict[Primitive, Callable[..., PPrint]] = {}
```
```{code-cell}
@ -1611,7 +1617,7 @@ input_handlers = {ty: default_input_handler for ty in
[bool, int, float, np.ndarray, np.float64, np.float32]}
def handle_result(aval: ShapedArray, buf):
del aval # Unused for now.
del aval # Unused for now
return buf.to_py()
xla_translations = {}
@ -1709,7 +1715,7 @@ level." Let's fix that!
```{code-cell}
def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
del num_consts # Unused.
del num_consts # Unused
new_jaxpr, new_consts = jvp_jaxpr(jaxpr)
outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,
num_consts=len(new_consts))
@ -1732,7 +1738,7 @@ def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:
```{code-cell}
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
del num_consts # Unused.
del num_consts # Unused
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
num_consts=len(new_consts))
@ -1760,7 +1766,7 @@ def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
```{code-cell}
def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):
del num_consts # Unused.
del num_consts # Unused
jaxpr_type = typecheck_jaxpr(jaxpr)
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
raise TypeError
@ -1857,6 +1863,20 @@ print(y)
print(ydot)
```
```{code-cell}
:tags: [hide-input]
def pprint_xla_call(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'}
rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return vcat([lhs >> pp(' = ') >> rhs,
pp_jaxpr(eqn.params['jaxpr']).indent(2)])
pp_rules[xla_call_p] = pprint_xla_call
```
## Part 4: `linearize` and `vjp` (and `grad`!)
The `linearize` and `vjp` autodiff functions are built on `jvp`, but involve
@ -2021,11 +2041,11 @@ forming a jaxpr for the entire function `(a1, a2) -> (b1, b2)`, staging all
operations out of Python first before sorting out what can be evaluated now
and what must be delayed, we want only to form a jaxpr for those operations
that _must_ be delayed due to a dependence on unknown inputs. In the context
of automatic differentiation, this is the feature that ultimately enables us to
handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control
flow works because partial evaluation keeps the primal computation in Python.
As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out
what can be evaluated and what must be staged out into a jaxpr.
of automatic differentiation, this is the feature that ultimately enables us
to handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python
control flow works because partial evaluation keeps the primal computation in
Python. As a consequence, our `Trace` and `Tracer` subclasses must on the fly
sort out what can be evaluated and what must be staged out into a jaxpr.
First, we start with a `PartialVal` class, which represents a value that can
be either known or unknown:
@ -2071,8 +2091,9 @@ interpreter will build a jaxpr on the fly while tracking data dependencies. To
do so, it builds a bipartite directed acyclic graph (DAG) between
`PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe`
nodes, representing formulas for how to compute some values from others. One
kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive
application, but we also have recipe types for constants and lambda binders:
kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s
primitive application, but we also have recipe types for constants and lambda
binders:
```{code-cell}
from weakref import ref, ReferenceType
@ -2172,11 +2193,12 @@ The jaxpr corresponds to a topological sort of the graph.
```{code-cell}
def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer],
tracers_out: List[PartialEvalTracer]):
tracer_to_var = {id(t): Var(raise_to_shaped(t.aval)) for t in tracers_in}
constvar_to_val = {}
constid_to_var = {}
processed_eqns = set()
eqns = []
tracer_to_var: Dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))
for t in tracers_in}
constvar_to_val: Dict[int, Any] = {}
constid_to_var: Dict[int, Var] = {}
processed_eqns: Set[int] = set()
eqns: List[JaxprEqn] = []
for t in toposort(tracers_out, tracer_parents):
if isinstance(t.recipe, LambdaBindingRecipe):
assert id(t) in set(map(id, tracers_in))
@ -2276,7 +2298,7 @@ jaxprs, which we'll call `xla_call_peval_eqn`.
```{code-cell}
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
del num_consts # Unused.
del num_consts # Unused
in_unknowns = [not t.pval.is_known for t in tracers]
jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)
known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
@ -2299,8 +2321,8 @@ def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool],
env: Dict[Var, bool] = {}
residuals: Set[Var] = set()
def read(v: Atom) -> bool:
return type(v) is Var and env[v]
def read(x: Atom) -> bool:
return type(x) is Var and env[x]
def write(unk: bool, v: Var) -> None:
env[v] = unk
@ -2332,6 +2354,7 @@ def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool],
out_unknowns = map(op.or_, out_unknowns, instantiate)
residuals, num_res = list(residuals), len(residuals)
assert all(type(v) is Var for v in residuals), residuals
ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)
@ -2363,16 +2386,16 @@ def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
partial_eval_jaxpr_rules = {}
def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,
) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:
) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Var]]:
jaxpr = eqn.params['jaxpr']
jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
ins1, ins2 = partition_list(unks_in, eqn.inputs)
outs1, outs2 = partition_list(unks_out, eqn.out_binders)
residuals, _ = split_list(jaxpr2.in_binders, num_res)
out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)
residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]
eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),
outs1 + residuals)
out_binders1 + residuals)
eqn2 = JaxprEqn(xla_call_p, residuals + ins2,
dict(jaxpr=jaxpr2, num_consts=0), outs2)
dict(jaxpr=jaxpr2, num_consts=0), out_binders2)
return eqn1, eqn2, unks_out, residuals
partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn
```
@ -2535,7 +2558,7 @@ def add_transpose_rule(cts, x, y):
transpose_rules[add_p] = add_transpose_rule
def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
del num_consts # Unused.
del num_consts # Unused
undef_primals = [type(x) is UndefPrimal for x in invals]
transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))
residuals, _ = partition_list(undef_primals, invals)
@ -2813,7 +2836,7 @@ def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):
abstract_eval_rules[cond_p] = cond_abstract_eval
def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
del in_avals # Unused.
del in_avals # Unused
pred, *in_vals = in_vals
flat_vals, in_tree = tree_flatten(in_vals)
operand = xops.Tuple(c, flat_vals)
@ -2930,7 +2953,8 @@ def cond_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,
eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
outs2)
return eqn1, eqn2, unks_out, [eqn.inputs[0], *residuals]
res = [eqn.inputs[0], *residuals] if type(eqn.inputs[0]) is Var else residuals
return eqn1, eqn2, unks_out, res
partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn
```
@ -2961,3 +2985,19 @@ transpose_rules[cond_p] = cond_transpose_rule
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
print(out)
```
```{code-cell}
:tags: [hide-input]
def pprint_cond(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:
true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
new_params = {k:v for k, v in eqn.params.items() if not k.endswith('jaxpr')}
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
rhs = (pp(eqn.primitive.name) >> pp_params(new_params) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return vcat([lhs >> pp(' = ') >> rhs,
pp_jaxpr(true_jaxpr).indent(2),
pp_jaxpr(false_jaxpr).indent(2)])
pp_rules[cond_p] = pprint_cond
```

View File

@ -1254,6 +1254,7 @@ def make_jaxpr_v1(f, *avals_in):
return jaxpr, consts, out_tree()
# + tags=["hide-input"]
from typing import DefaultDict
from collections import defaultdict
import string
@ -1288,7 +1289,7 @@ def pp(s: Any) -> PPrint:
def vcat(ps: List[PPrint]) -> PPrint:
return sum(ps, pp(''))
def pp_jaxpr(jaxpr: Jaxpr):
def pp_jaxpr(jaxpr: Jaxpr) -> PPrint:
namegen = (''.join(s) for r in it.count(1)
for s in it.permutations(string.ascii_lowercase, r))
names = defaultdict(lambda: next(namegen))
@ -1299,15 +1300,19 @@ def pp_jaxpr(jaxpr: Jaxpr):
return (pp(f'{{ lambda {in_binders} .') +
((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))
def var_str(names: Dict[Var, str], v: Var) -> str:
def var_str(names: DefaultDict[Var, str], v: Var) -> str:
return f'{names[v]}:{v.aval.str_short()}'
def pp_eqn(names: Dict[Var, str], eqn: JaxprEqn) -> PPrint:
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return lhs >> pp(' = ') >> rhs
def pp_eqn(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:
rule = pp_rules.get(eqn.primitive)
if rule:
return rule(names, eqn)
else:
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return lhs >> pp(' = ') >> rhs
def pp_params(params: Dict[str, Any]) -> PPrint:
items = sorted(params.items())
@ -1317,6 +1322,7 @@ def pp_params(params: Dict[str, Any]) -> PPrint:
return pp(' ')
Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))
pp_rules: Dict[Primitive, Callable[..., PPrint]] = {}
# -
jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.)))
@ -1548,7 +1554,7 @@ input_handlers = {ty: default_input_handler for ty in
[bool, int, float, np.ndarray, np.float64, np.float32]}
def handle_result(aval: ShapedArray, buf):
del aval # Unused for now.
del aval # Unused for now
return buf.to_py()
xla_translations = {}
@ -1637,7 +1643,7 @@ print(jit(deriv(deriv(f)))(3.))
# +
def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
del num_consts # Unused.
del num_consts # Unused
new_jaxpr, new_consts = jvp_jaxpr(jaxpr)
outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,
num_consts=len(new_consts))
@ -1659,7 +1665,7 @@ def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:
# +
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
del num_consts # Unused.
del num_consts # Unused
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
num_consts=len(new_consts))
@ -1686,7 +1692,7 @@ def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
# +
def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):
del num_consts # Unused.
del num_consts # Unused
jaxpr_type = typecheck_jaxpr(jaxpr)
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
raise TypeError
@ -1775,6 +1781,17 @@ x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
# + tags=["hide-input"]
def pprint_xla_call(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'}
rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return vcat([lhs >> pp(' = ') >> rhs,
pp_jaxpr(eqn.params['jaxpr']).indent(2)])
pp_rules[xla_call_p] = pprint_xla_call
# -
# ## Part 4: `linearize` and `vjp` (and `grad`!)
@ -1939,11 +1956,11 @@ def vspace(aval: ShapedArray) -> ShapedArray:
# operations out of Python first before sorting out what can be evaluated now
# and what must be delayed, we want only to form a jaxpr for those operations
# that _must_ be delayed due to a dependence on unknown inputs. In the context
# of automatic differentiation, this is the feature that ultimately enables us to
# handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control
# flow works because partial evaluation keeps the primal computation in Python.
# As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out
# what can be evaluated and what must be staged out into a jaxpr.
# of automatic differentiation, this is the feature that ultimately enables us
# to handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python
# control flow works because partial evaluation keeps the primal computation in
# Python. As a consequence, our `Trace` and `Tracer` subclasses must on the fly
# sort out what can be evaluated and what must be staged out into a jaxpr.
#
# First, we start with a `PartialVal` class, which represents a value that can
# be either known or unknown:
@ -1985,8 +2002,9 @@ def partial_eval_flat(f: Callable, pvals_in: List[PartialVal]
# do so, it builds a bipartite directed acyclic graph (DAG) between
# `PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe`
# nodes, representing formulas for how to compute some values from others. One
# kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive
# application, but we also have recipe types for constants and lambda binders:
# kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s
# primitive application, but we also have recipe types for constants and lambda
# binders:
# +
from weakref import ref, ReferenceType
@ -2086,11 +2104,12 @@ partial_eval_rules = {}
# +
def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer],
tracers_out: List[PartialEvalTracer]):
tracer_to_var = {id(t): Var(raise_to_shaped(t.aval)) for t in tracers_in}
constvar_to_val = {}
constid_to_var = {}
processed_eqns = set()
eqns = []
tracer_to_var: Dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))
for t in tracers_in}
constvar_to_val: Dict[int, Any] = {}
constid_to_var: Dict[int, Var] = {}
processed_eqns: Set[int] = set()
eqns: List[JaxprEqn] = []
for t in toposort(tracers_out, tracer_parents):
if isinstance(t.recipe, LambdaBindingRecipe):
assert id(t) in set(map(id, tracers_in))
@ -2185,7 +2204,7 @@ print(sin_lin(1.), cos(3.))
# +
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
del num_consts # Unused.
del num_consts # Unused
in_unknowns = [not t.pval.is_known for t in tracers]
jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)
known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
@ -2208,8 +2227,8 @@ def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool],
env: Dict[Var, bool] = {}
residuals: Set[Var] = set()
def read(v: Atom) -> bool:
return type(v) is Var and env[v]
def read(x: Atom) -> bool:
return type(x) is Var and env[x]
def write(unk: bool, v: Var) -> None:
env[v] = unk
@ -2241,6 +2260,7 @@ def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool],
out_unknowns = map(op.or_, out_unknowns, instantiate)
residuals, num_res = list(residuals), len(residuals)
assert all(type(v) is Var for v in residuals), residuals
ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)
@ -2272,16 +2292,16 @@ def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
partial_eval_jaxpr_rules = {}
def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,
) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:
) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Var]]:
jaxpr = eqn.params['jaxpr']
jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
ins1, ins2 = partition_list(unks_in, eqn.inputs)
outs1, outs2 = partition_list(unks_out, eqn.out_binders)
residuals, _ = split_list(jaxpr2.in_binders, num_res)
out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)
residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]
eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),
outs1 + residuals)
out_binders1 + residuals)
eqn2 = JaxprEqn(xla_call_p, residuals + ins2,
dict(jaxpr=jaxpr2, num_consts=0), outs2)
dict(jaxpr=jaxpr2, num_consts=0), out_binders2)
return eqn1, eqn2, unks_out, residuals
partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn
# -
@ -2442,7 +2462,7 @@ def add_transpose_rule(cts, x, y):
transpose_rules[add_p] = add_transpose_rule
def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
del num_consts # Unused.
del num_consts # Unused
undef_primals = [type(x) is UndefPrimal for x in invals]
transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))
residuals, _ = partition_list(undef_primals, invals)
@ -2701,7 +2721,7 @@ def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):
abstract_eval_rules[cond_p] = cond_abstract_eval
def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
del in_avals # Unused.
del in_avals # Unused
pred, *in_vals = in_vals
flat_vals, in_tree = tree_flatten(in_vals)
operand = xops.Tuple(c, flat_vals)
@ -2791,8 +2811,6 @@ def _join_jaxpr_res(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
new_jaxpr1 = Jaxpr(jaxpr1.in_binders, jaxpr1.eqns, outs1 + res1 + zeros_like2)
new_jaxpr2 = Jaxpr(jaxpr2.in_binders, jaxpr2.eqns, outs2 + zeros_like1 + res2)
return new_jaxpr1, new_jaxpr2
# -
_, f_lin = linearize(lambda x: cond(True, lambda: x, lambda: 0.), 1.)
@ -2815,7 +2833,8 @@ def cond_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,
eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
outs2)
return eqn1, eqn2, unks_out, [eqn.inputs[0], *residuals]
res = [eqn.inputs[0], *residuals] if type(eqn.inputs[0]) is Var else residuals
return eqn1, eqn2, unks_out, res
partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn
_, f_lin = linearize(jit(lambda x: cond(True, lambda: x, lambda: 0.)), 1.)
@ -2839,3 +2858,16 @@ transpose_rules[cond_p] = cond_transpose_rule
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
print(out)
# + tags=["hide-input"]
def pprint_cond(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:
true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
new_params = {k:v for k, v in eqn.params.items() if not k.endswith('jaxpr')}
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
rhs = (pp(eqn.primitive.name) >> pp_params(new_params) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return vcat([lhs >> pp(' = ') >> rhs,
pp_jaxpr(true_jaxpr).indent(2),
pp_jaxpr(false_jaxpr).indent(2)])
pp_rules[cond_p] = pprint_cond