mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
small autodidax tweaks
This commit is contained in:
parent
c75f77362d
commit
24de3e992c
@ -1766,6 +1766,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import DefaultDict\n",
|
||||
"from collections import defaultdict\n",
|
||||
"import string\n",
|
||||
"\n",
|
||||
@ -1800,7 +1801,7 @@
|
||||
"def vcat(ps: List[PPrint]) -> PPrint:\n",
|
||||
" return sum(ps, pp(''))\n",
|
||||
"\n",
|
||||
"def pp_jaxpr(jaxpr: Jaxpr):\n",
|
||||
"def pp_jaxpr(jaxpr: Jaxpr) -> PPrint:\n",
|
||||
" namegen = (''.join(s) for r in it.count(1)\n",
|
||||
" for s in it.permutations(string.ascii_lowercase, r))\n",
|
||||
" names = defaultdict(lambda: next(namegen))\n",
|
||||
@ -1811,15 +1812,19 @@
|
||||
" return (pp(f'{{ lambda {in_binders} .') +\n",
|
||||
" ((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))\n",
|
||||
"\n",
|
||||
"def var_str(names: Dict[Var, str], v: Var) -> str:\n",
|
||||
"def var_str(names: DefaultDict[Var, str], v: Var) -> str:\n",
|
||||
" return f'{names[v]}:{v.aval.str_short()}'\n",
|
||||
"\n",
|
||||
"def pp_eqn(names: Dict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
|
||||
" lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
|
||||
" rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>\n",
|
||||
" pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
|
||||
" for x in eqn.inputs)))\n",
|
||||
" return lhs >> pp(' = ') >> rhs\n",
|
||||
"def pp_eqn(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
|
||||
" rule = pp_rules.get(eqn.primitive)\n",
|
||||
" if rule:\n",
|
||||
" return rule(names, eqn)\n",
|
||||
" else:\n",
|
||||
" lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
|
||||
" rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>\n",
|
||||
" pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
|
||||
" for x in eqn.inputs)))\n",
|
||||
" return lhs >> pp(' = ') >> rhs\n",
|
||||
"\n",
|
||||
"def pp_params(params: Dict[str, Any]) -> PPrint:\n",
|
||||
" items = sorted(params.items())\n",
|
||||
@ -1828,7 +1833,8 @@
|
||||
" else:\n",
|
||||
" return pp(' ')\n",
|
||||
"\n",
|
||||
"Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))"
|
||||
"Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))\n",
|
||||
"pp_rules: Dict[Primitive, Callable[..., PPrint]] = {}"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -2167,7 +2173,7 @@
|
||||
" [bool, int, float, np.ndarray, np.float64, np.float32]}\n",
|
||||
"\n",
|
||||
"def handle_result(aval: ShapedArray, buf):\n",
|
||||
" del aval # Unused for now.\n",
|
||||
" del aval # Unused for now\n",
|
||||
" return buf.to_py()\n",
|
||||
"\n",
|
||||
"xla_translations = {}"
|
||||
@ -2332,7 +2338,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):\n",
|
||||
" del num_consts # Unused.\n",
|
||||
" del num_consts # Unused\n",
|
||||
" new_jaxpr, new_consts = jvp_jaxpr(jaxpr)\n",
|
||||
" outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,\n",
|
||||
" num_consts=len(new_consts))\n",
|
||||
@ -2362,7 +2368,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):\n",
|
||||
" del num_consts # Unused.\n",
|
||||
" del num_consts # Unused\n",
|
||||
" new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))\n",
|
||||
" outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,\n",
|
||||
" num_consts=len(new_consts))\n",
|
||||
@ -2397,7 +2403,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):\n",
|
||||
" del num_consts # Unused.\n",
|
||||
" del num_consts # Unused\n",
|
||||
" jaxpr_type = typecheck_jaxpr(jaxpr)\n",
|
||||
" if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):\n",
|
||||
" raise TypeError\n",
|
||||
@ -2532,6 +2538,29 @@
|
||||
"print(ydot)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_end_of_cell_marker": 0,
|
||||
"lines_to_next_cell": 1,
|
||||
"tags": [
|
||||
"hide-input"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def pprint_xla_call(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
|
||||
" lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
|
||||
" params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'}\n",
|
||||
" rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >>\n",
|
||||
" pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
|
||||
" for x in eqn.inputs)))\n",
|
||||
" return vcat([lhs >> pp(' = ') >> rhs,\n",
|
||||
" pp_jaxpr(eqn.params['jaxpr']).indent(2)])\n",
|
||||
"pp_rules[xla_call_p] = pprint_xla_call"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@ -2729,11 +2758,11 @@
|
||||
"operations out of Python first before sorting out what can be evaluated now\n",
|
||||
"and what must be delayed, we want only to form a jaxpr for those operations\n",
|
||||
"that _must_ be delayed due to a dependence on unknown inputs. In the context\n",
|
||||
"of automatic differentiation, this is the feature that ultimately enables us to\n",
|
||||
"handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control\n",
|
||||
"flow works because partial evaluation keeps the primal computation in Python.\n",
|
||||
"As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out\n",
|
||||
"what can be evaluated and what must be staged out into a jaxpr.\n",
|
||||
"of automatic differentiation, this is the feature that ultimately enables us\n",
|
||||
"to handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python\n",
|
||||
"control flow works because partial evaluation keeps the primal computation in\n",
|
||||
"Python. As a consequence, our `Trace` and `Tracer` subclasses must on the fly\n",
|
||||
"sort out what can be evaluated and what must be staged out into a jaxpr.\n",
|
||||
"\n",
|
||||
"First, we start with a `PartialVal` class, which represents a value that can\n",
|
||||
"be either known or unknown:"
|
||||
@ -2803,8 +2832,9 @@
|
||||
"do so, it builds a bipartite directed acyclic graph (DAG) between\n",
|
||||
"`PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe`\n",
|
||||
"nodes, representing formulas for how to compute some values from others. One\n",
|
||||
"kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive\n",
|
||||
"application, but we also have recipe types for constants and lambda binders:"
|
||||
"kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s\n",
|
||||
"primitive application, but we also have recipe types for constants and lambda\n",
|
||||
"binders:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -2945,11 +2975,12 @@
|
||||
"source": [
|
||||
"def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer],\n",
|
||||
" tracers_out: List[PartialEvalTracer]):\n",
|
||||
" tracer_to_var = {id(t): Var(raise_to_shaped(t.aval)) for t in tracers_in}\n",
|
||||
" constvar_to_val = {}\n",
|
||||
" constid_to_var = {}\n",
|
||||
" processed_eqns = set()\n",
|
||||
" eqns = []\n",
|
||||
" tracer_to_var: Dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))\n",
|
||||
" for t in tracers_in}\n",
|
||||
" constvar_to_val: Dict[int, Any] = {}\n",
|
||||
" constid_to_var: Dict[int, Var] = {}\n",
|
||||
" processed_eqns: Set[int] = set()\n",
|
||||
" eqns: List[JaxprEqn] = []\n",
|
||||
" for t in toposort(tracers_out, tracer_parents):\n",
|
||||
" if isinstance(t.recipe, LambdaBindingRecipe):\n",
|
||||
" assert id(t) in set(map(id, tracers_in))\n",
|
||||
@ -3083,7 +3114,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):\n",
|
||||
" del num_consts # Unused.\n",
|
||||
" del num_consts # Unused\n",
|
||||
" in_unknowns = [not t.pval.is_known for t in tracers]\n",
|
||||
" jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)\n",
|
||||
" known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)\n",
|
||||
@ -3106,8 +3137,8 @@
|
||||
" env: Dict[Var, bool] = {}\n",
|
||||
" residuals: Set[Var] = set()\n",
|
||||
"\n",
|
||||
" def read(v: Atom) -> bool:\n",
|
||||
" return type(v) is Var and env[v]\n",
|
||||
" def read(x: Atom) -> bool:\n",
|
||||
" return type(x) is Var and env[x]\n",
|
||||
"\n",
|
||||
" def write(unk: bool, v: Var) -> None:\n",
|
||||
" env[v] = unk\n",
|
||||
@ -3139,6 +3170,7 @@
|
||||
" out_unknowns = map(op.or_, out_unknowns, instantiate)\n",
|
||||
"\n",
|
||||
" residuals, num_res = list(residuals), len(residuals)\n",
|
||||
" assert all(type(v) is Var for v in residuals), residuals\n",
|
||||
"\n",
|
||||
" ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)\n",
|
||||
" outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)\n",
|
||||
@ -3170,16 +3202,16 @@
|
||||
"partial_eval_jaxpr_rules = {}\n",
|
||||
"\n",
|
||||
"def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,\n",
|
||||
" ) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:\n",
|
||||
" ) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Var]]:\n",
|
||||
" jaxpr = eqn.params['jaxpr']\n",
|
||||
" jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)\n",
|
||||
" ins1, ins2 = partition_list(unks_in, eqn.inputs)\n",
|
||||
" outs1, outs2 = partition_list(unks_out, eqn.out_binders)\n",
|
||||
" residuals, _ = split_list(jaxpr2.in_binders, num_res)\n",
|
||||
" out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)\n",
|
||||
" residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]\n",
|
||||
" eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),\n",
|
||||
" outs1 + residuals)\n",
|
||||
" out_binders1 + residuals)\n",
|
||||
" eqn2 = JaxprEqn(xla_call_p, residuals + ins2,\n",
|
||||
" dict(jaxpr=jaxpr2, num_consts=0), outs2)\n",
|
||||
" dict(jaxpr=jaxpr2, num_consts=0), out_binders2)\n",
|
||||
" return eqn1, eqn2, unks_out, residuals\n",
|
||||
"partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn"
|
||||
]
|
||||
@ -3395,7 +3427,7 @@
|
||||
"transpose_rules[add_p] = add_transpose_rule\n",
|
||||
"\n",
|
||||
"def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):\n",
|
||||
" del num_consts # Unused.\n",
|
||||
" del num_consts # Unused\n",
|
||||
" undef_primals = [type(x) is UndefPrimal for x in invals]\n",
|
||||
" transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))\n",
|
||||
" residuals, _ = partition_list(undef_primals, invals)\n",
|
||||
@ -3804,7 +3836,7 @@
|
||||
"abstract_eval_rules[cond_p] = cond_abstract_eval\n",
|
||||
"\n",
|
||||
"def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):\n",
|
||||
" del in_avals # Unused.\n",
|
||||
" del in_avals # Unused\n",
|
||||
" pred, *in_vals = in_vals\n",
|
||||
" flat_vals, in_tree = tree_flatten(in_vals)\n",
|
||||
" operand = xops.Tuple(c, flat_vals)\n",
|
||||
@ -3857,6 +3889,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_end_of_cell_marker": 0,
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
"outputs": [],
|
||||
@ -3954,7 +3987,8 @@
|
||||
" eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],\n",
|
||||
" dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),\n",
|
||||
" outs2)\n",
|
||||
" return eqn1, eqn2, unks_out, [eqn.inputs[0], *residuals]\n",
|
||||
" res = [eqn.inputs[0], *residuals] if type(eqn.inputs[0]) is Var else residuals\n",
|
||||
" return eqn1, eqn2, unks_out, res\n",
|
||||
"partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn"
|
||||
]
|
||||
},
|
||||
@ -4002,12 +4036,37 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)\n",
|
||||
"print(out)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": [
|
||||
"hide-input"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def pprint_cond(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
|
||||
" true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']\n",
|
||||
" new_params = {k:v for k, v in eqn.params.items() if not k.endswith('jaxpr')}\n",
|
||||
" lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
|
||||
" rhs = (pp(eqn.primitive.name) >> pp_params(new_params) >>\n",
|
||||
" pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
|
||||
" for x in eqn.inputs)))\n",
|
||||
" return vcat([lhs >> pp(' = ') >> rhs,\n",
|
||||
" pp_jaxpr(true_jaxpr).indent(2),\n",
|
||||
" pp_jaxpr(false_jaxpr).indent(2)])\n",
|
||||
"pp_rules[cond_p] = pprint_cond"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -1304,6 +1304,7 @@ def make_jaxpr_v1(f, *avals_in):
|
||||
```{code-cell}
|
||||
:tags: [hide-input]
|
||||
|
||||
from typing import DefaultDict
|
||||
from collections import defaultdict
|
||||
import string
|
||||
|
||||
@ -1338,7 +1339,7 @@ def pp(s: Any) -> PPrint:
|
||||
def vcat(ps: List[PPrint]) -> PPrint:
|
||||
return sum(ps, pp(''))
|
||||
|
||||
def pp_jaxpr(jaxpr: Jaxpr):
|
||||
def pp_jaxpr(jaxpr: Jaxpr) -> PPrint:
|
||||
namegen = (''.join(s) for r in it.count(1)
|
||||
for s in it.permutations(string.ascii_lowercase, r))
|
||||
names = defaultdict(lambda: next(namegen))
|
||||
@ -1349,15 +1350,19 @@ def pp_jaxpr(jaxpr: Jaxpr):
|
||||
return (pp(f'{{ lambda {in_binders} .') +
|
||||
((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))
|
||||
|
||||
def var_str(names: Dict[Var, str], v: Var) -> str:
|
||||
def var_str(names: DefaultDict[Var, str], v: Var) -> str:
|
||||
return f'{names[v]}:{v.aval.str_short()}'
|
||||
|
||||
def pp_eqn(names: Dict[Var, str], eqn: JaxprEqn) -> PPrint:
|
||||
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
|
||||
rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>
|
||||
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
|
||||
for x in eqn.inputs)))
|
||||
return lhs >> pp(' = ') >> rhs
|
||||
def pp_eqn(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:
|
||||
rule = pp_rules.get(eqn.primitive)
|
||||
if rule:
|
||||
return rule(names, eqn)
|
||||
else:
|
||||
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
|
||||
rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>
|
||||
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
|
||||
for x in eqn.inputs)))
|
||||
return lhs >> pp(' = ') >> rhs
|
||||
|
||||
def pp_params(params: Dict[str, Any]) -> PPrint:
|
||||
items = sorted(params.items())
|
||||
@ -1367,6 +1372,7 @@ def pp_params(params: Dict[str, Any]) -> PPrint:
|
||||
return pp(' ')
|
||||
|
||||
Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))
|
||||
pp_rules: Dict[Primitive, Callable[..., PPrint]] = {}
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
@ -1611,7 +1617,7 @@ input_handlers = {ty: default_input_handler for ty in
|
||||
[bool, int, float, np.ndarray, np.float64, np.float32]}
|
||||
|
||||
def handle_result(aval: ShapedArray, buf):
|
||||
del aval # Unused for now.
|
||||
del aval # Unused for now
|
||||
return buf.to_py()
|
||||
|
||||
xla_translations = {}
|
||||
@ -1709,7 +1715,7 @@ level." Let's fix that!
|
||||
|
||||
```{code-cell}
|
||||
def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
|
||||
del num_consts # Unused.
|
||||
del num_consts # Unused
|
||||
new_jaxpr, new_consts = jvp_jaxpr(jaxpr)
|
||||
outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,
|
||||
num_consts=len(new_consts))
|
||||
@ -1732,7 +1738,7 @@ def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:
|
||||
|
||||
```{code-cell}
|
||||
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
|
||||
del num_consts # Unused.
|
||||
del num_consts # Unused
|
||||
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
|
||||
outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
|
||||
num_consts=len(new_consts))
|
||||
@ -1760,7 +1766,7 @@ def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
|
||||
|
||||
```{code-cell}
|
||||
def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):
|
||||
del num_consts # Unused.
|
||||
del num_consts # Unused
|
||||
jaxpr_type = typecheck_jaxpr(jaxpr)
|
||||
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
|
||||
raise TypeError
|
||||
@ -1857,6 +1863,20 @@ print(y)
|
||||
print(ydot)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:tags: [hide-input]
|
||||
|
||||
def pprint_xla_call(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:
|
||||
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
|
||||
params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'}
|
||||
rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >>
|
||||
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
|
||||
for x in eqn.inputs)))
|
||||
return vcat([lhs >> pp(' = ') >> rhs,
|
||||
pp_jaxpr(eqn.params['jaxpr']).indent(2)])
|
||||
pp_rules[xla_call_p] = pprint_xla_call
|
||||
```
|
||||
|
||||
## Part 4: `linearize` and `vjp` (and `grad`!)
|
||||
|
||||
The `linearize` and `vjp` autodiff functions are built on `jvp`, but involve
|
||||
@ -2021,11 +2041,11 @@ forming a jaxpr for the entire function `(a1, a2) -> (b1, b2)`, staging all
|
||||
operations out of Python first before sorting out what can be evaluated now
|
||||
and what must be delayed, we want only to form a jaxpr for those operations
|
||||
that _must_ be delayed due to a dependence on unknown inputs. In the context
|
||||
of automatic differentiation, this is the feature that ultimately enables us to
|
||||
handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control
|
||||
flow works because partial evaluation keeps the primal computation in Python.
|
||||
As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out
|
||||
what can be evaluated and what must be staged out into a jaxpr.
|
||||
of automatic differentiation, this is the feature that ultimately enables us
|
||||
to handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python
|
||||
control flow works because partial evaluation keeps the primal computation in
|
||||
Python. As a consequence, our `Trace` and `Tracer` subclasses must on the fly
|
||||
sort out what can be evaluated and what must be staged out into a jaxpr.
|
||||
|
||||
First, we start with a `PartialVal` class, which represents a value that can
|
||||
be either known or unknown:
|
||||
@ -2071,8 +2091,9 @@ interpreter will build a jaxpr on the fly while tracking data dependencies. To
|
||||
do so, it builds a bipartite directed acyclic graph (DAG) between
|
||||
`PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe`
|
||||
nodes, representing formulas for how to compute some values from others. One
|
||||
kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive
|
||||
application, but we also have recipe types for constants and lambda binders:
|
||||
kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s
|
||||
primitive application, but we also have recipe types for constants and lambda
|
||||
binders:
|
||||
|
||||
```{code-cell}
|
||||
from weakref import ref, ReferenceType
|
||||
@ -2172,11 +2193,12 @@ The jaxpr corresponds to a topological sort of the graph.
|
||||
```{code-cell}
|
||||
def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer],
|
||||
tracers_out: List[PartialEvalTracer]):
|
||||
tracer_to_var = {id(t): Var(raise_to_shaped(t.aval)) for t in tracers_in}
|
||||
constvar_to_val = {}
|
||||
constid_to_var = {}
|
||||
processed_eqns = set()
|
||||
eqns = []
|
||||
tracer_to_var: Dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))
|
||||
for t in tracers_in}
|
||||
constvar_to_val: Dict[int, Any] = {}
|
||||
constid_to_var: Dict[int, Var] = {}
|
||||
processed_eqns: Set[int] = set()
|
||||
eqns: List[JaxprEqn] = []
|
||||
for t in toposort(tracers_out, tracer_parents):
|
||||
if isinstance(t.recipe, LambdaBindingRecipe):
|
||||
assert id(t) in set(map(id, tracers_in))
|
||||
@ -2276,7 +2298,7 @@ jaxprs, which we'll call `xla_call_peval_eqn`.
|
||||
|
||||
```{code-cell}
|
||||
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
|
||||
del num_consts # Unused.
|
||||
del num_consts # Unused
|
||||
in_unknowns = [not t.pval.is_known for t in tracers]
|
||||
jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)
|
||||
known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
|
||||
@ -2299,8 +2321,8 @@ def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool],
|
||||
env: Dict[Var, bool] = {}
|
||||
residuals: Set[Var] = set()
|
||||
|
||||
def read(v: Atom) -> bool:
|
||||
return type(v) is Var and env[v]
|
||||
def read(x: Atom) -> bool:
|
||||
return type(x) is Var and env[x]
|
||||
|
||||
def write(unk: bool, v: Var) -> None:
|
||||
env[v] = unk
|
||||
@ -2332,6 +2354,7 @@ def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool],
|
||||
out_unknowns = map(op.or_, out_unknowns, instantiate)
|
||||
|
||||
residuals, num_res = list(residuals), len(residuals)
|
||||
assert all(type(v) is Var for v in residuals), residuals
|
||||
|
||||
ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
|
||||
outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)
|
||||
@ -2363,16 +2386,16 @@ def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
|
||||
partial_eval_jaxpr_rules = {}
|
||||
|
||||
def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,
|
||||
) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:
|
||||
) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Var]]:
|
||||
jaxpr = eqn.params['jaxpr']
|
||||
jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
|
||||
ins1, ins2 = partition_list(unks_in, eqn.inputs)
|
||||
outs1, outs2 = partition_list(unks_out, eqn.out_binders)
|
||||
residuals, _ = split_list(jaxpr2.in_binders, num_res)
|
||||
out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)
|
||||
residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]
|
||||
eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),
|
||||
outs1 + residuals)
|
||||
out_binders1 + residuals)
|
||||
eqn2 = JaxprEqn(xla_call_p, residuals + ins2,
|
||||
dict(jaxpr=jaxpr2, num_consts=0), outs2)
|
||||
dict(jaxpr=jaxpr2, num_consts=0), out_binders2)
|
||||
return eqn1, eqn2, unks_out, residuals
|
||||
partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn
|
||||
```
|
||||
@ -2535,7 +2558,7 @@ def add_transpose_rule(cts, x, y):
|
||||
transpose_rules[add_p] = add_transpose_rule
|
||||
|
||||
def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
|
||||
del num_consts # Unused.
|
||||
del num_consts # Unused
|
||||
undef_primals = [type(x) is UndefPrimal for x in invals]
|
||||
transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))
|
||||
residuals, _ = partition_list(undef_primals, invals)
|
||||
@ -2813,7 +2836,7 @@ def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):
|
||||
abstract_eval_rules[cond_p] = cond_abstract_eval
|
||||
|
||||
def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
|
||||
del in_avals # Unused.
|
||||
del in_avals # Unused
|
||||
pred, *in_vals = in_vals
|
||||
flat_vals, in_tree = tree_flatten(in_vals)
|
||||
operand = xops.Tuple(c, flat_vals)
|
||||
@ -2930,7 +2953,8 @@ def cond_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,
|
||||
eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],
|
||||
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
|
||||
outs2)
|
||||
return eqn1, eqn2, unks_out, [eqn.inputs[0], *residuals]
|
||||
res = [eqn.inputs[0], *residuals] if type(eqn.inputs[0]) is Var else residuals
|
||||
return eqn1, eqn2, unks_out, res
|
||||
partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn
|
||||
```
|
||||
|
||||
@ -2961,3 +2985,19 @@ transpose_rules[cond_p] = cond_transpose_rule
|
||||
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
|
||||
print(out)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:tags: [hide-input]
|
||||
|
||||
def pprint_cond(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:
|
||||
true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
|
||||
new_params = {k:v for k, v in eqn.params.items() if not k.endswith('jaxpr')}
|
||||
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
|
||||
rhs = (pp(eqn.primitive.name) >> pp_params(new_params) >>
|
||||
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
|
||||
for x in eqn.inputs)))
|
||||
return vcat([lhs >> pp(' = ') >> rhs,
|
||||
pp_jaxpr(true_jaxpr).indent(2),
|
||||
pp_jaxpr(false_jaxpr).indent(2)])
|
||||
pp_rules[cond_p] = pprint_cond
|
||||
```
|
||||
|
@ -1254,6 +1254,7 @@ def make_jaxpr_v1(f, *avals_in):
|
||||
return jaxpr, consts, out_tree()
|
||||
|
||||
# + tags=["hide-input"]
|
||||
from typing import DefaultDict
|
||||
from collections import defaultdict
|
||||
import string
|
||||
|
||||
@ -1288,7 +1289,7 @@ def pp(s: Any) -> PPrint:
|
||||
def vcat(ps: List[PPrint]) -> PPrint:
|
||||
return sum(ps, pp(''))
|
||||
|
||||
def pp_jaxpr(jaxpr: Jaxpr):
|
||||
def pp_jaxpr(jaxpr: Jaxpr) -> PPrint:
|
||||
namegen = (''.join(s) for r in it.count(1)
|
||||
for s in it.permutations(string.ascii_lowercase, r))
|
||||
names = defaultdict(lambda: next(namegen))
|
||||
@ -1299,15 +1300,19 @@ def pp_jaxpr(jaxpr: Jaxpr):
|
||||
return (pp(f'{{ lambda {in_binders} .') +
|
||||
((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))
|
||||
|
||||
def var_str(names: Dict[Var, str], v: Var) -> str:
|
||||
def var_str(names: DefaultDict[Var, str], v: Var) -> str:
|
||||
return f'{names[v]}:{v.aval.str_short()}'
|
||||
|
||||
def pp_eqn(names: Dict[Var, str], eqn: JaxprEqn) -> PPrint:
|
||||
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
|
||||
rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>
|
||||
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
|
||||
for x in eqn.inputs)))
|
||||
return lhs >> pp(' = ') >> rhs
|
||||
def pp_eqn(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:
|
||||
rule = pp_rules.get(eqn.primitive)
|
||||
if rule:
|
||||
return rule(names, eqn)
|
||||
else:
|
||||
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
|
||||
rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>
|
||||
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
|
||||
for x in eqn.inputs)))
|
||||
return lhs >> pp(' = ') >> rhs
|
||||
|
||||
def pp_params(params: Dict[str, Any]) -> PPrint:
|
||||
items = sorted(params.items())
|
||||
@ -1317,6 +1322,7 @@ def pp_params(params: Dict[str, Any]) -> PPrint:
|
||||
return pp(' ')
|
||||
|
||||
Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))
|
||||
pp_rules: Dict[Primitive, Callable[..., PPrint]] = {}
|
||||
# -
|
||||
|
||||
jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.)))
|
||||
@ -1548,7 +1554,7 @@ input_handlers = {ty: default_input_handler for ty in
|
||||
[bool, int, float, np.ndarray, np.float64, np.float32]}
|
||||
|
||||
def handle_result(aval: ShapedArray, buf):
|
||||
del aval # Unused for now.
|
||||
del aval # Unused for now
|
||||
return buf.to_py()
|
||||
|
||||
xla_translations = {}
|
||||
@ -1637,7 +1643,7 @@ print(jit(deriv(deriv(f)))(3.))
|
||||
|
||||
# +
|
||||
def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
|
||||
del num_consts # Unused.
|
||||
del num_consts # Unused
|
||||
new_jaxpr, new_consts = jvp_jaxpr(jaxpr)
|
||||
outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,
|
||||
num_consts=len(new_consts))
|
||||
@ -1659,7 +1665,7 @@ def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:
|
||||
|
||||
# +
|
||||
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
|
||||
del num_consts # Unused.
|
||||
del num_consts # Unused
|
||||
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
|
||||
outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
|
||||
num_consts=len(new_consts))
|
||||
@ -1686,7 +1692,7 @@ def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
|
||||
|
||||
# +
|
||||
def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):
|
||||
del num_consts # Unused.
|
||||
del num_consts # Unused
|
||||
jaxpr_type = typecheck_jaxpr(jaxpr)
|
||||
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
|
||||
raise TypeError
|
||||
@ -1775,6 +1781,17 @@ x, xdot = 3., 1.
|
||||
y, ydot = jvp(f, (x,), (xdot,))
|
||||
print(y)
|
||||
print(ydot)
|
||||
|
||||
# + tags=["hide-input"]
|
||||
def pprint_xla_call(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:
|
||||
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
|
||||
params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'}
|
||||
rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >>
|
||||
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
|
||||
for x in eqn.inputs)))
|
||||
return vcat([lhs >> pp(' = ') >> rhs,
|
||||
pp_jaxpr(eqn.params['jaxpr']).indent(2)])
|
||||
pp_rules[xla_call_p] = pprint_xla_call
|
||||
# -
|
||||
|
||||
# ## Part 4: `linearize` and `vjp` (and `grad`!)
|
||||
@ -1939,11 +1956,11 @@ def vspace(aval: ShapedArray) -> ShapedArray:
|
||||
# operations out of Python first before sorting out what can be evaluated now
|
||||
# and what must be delayed, we want only to form a jaxpr for those operations
|
||||
# that _must_ be delayed due to a dependence on unknown inputs. In the context
|
||||
# of automatic differentiation, this is the feature that ultimately enables us to
|
||||
# handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control
|
||||
# flow works because partial evaluation keeps the primal computation in Python.
|
||||
# As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out
|
||||
# what can be evaluated and what must be staged out into a jaxpr.
|
||||
# of automatic differentiation, this is the feature that ultimately enables us
|
||||
# to handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python
|
||||
# control flow works because partial evaluation keeps the primal computation in
|
||||
# Python. As a consequence, our `Trace` and `Tracer` subclasses must on the fly
|
||||
# sort out what can be evaluated and what must be staged out into a jaxpr.
|
||||
#
|
||||
# First, we start with a `PartialVal` class, which represents a value that can
|
||||
# be either known or unknown:
|
||||
@ -1985,8 +2002,9 @@ def partial_eval_flat(f: Callable, pvals_in: List[PartialVal]
|
||||
# do so, it builds a bipartite directed acyclic graph (DAG) between
|
||||
# `PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe`
|
||||
# nodes, representing formulas for how to compute some values from others. One
|
||||
# kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive
|
||||
# application, but we also have recipe types for constants and lambda binders:
|
||||
# kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s
|
||||
# primitive application, but we also have recipe types for constants and lambda
|
||||
# binders:
|
||||
|
||||
# +
|
||||
from weakref import ref, ReferenceType
|
||||
@ -2086,11 +2104,12 @@ partial_eval_rules = {}
|
||||
# +
|
||||
def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer],
|
||||
tracers_out: List[PartialEvalTracer]):
|
||||
tracer_to_var = {id(t): Var(raise_to_shaped(t.aval)) for t in tracers_in}
|
||||
constvar_to_val = {}
|
||||
constid_to_var = {}
|
||||
processed_eqns = set()
|
||||
eqns = []
|
||||
tracer_to_var: Dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))
|
||||
for t in tracers_in}
|
||||
constvar_to_val: Dict[int, Any] = {}
|
||||
constid_to_var: Dict[int, Var] = {}
|
||||
processed_eqns: Set[int] = set()
|
||||
eqns: List[JaxprEqn] = []
|
||||
for t in toposort(tracers_out, tracer_parents):
|
||||
if isinstance(t.recipe, LambdaBindingRecipe):
|
||||
assert id(t) in set(map(id, tracers_in))
|
||||
@ -2185,7 +2204,7 @@ print(sin_lin(1.), cos(3.))
|
||||
|
||||
# +
|
||||
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
|
||||
del num_consts # Unused.
|
||||
del num_consts # Unused
|
||||
in_unknowns = [not t.pval.is_known for t in tracers]
|
||||
jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)
|
||||
known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
|
||||
@ -2208,8 +2227,8 @@ def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool],
|
||||
env: Dict[Var, bool] = {}
|
||||
residuals: Set[Var] = set()
|
||||
|
||||
def read(v: Atom) -> bool:
|
||||
return type(v) is Var and env[v]
|
||||
def read(x: Atom) -> bool:
|
||||
return type(x) is Var and env[x]
|
||||
|
||||
def write(unk: bool, v: Var) -> None:
|
||||
env[v] = unk
|
||||
@ -2241,6 +2260,7 @@ def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool],
|
||||
out_unknowns = map(op.or_, out_unknowns, instantiate)
|
||||
|
||||
residuals, num_res = list(residuals), len(residuals)
|
||||
assert all(type(v) is Var for v in residuals), residuals
|
||||
|
||||
ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
|
||||
outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)
|
||||
@ -2272,16 +2292,16 @@ def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
|
||||
partial_eval_jaxpr_rules = {}
|
||||
|
||||
def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,
|
||||
) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:
|
||||
) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Var]]:
|
||||
jaxpr = eqn.params['jaxpr']
|
||||
jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
|
||||
ins1, ins2 = partition_list(unks_in, eqn.inputs)
|
||||
outs1, outs2 = partition_list(unks_out, eqn.out_binders)
|
||||
residuals, _ = split_list(jaxpr2.in_binders, num_res)
|
||||
out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)
|
||||
residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]
|
||||
eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),
|
||||
outs1 + residuals)
|
||||
out_binders1 + residuals)
|
||||
eqn2 = JaxprEqn(xla_call_p, residuals + ins2,
|
||||
dict(jaxpr=jaxpr2, num_consts=0), outs2)
|
||||
dict(jaxpr=jaxpr2, num_consts=0), out_binders2)
|
||||
return eqn1, eqn2, unks_out, residuals
|
||||
partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn
|
||||
# -
|
||||
@ -2442,7 +2462,7 @@ def add_transpose_rule(cts, x, y):
|
||||
transpose_rules[add_p] = add_transpose_rule
|
||||
|
||||
def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
|
||||
del num_consts # Unused.
|
||||
del num_consts # Unused
|
||||
undef_primals = [type(x) is UndefPrimal for x in invals]
|
||||
transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))
|
||||
residuals, _ = partition_list(undef_primals, invals)
|
||||
@ -2701,7 +2721,7 @@ def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):
|
||||
abstract_eval_rules[cond_p] = cond_abstract_eval
|
||||
|
||||
def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
|
||||
del in_avals # Unused.
|
||||
del in_avals # Unused
|
||||
pred, *in_vals = in_vals
|
||||
flat_vals, in_tree = tree_flatten(in_vals)
|
||||
operand = xops.Tuple(c, flat_vals)
|
||||
@ -2791,8 +2811,6 @@ def _join_jaxpr_res(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
|
||||
new_jaxpr1 = Jaxpr(jaxpr1.in_binders, jaxpr1.eqns, outs1 + res1 + zeros_like2)
|
||||
new_jaxpr2 = Jaxpr(jaxpr2.in_binders, jaxpr2.eqns, outs2 + zeros_like1 + res2)
|
||||
return new_jaxpr1, new_jaxpr2
|
||||
|
||||
|
||||
# -
|
||||
|
||||
_, f_lin = linearize(lambda x: cond(True, lambda: x, lambda: 0.), 1.)
|
||||
@ -2815,7 +2833,8 @@ def cond_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,
|
||||
eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],
|
||||
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
|
||||
outs2)
|
||||
return eqn1, eqn2, unks_out, [eqn.inputs[0], *residuals]
|
||||
res = [eqn.inputs[0], *residuals] if type(eqn.inputs[0]) is Var else residuals
|
||||
return eqn1, eqn2, unks_out, res
|
||||
partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn
|
||||
|
||||
_, f_lin = linearize(jit(lambda x: cond(True, lambda: x, lambda: 0.)), 1.)
|
||||
@ -2839,3 +2858,16 @@ transpose_rules[cond_p] = cond_transpose_rule
|
||||
|
||||
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
|
||||
print(out)
|
||||
|
||||
# + tags=["hide-input"]
|
||||
def pprint_cond(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:
|
||||
true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
|
||||
new_params = {k:v for k, v in eqn.params.items() if not k.endswith('jaxpr')}
|
||||
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
|
||||
rhs = (pp(eqn.primitive.name) >> pp_params(new_params) >>
|
||||
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
|
||||
for x in eqn.inputs)))
|
||||
return vcat([lhs >> pp(' = ') >> rhs,
|
||||
pp_jaxpr(true_jaxpr).indent(2),
|
||||
pp_jaxpr(false_jaxpr).indent(2)])
|
||||
pp_rules[cond_p] = pprint_cond
|
||||
|
Loading…
x
Reference in New Issue
Block a user