autodidax: delete while_loop for now

This commit is contained in:
Matthew Johnson 2021-05-05 12:44:49 -07:00
parent 83cd42271b
commit d88acd8b8c
3 changed files with 6 additions and 412 deletions

View File

@ -3077,7 +3077,7 @@
"\n",
"There are actually two rules to write: one for trace-time partial evaluation,\n",
"which we'll call `xla_call_partial_eval`, and one for partial evaluation of\n",
"jaxprs, whicch we'll call `xla_call_peval_eqn`."
"jaxprs, which we'll call `xla_call_peval_eqn`."
]
},
{
@ -3558,7 +3558,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 5: the control flow primitives `cond` and `while_loop`\n",
"## Part 5: the control flow primitives `cond`\n",
"\n",
"Next we'll add higher-order primitives for staged-out control flow. These\n",
"resemble `jit` from Part 3, another higher-order primitive, but differ in that\n",
@ -4014,187 +4014,6 @@
"out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)\n",
"print(out)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Adding `while_loop`\n",
"\n",
"Next we'll add a primitive for looping behavior in a jaxpr. We'll use\n",
"`while_loop : (a -> Bool) -> (a -> a) -> a -> a`, where the first\n",
"function-valued argument represents the loop condition, the second represents\n",
"the loop body, and the final argument is the initial value of the carry.\n",
"\n",
"After `cond`, adding `while_loop` is not so different:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"def while_loop(cond_fn, body_fn, init_val):\n",
" init_val, in_tree = tree_flatten(init_val)\n",
" avals_in = [raise_to_shaped(get_aval(x)) for x in init_val]\n",
" cond_jaxpr, cond_consts, cond_tree = make_jaxpr(cond_fn, *avals_in)\n",
" body_jaxpr, body_consts, in_tree_ = make_jaxpr(body_fn, *avals_in)\n",
" cond_jaxpr, body_jaxpr = _join_jaxpr_consts(\n",
" cond_jaxpr, body_jaxpr, len(cond_consts), len(body_consts))\n",
" if cond_tree != tree_flatten(True)[1]: raise TypeError\n",
" if in_tree != in_tree_: raise TypeError\n",
" outs = bind(while_loop_p, *cond_consts, *body_consts, *init_val,\n",
" cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)\n",
" return tree_unflatten(in_tree, outs)\n",
"while_loop_p = Primitive('while_loop')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"def while_loop_impl(*args, cond_jaxpr, body_jaxpr):\n",
" consts, carry = split_list(args, _loop_num_consts(body_jaxpr))\n",
" while eval_jaxpr(cond_jaxpr, [*consts, *carry])[0]:\n",
" carry = eval_jaxpr(body_jaxpr, [*consts, *carry])\n",
" return carry\n",
"impl_rules[while_loop_p] = while_loop_impl"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"def _loop_num_consts(body_jaxpr: Jaxpr) -> int:\n",
" return len(body_jaxpr.in_binders) - len(body_jaxpr.outs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_end_of_cell_marker": 0,
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"out = while_loop(lambda x: x > 0, lambda x: x + -3, 10)\n",
"print(out)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice the convention that `args = [*consts, *carry]`.\n",
"\n",
"The `while_loop` JVP rule introduces a wrinkle. For `jvp_jaxpr`, we have the\n",
"convention that all the binders for tangent values are appended after all the\n",
"binders for primal values, like `args = [*primals, *tangents]`. But that's in\n",
"tension with our `while_loop` convention that the carry binders come after all\n",
"the constant binders, i.e. that `args = [*consts, *carry]`, because both the\n",
"constants and the carries can have their own tangents. For this reason, we\n",
"introduce the `_loop_jvp_binders` helper to rearrange binders as needed."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"def while_loop_jvp_rule(primals, tangents, *, cond_jaxpr, body_jaxpr):\n",
" num_consts = _loop_num_consts(body_jaxpr)\n",
" body_jaxpr, body_consts = jvp_jaxpr(body_jaxpr)\n",
" cond_jaxpr, body_jaxpr = _loop_jvp_binders(\n",
" cond_jaxpr, body_jaxpr, len(body_consts), num_consts)\n",
" outs = bind(while_loop_p, *body_consts, *primals, *tangents,\n",
" cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)\n",
" primals_out, tangents_out = split_half(outs)\n",
" return primals_out, tangents_out\n",
"jvp_rules[while_loop_p] = while_loop_jvp_rule\n",
"\n",
"def _loop_jvp_binders(cond_jaxpr: Jaxpr, body_jaxpr: Jaxpr, n1: int, n2: int\n",
" ) -> Jaxpr:\n",
" # body binders [c1, c2, x1, c2dot, x2dot] ~~> [c1, c2, c2dot, x1, x1dot]\n",
" jvp_const_binders, binders = split_list(body_jaxpr.in_binders, n1)\n",
" primal_binders, tangent_binders = split_half(binders)\n",
" consts , carry = split_list(primal_binders , n2)\n",
" consts_dot, carry_dot = split_list(tangent_binders, n2)\n",
" new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot\n",
" new_body_jaxpr = Jaxpr(new_in_binders, body_jaxpr.eqns, body_jaxpr.outs)\n",
" typecheck_jaxpr(new_body_jaxpr)\n",
"\n",
" # cond binders [c2, x1] ~~> [c1, c2, c2dot, x1, x1dot]\n",
" assert not set(new_body_jaxpr.in_binders) & set(cond_jaxpr.in_binders)\n",
" consts, carry = split_list(cond_jaxpr.in_binders, n2)\n",
" new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot\n",
" new_cond_jaxpr = Jaxpr(new_in_binders, cond_jaxpr.eqns, cond_jaxpr.outs)\n",
"\n",
" return new_cond_jaxpr, new_body_jaxpr"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"out, out_tan = jvp(lambda x: while_loop(lambda x: x < 10., lambda x: x * 2., x),\n",
" (1.,), (1.,))\n",
"print(out_tan)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"def f(x):\n",
" def cond_fn(i, _):\n",
" return i < 3\n",
" def body_fn(i, x):\n",
" return i + 1, cos(x)\n",
" _, y = while_loop(cond_fn, body_fn, (0, x))\n",
" return y\n",
"\n",
"def g(x):\n",
" return cos(cos(cos(x)))\n",
"\n",
"print(jvp(f, (1.,), (1.,)))\n",
"print(jvp(g, (1.,), (1.,)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The vmap rule for `while_loop` presents two cases:\n",
"1. if the output of `cond_fun` is not batched, then the loop has the same\n",
" basic structure, just with a batched body;\n",
"2. but if the output of `cond_fun` is batched, we must represent a batch of\n",
" loops which might run for different numbers of iterations.\n",
"\n",
"...Stay tuned for the thrilling conclusion!"
]
}
],
"metadata": {

View File

@ -2279,7 +2279,7 @@ perform partial evaluation of a jaxpr, 'unzipping' it into two jaxprs.
There are actually two rules to write: one for trace-time partial evaluation,
which we'll call `xla_call_partial_eval`, and one for partial evaluation of
jaxprs, whicch we'll call `xla_call_peval_eqn`.
jaxprs, which we'll call `xla_call_peval_eqn`.
```{code-cell}
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
@ -2646,7 +2646,7 @@ _, hess7 = jvp(jit(grad(f)), (3.,), (1.,))
assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)
```
## Part 5: the control flow primitives `cond` and `while_loop`
## Part 5: the control flow primitives `cond`
Next we'll add higher-order primitives for staged-out control flow. These
resemble `jit` from Part 3, another higher-order primitive, but differ in that
@ -2967,119 +2967,3 @@ transpose_rules[cond_p] = cond_transpose_rule
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
print(out)
```
### Adding `while_loop`
Next we'll add a primitive for looping behavior in a jaxpr. We'll use
`while_loop : (a -> Bool) -> (a -> a) -> a -> a`, where the first
function-valued argument represents the loop condition, the second represents
the loop body, and the final argument is the initial value of the carry.
After `cond`, adding `while_loop` is not so different:
```{code-cell}
def while_loop(cond_fn, body_fn, init_val):
init_val, in_tree = tree_flatten(init_val)
avals_in = [raise_to_shaped(get_aval(x)) for x in init_val]
cond_jaxpr, cond_consts, cond_tree = make_jaxpr(cond_fn, *avals_in)
body_jaxpr, body_consts, in_tree_ = make_jaxpr(body_fn, *avals_in)
cond_jaxpr, body_jaxpr = _join_jaxpr_consts(
cond_jaxpr, body_jaxpr, len(cond_consts), len(body_consts))
if cond_tree != tree_flatten(True)[1]: raise TypeError
if in_tree != in_tree_: raise TypeError
outs = bind(while_loop_p, *cond_consts, *body_consts, *init_val,
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
return tree_unflatten(in_tree, outs)
while_loop_p = Primitive('while_loop')
```
```{code-cell}
def while_loop_impl(*args, cond_jaxpr, body_jaxpr):
consts, carry = split_list(args, _loop_num_consts(body_jaxpr))
while eval_jaxpr(cond_jaxpr, [*consts, *carry])[0]:
carry = eval_jaxpr(body_jaxpr, [*consts, *carry])
return carry
impl_rules[while_loop_p] = while_loop_impl
```
```{code-cell}
def _loop_num_consts(body_jaxpr: Jaxpr) -> int:
return len(body_jaxpr.in_binders) - len(body_jaxpr.outs)
```
```{code-cell}
out = while_loop(lambda x: x > 0, lambda x: x + -3, 10)
print(out)
```
Notice the convention that `args = [*consts, *carry]`.
The `while_loop` JVP rule introduces a wrinkle. For `jvp_jaxpr`, we have the
convention that all the binders for tangent values are appended after all the
binders for primal values, like `args = [*primals, *tangents]`. But that's in
tension with our `while_loop` convention that the carry binders come after all
the constant binders, i.e. that `args = [*consts, *carry]`, because both the
constants and the carries can have their own tangents. For this reason, we
introduce the `_loop_jvp_binders` helper to rearrange binders as needed.
```{code-cell}
def while_loop_jvp_rule(primals, tangents, *, cond_jaxpr, body_jaxpr):
num_consts = _loop_num_consts(body_jaxpr)
body_jaxpr, body_consts = jvp_jaxpr(body_jaxpr)
cond_jaxpr, body_jaxpr = _loop_jvp_binders(
cond_jaxpr, body_jaxpr, len(body_consts), num_consts)
outs = bind(while_loop_p, *body_consts, *primals, *tangents,
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
primals_out, tangents_out = split_half(outs)
return primals_out, tangents_out
jvp_rules[while_loop_p] = while_loop_jvp_rule
def _loop_jvp_binders(cond_jaxpr: Jaxpr, body_jaxpr: Jaxpr, n1: int, n2: int
) -> Jaxpr:
# body binders [c1, c2, x1, c2dot, x2dot] ~~> [c1, c2, c2dot, x1, x1dot]
jvp_const_binders, binders = split_list(body_jaxpr.in_binders, n1)
primal_binders, tangent_binders = split_half(binders)
consts , carry = split_list(primal_binders , n2)
consts_dot, carry_dot = split_list(tangent_binders, n2)
new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot
new_body_jaxpr = Jaxpr(new_in_binders, body_jaxpr.eqns, body_jaxpr.outs)
typecheck_jaxpr(new_body_jaxpr)
# cond binders [c2, x1] ~~> [c1, c2, c2dot, x1, x1dot]
assert not set(new_body_jaxpr.in_binders) & set(cond_jaxpr.in_binders)
consts, carry = split_list(cond_jaxpr.in_binders, n2)
new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot
new_cond_jaxpr = Jaxpr(new_in_binders, cond_jaxpr.eqns, cond_jaxpr.outs)
return new_cond_jaxpr, new_body_jaxpr
```
```{code-cell}
out, out_tan = jvp(lambda x: while_loop(lambda x: x < 10., lambda x: x * 2., x),
(1.,), (1.,))
print(out_tan)
```
```{code-cell}
def f(x):
def cond_fn(i, _):
return i < 3
def body_fn(i, x):
return i + 1, cos(x)
_, y = while_loop(cond_fn, body_fn, (0, x))
return y
def g(x):
return cos(cos(cos(x)))
print(jvp(f, (1.,), (1.,)))
print(jvp(g, (1.,), (1.,)))
```
The vmap rule for `while_loop` presents two cases:
1. if the output of `cond_fun` is not batched, then the loop has the same
basic structure, just with a batched body;
2. but if the output of `cond_fun` is batched, we must represent a batch of
loops which might run for different numbers of iterations.
...Stay tuned for the thrilling conclusion!

View File

@ -2188,7 +2188,7 @@ print(sin_lin(1.), cos(3.))
#
# There are actually two rules to write: one for trace-time partial evaluation,
# which we'll call `xla_call_partial_eval`, and one for partial evaluation of
# jaxprs, whicch we'll call `xla_call_peval_eqn`.
# jaxprs, which we'll call `xla_call_peval_eqn`.
# +
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
@ -2548,7 +2548,7 @@ _, hess7 = jvp(jit(grad(f)), (3.,), (1.,))
assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)
# -
# ## Part 5: the control flow primitives `cond` and `while_loop`
# ## Part 5: the control flow primitives `cond`
#
# Next we'll add higher-order primitives for staged-out control flow. These
# resemble `jit` from Part 3, another higher-order primitive, but differ in that
@ -2845,112 +2845,3 @@ transpose_rules[cond_p] = cond_transpose_rule
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
print(out)
# ### Adding `while_loop`
#
# Next we'll add a primitive for looping behavior in a jaxpr. We'll use
# `while_loop : (a -> Bool) -> (a -> a) -> a -> a`, where the first
# function-valued argument represents the loop condition, the second represents
# the loop body, and the final argument is the initial value of the carry.
#
# After `cond`, adding `while_loop` is not so different:
def while_loop(cond_fn, body_fn, init_val):
init_val, in_tree = tree_flatten(init_val)
avals_in = [raise_to_shaped(get_aval(x)) for x in init_val]
cond_jaxpr, cond_consts, cond_tree = make_jaxpr(cond_fn, *avals_in)
body_jaxpr, body_consts, in_tree_ = make_jaxpr(body_fn, *avals_in)
cond_jaxpr, body_jaxpr = _join_jaxpr_consts(
cond_jaxpr, body_jaxpr, len(cond_consts), len(body_consts))
if cond_tree != tree_flatten(True)[1]: raise TypeError
if in_tree != in_tree_: raise TypeError
outs = bind(while_loop_p, *cond_consts, *body_consts, *init_val,
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
return tree_unflatten(in_tree, outs)
while_loop_p = Primitive('while_loop')
def while_loop_impl(*args, cond_jaxpr, body_jaxpr):
consts, carry = split_list(args, _loop_num_consts(body_jaxpr))
while eval_jaxpr(cond_jaxpr, [*consts, *carry])[0]:
carry = eval_jaxpr(body_jaxpr, [*consts, *carry])
return carry
impl_rules[while_loop_p] = while_loop_impl
def _loop_num_consts(body_jaxpr: Jaxpr) -> int:
return len(body_jaxpr.in_binders) - len(body_jaxpr.outs)
out = while_loop(lambda x: x > 0, lambda x: x + -3, 10)
print(out)
# Notice the convention that `args = [*consts, *carry]`.
#
# The `while_loop` JVP rule introduces a wrinkle. For `jvp_jaxpr`, we have the
# convention that all the binders for tangent values are appended after all the
# binders for primal values, like `args = [*primals, *tangents]`. But that's in
# tension with our `while_loop` convention that the carry binders come after all
# the constant binders, i.e. that `args = [*consts, *carry]`, because both the
# constants and the carries can have their own tangents. For this reason, we
# introduce the `_loop_jvp_binders` helper to rearrange binders as needed.
# +
def while_loop_jvp_rule(primals, tangents, *, cond_jaxpr, body_jaxpr):
num_consts = _loop_num_consts(body_jaxpr)
body_jaxpr, body_consts = jvp_jaxpr(body_jaxpr)
cond_jaxpr, body_jaxpr = _loop_jvp_binders(
cond_jaxpr, body_jaxpr, len(body_consts), num_consts)
outs = bind(while_loop_p, *body_consts, *primals, *tangents,
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
primals_out, tangents_out = split_half(outs)
return primals_out, tangents_out
jvp_rules[while_loop_p] = while_loop_jvp_rule
def _loop_jvp_binders(cond_jaxpr: Jaxpr, body_jaxpr: Jaxpr, n1: int, n2: int
) -> Jaxpr:
# body binders [c1, c2, x1, c2dot, x2dot] ~~> [c1, c2, c2dot, x1, x1dot]
jvp_const_binders, binders = split_list(body_jaxpr.in_binders, n1)
primal_binders, tangent_binders = split_half(binders)
consts , carry = split_list(primal_binders , n2)
consts_dot, carry_dot = split_list(tangent_binders, n2)
new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot
new_body_jaxpr = Jaxpr(new_in_binders, body_jaxpr.eqns, body_jaxpr.outs)
typecheck_jaxpr(new_body_jaxpr)
# cond binders [c2, x1] ~~> [c1, c2, c2dot, x1, x1dot]
assert not set(new_body_jaxpr.in_binders) & set(cond_jaxpr.in_binders)
consts, carry = split_list(cond_jaxpr.in_binders, n2)
new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot
new_cond_jaxpr = Jaxpr(new_in_binders, cond_jaxpr.eqns, cond_jaxpr.outs)
return new_cond_jaxpr, new_body_jaxpr
# -
out, out_tan = jvp(lambda x: while_loop(lambda x: x < 10., lambda x: x * 2., x),
(1.,), (1.,))
print(out_tan)
# +
def f(x):
def cond_fn(i, _):
return i < 3
def body_fn(i, x):
return i + 1, cos(x)
_, y = while_loop(cond_fn, body_fn, (0, x))
return y
def g(x):
return cos(cos(cos(x)))
print(jvp(f, (1.,), (1.,)))
print(jvp(g, (1.,), (1.,)))
# -
# The vmap rule for `while_loop` presents two cases:
# 1. if the output of `cond_fun` is not batched, then the loop has the same
# basic structure, just with a batched body;
# 2. but if the output of `cond_fun` is batched, we must represent a batch of
# loops which might run for different numbers of iterations.
#
# ...Stay tuned for the thrilling conclusion!