mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
autodidax: delete while_loop for now
This commit is contained in:
parent
83cd42271b
commit
d88acd8b8c
@ -3077,7 +3077,7 @@
|
||||
"\n",
|
||||
"There are actually two rules to write: one for trace-time partial evaluation,\n",
|
||||
"which we'll call `xla_call_partial_eval`, and one for partial evaluation of\n",
|
||||
"jaxprs, whicch we'll call `xla_call_peval_eqn`."
|
||||
"jaxprs, which we'll call `xla_call_peval_eqn`."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -3558,7 +3558,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Part 5: the control flow primitives `cond` and `while_loop`\n",
|
||||
"## Part 5: the control flow primitives `cond`\n",
|
||||
"\n",
|
||||
"Next we'll add higher-order primitives for staged-out control flow. These\n",
|
||||
"resemble `jit` from Part 3, another higher-order primitive, but differ in that\n",
|
||||
@ -4014,187 +4014,6 @@
|
||||
"out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)\n",
|
||||
"print(out)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Adding `while_loop`\n",
|
||||
"\n",
|
||||
"Next we'll add a primitive for looping behavior in a jaxpr. We'll use\n",
|
||||
"`while_loop : (a -> Bool) -> (a -> a) -> a -> a`, where the first\n",
|
||||
"function-valued argument represents the loop condition, the second represents\n",
|
||||
"the loop body, and the final argument is the initial value of the carry.\n",
|
||||
"\n",
|
||||
"After `cond`, adding `while_loop` is not so different:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def while_loop(cond_fn, body_fn, init_val):\n",
|
||||
" init_val, in_tree = tree_flatten(init_val)\n",
|
||||
" avals_in = [raise_to_shaped(get_aval(x)) for x in init_val]\n",
|
||||
" cond_jaxpr, cond_consts, cond_tree = make_jaxpr(cond_fn, *avals_in)\n",
|
||||
" body_jaxpr, body_consts, in_tree_ = make_jaxpr(body_fn, *avals_in)\n",
|
||||
" cond_jaxpr, body_jaxpr = _join_jaxpr_consts(\n",
|
||||
" cond_jaxpr, body_jaxpr, len(cond_consts), len(body_consts))\n",
|
||||
" if cond_tree != tree_flatten(True)[1]: raise TypeError\n",
|
||||
" if in_tree != in_tree_: raise TypeError\n",
|
||||
" outs = bind(while_loop_p, *cond_consts, *body_consts, *init_val,\n",
|
||||
" cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)\n",
|
||||
" return tree_unflatten(in_tree, outs)\n",
|
||||
"while_loop_p = Primitive('while_loop')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def while_loop_impl(*args, cond_jaxpr, body_jaxpr):\n",
|
||||
" consts, carry = split_list(args, _loop_num_consts(body_jaxpr))\n",
|
||||
" while eval_jaxpr(cond_jaxpr, [*consts, *carry])[0]:\n",
|
||||
" carry = eval_jaxpr(body_jaxpr, [*consts, *carry])\n",
|
||||
" return carry\n",
|
||||
"impl_rules[while_loop_p] = while_loop_impl"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _loop_num_consts(body_jaxpr: Jaxpr) -> int:\n",
|
||||
" return len(body_jaxpr.in_binders) - len(body_jaxpr.outs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_end_of_cell_marker": 0,
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"out = while_loop(lambda x: x > 0, lambda x: x + -3, 10)\n",
|
||||
"print(out)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Notice the convention that `args = [*consts, *carry]`.\n",
|
||||
"\n",
|
||||
"The `while_loop` JVP rule introduces a wrinkle. For `jvp_jaxpr`, we have the\n",
|
||||
"convention that all the binders for tangent values are appended after all the\n",
|
||||
"binders for primal values, like `args = [*primals, *tangents]`. But that's in\n",
|
||||
"tension with our `while_loop` convention that the carry binders come after all\n",
|
||||
"the constant binders, i.e. that `args = [*consts, *carry]`, because both the\n",
|
||||
"constants and the carries can have their own tangents. For this reason, we\n",
|
||||
"introduce the `_loop_jvp_binders` helper to rearrange binders as needed."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def while_loop_jvp_rule(primals, tangents, *, cond_jaxpr, body_jaxpr):\n",
|
||||
" num_consts = _loop_num_consts(body_jaxpr)\n",
|
||||
" body_jaxpr, body_consts = jvp_jaxpr(body_jaxpr)\n",
|
||||
" cond_jaxpr, body_jaxpr = _loop_jvp_binders(\n",
|
||||
" cond_jaxpr, body_jaxpr, len(body_consts), num_consts)\n",
|
||||
" outs = bind(while_loop_p, *body_consts, *primals, *tangents,\n",
|
||||
" cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)\n",
|
||||
" primals_out, tangents_out = split_half(outs)\n",
|
||||
" return primals_out, tangents_out\n",
|
||||
"jvp_rules[while_loop_p] = while_loop_jvp_rule\n",
|
||||
"\n",
|
||||
"def _loop_jvp_binders(cond_jaxpr: Jaxpr, body_jaxpr: Jaxpr, n1: int, n2: int\n",
|
||||
" ) -> Jaxpr:\n",
|
||||
" # body binders [c1, c2, x1, c2dot, x2dot] ~~> [c1, c2, c2dot, x1, x1dot]\n",
|
||||
" jvp_const_binders, binders = split_list(body_jaxpr.in_binders, n1)\n",
|
||||
" primal_binders, tangent_binders = split_half(binders)\n",
|
||||
" consts , carry = split_list(primal_binders , n2)\n",
|
||||
" consts_dot, carry_dot = split_list(tangent_binders, n2)\n",
|
||||
" new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot\n",
|
||||
" new_body_jaxpr = Jaxpr(new_in_binders, body_jaxpr.eqns, body_jaxpr.outs)\n",
|
||||
" typecheck_jaxpr(new_body_jaxpr)\n",
|
||||
"\n",
|
||||
" # cond binders [c2, x1] ~~> [c1, c2, c2dot, x1, x1dot]\n",
|
||||
" assert not set(new_body_jaxpr.in_binders) & set(cond_jaxpr.in_binders)\n",
|
||||
" consts, carry = split_list(cond_jaxpr.in_binders, n2)\n",
|
||||
" new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot\n",
|
||||
" new_cond_jaxpr = Jaxpr(new_in_binders, cond_jaxpr.eqns, cond_jaxpr.outs)\n",
|
||||
"\n",
|
||||
" return new_cond_jaxpr, new_body_jaxpr"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"out, out_tan = jvp(lambda x: while_loop(lambda x: x < 10., lambda x: x * 2., x),\n",
|
||||
" (1.,), (1.,))\n",
|
||||
"print(out_tan)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def f(x):\n",
|
||||
" def cond_fn(i, _):\n",
|
||||
" return i < 3\n",
|
||||
" def body_fn(i, x):\n",
|
||||
" return i + 1, cos(x)\n",
|
||||
" _, y = while_loop(cond_fn, body_fn, (0, x))\n",
|
||||
" return y\n",
|
||||
"\n",
|
||||
"def g(x):\n",
|
||||
" return cos(cos(cos(x)))\n",
|
||||
"\n",
|
||||
"print(jvp(f, (1.,), (1.,)))\n",
|
||||
"print(jvp(g, (1.,), (1.,)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The vmap rule for `while_loop` presents two cases:\n",
|
||||
"1. if the output of `cond_fun` is not batched, then the loop has the same\n",
|
||||
" basic structure, just with a batched body;\n",
|
||||
"2. but if the output of `cond_fun` is batched, we must represent a batch of\n",
|
||||
" loops which might run for different numbers of iterations.\n",
|
||||
"\n",
|
||||
"...Stay tuned for the thrilling conclusion!"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -2279,7 +2279,7 @@ perform partial evaluation of a jaxpr, 'unzipping' it into two jaxprs.
|
||||
|
||||
There are actually two rules to write: one for trace-time partial evaluation,
|
||||
which we'll call `xla_call_partial_eval`, and one for partial evaluation of
|
||||
jaxprs, whicch we'll call `xla_call_peval_eqn`.
|
||||
jaxprs, which we'll call `xla_call_peval_eqn`.
|
||||
|
||||
```{code-cell}
|
||||
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
|
||||
@ -2646,7 +2646,7 @@ _, hess7 = jvp(jit(grad(f)), (3.,), (1.,))
|
||||
assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)
|
||||
```
|
||||
|
||||
## Part 5: the control flow primitives `cond` and `while_loop`
|
||||
## Part 5: the control flow primitives `cond`
|
||||
|
||||
Next we'll add higher-order primitives for staged-out control flow. These
|
||||
resemble `jit` from Part 3, another higher-order primitive, but differ in that
|
||||
@ -2967,119 +2967,3 @@ transpose_rules[cond_p] = cond_transpose_rule
|
||||
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
|
||||
print(out)
|
||||
```
|
||||
|
||||
### Adding `while_loop`
|
||||
|
||||
Next we'll add a primitive for looping behavior in a jaxpr. We'll use
|
||||
`while_loop : (a -> Bool) -> (a -> a) -> a -> a`, where the first
|
||||
function-valued argument represents the loop condition, the second represents
|
||||
the loop body, and the final argument is the initial value of the carry.
|
||||
|
||||
After `cond`, adding `while_loop` is not so different:
|
||||
|
||||
```{code-cell}
|
||||
def while_loop(cond_fn, body_fn, init_val):
|
||||
init_val, in_tree = tree_flatten(init_val)
|
||||
avals_in = [raise_to_shaped(get_aval(x)) for x in init_val]
|
||||
cond_jaxpr, cond_consts, cond_tree = make_jaxpr(cond_fn, *avals_in)
|
||||
body_jaxpr, body_consts, in_tree_ = make_jaxpr(body_fn, *avals_in)
|
||||
cond_jaxpr, body_jaxpr = _join_jaxpr_consts(
|
||||
cond_jaxpr, body_jaxpr, len(cond_consts), len(body_consts))
|
||||
if cond_tree != tree_flatten(True)[1]: raise TypeError
|
||||
if in_tree != in_tree_: raise TypeError
|
||||
outs = bind(while_loop_p, *cond_consts, *body_consts, *init_val,
|
||||
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
|
||||
return tree_unflatten(in_tree, outs)
|
||||
while_loop_p = Primitive('while_loop')
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
def while_loop_impl(*args, cond_jaxpr, body_jaxpr):
|
||||
consts, carry = split_list(args, _loop_num_consts(body_jaxpr))
|
||||
while eval_jaxpr(cond_jaxpr, [*consts, *carry])[0]:
|
||||
carry = eval_jaxpr(body_jaxpr, [*consts, *carry])
|
||||
return carry
|
||||
impl_rules[while_loop_p] = while_loop_impl
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
def _loop_num_consts(body_jaxpr: Jaxpr) -> int:
|
||||
return len(body_jaxpr.in_binders) - len(body_jaxpr.outs)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
out = while_loop(lambda x: x > 0, lambda x: x + -3, 10)
|
||||
print(out)
|
||||
```
|
||||
|
||||
Notice the convention that `args = [*consts, *carry]`.
|
||||
|
||||
The `while_loop` JVP rule introduces a wrinkle. For `jvp_jaxpr`, we have the
|
||||
convention that all the binders for tangent values are appended after all the
|
||||
binders for primal values, like `args = [*primals, *tangents]`. But that's in
|
||||
tension with our `while_loop` convention that the carry binders come after all
|
||||
the constant binders, i.e. that `args = [*consts, *carry]`, because both the
|
||||
constants and the carries can have their own tangents. For this reason, we
|
||||
introduce the `_loop_jvp_binders` helper to rearrange binders as needed.
|
||||
|
||||
```{code-cell}
|
||||
def while_loop_jvp_rule(primals, tangents, *, cond_jaxpr, body_jaxpr):
|
||||
num_consts = _loop_num_consts(body_jaxpr)
|
||||
body_jaxpr, body_consts = jvp_jaxpr(body_jaxpr)
|
||||
cond_jaxpr, body_jaxpr = _loop_jvp_binders(
|
||||
cond_jaxpr, body_jaxpr, len(body_consts), num_consts)
|
||||
outs = bind(while_loop_p, *body_consts, *primals, *tangents,
|
||||
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
|
||||
primals_out, tangents_out = split_half(outs)
|
||||
return primals_out, tangents_out
|
||||
jvp_rules[while_loop_p] = while_loop_jvp_rule
|
||||
|
||||
def _loop_jvp_binders(cond_jaxpr: Jaxpr, body_jaxpr: Jaxpr, n1: int, n2: int
|
||||
) -> Jaxpr:
|
||||
# body binders [c1, c2, x1, c2dot, x2dot] ~~> [c1, c2, c2dot, x1, x1dot]
|
||||
jvp_const_binders, binders = split_list(body_jaxpr.in_binders, n1)
|
||||
primal_binders, tangent_binders = split_half(binders)
|
||||
consts , carry = split_list(primal_binders , n2)
|
||||
consts_dot, carry_dot = split_list(tangent_binders, n2)
|
||||
new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot
|
||||
new_body_jaxpr = Jaxpr(new_in_binders, body_jaxpr.eqns, body_jaxpr.outs)
|
||||
typecheck_jaxpr(new_body_jaxpr)
|
||||
|
||||
# cond binders [c2, x1] ~~> [c1, c2, c2dot, x1, x1dot]
|
||||
assert not set(new_body_jaxpr.in_binders) & set(cond_jaxpr.in_binders)
|
||||
consts, carry = split_list(cond_jaxpr.in_binders, n2)
|
||||
new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot
|
||||
new_cond_jaxpr = Jaxpr(new_in_binders, cond_jaxpr.eqns, cond_jaxpr.outs)
|
||||
|
||||
return new_cond_jaxpr, new_body_jaxpr
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
out, out_tan = jvp(lambda x: while_loop(lambda x: x < 10., lambda x: x * 2., x),
|
||||
(1.,), (1.,))
|
||||
print(out_tan)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
def f(x):
|
||||
def cond_fn(i, _):
|
||||
return i < 3
|
||||
def body_fn(i, x):
|
||||
return i + 1, cos(x)
|
||||
_, y = while_loop(cond_fn, body_fn, (0, x))
|
||||
return y
|
||||
|
||||
def g(x):
|
||||
return cos(cos(cos(x)))
|
||||
|
||||
print(jvp(f, (1.,), (1.,)))
|
||||
print(jvp(g, (1.,), (1.,)))
|
||||
```
|
||||
|
||||
The vmap rule for `while_loop` presents two cases:
|
||||
1. if the output of `cond_fun` is not batched, then the loop has the same
|
||||
basic structure, just with a batched body;
|
||||
2. but if the output of `cond_fun` is batched, we must represent a batch of
|
||||
loops which might run for different numbers of iterations.
|
||||
|
||||
...Stay tuned for the thrilling conclusion!
|
||||
|
@ -2188,7 +2188,7 @@ print(sin_lin(1.), cos(3.))
|
||||
#
|
||||
# There are actually two rules to write: one for trace-time partial evaluation,
|
||||
# which we'll call `xla_call_partial_eval`, and one for partial evaluation of
|
||||
# jaxprs, whicch we'll call `xla_call_peval_eqn`.
|
||||
# jaxprs, which we'll call `xla_call_peval_eqn`.
|
||||
|
||||
# +
|
||||
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
|
||||
@ -2548,7 +2548,7 @@ _, hess7 = jvp(jit(grad(f)), (3.,), (1.,))
|
||||
assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)
|
||||
# -
|
||||
|
||||
# ## Part 5: the control flow primitives `cond` and `while_loop`
|
||||
# ## Part 5: the control flow primitives `cond`
|
||||
#
|
||||
# Next we'll add higher-order primitives for staged-out control flow. These
|
||||
# resemble `jit` from Part 3, another higher-order primitive, but differ in that
|
||||
@ -2845,112 +2845,3 @@ transpose_rules[cond_p] = cond_transpose_rule
|
||||
|
||||
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
|
||||
print(out)
|
||||
|
||||
|
||||
# ### Adding `while_loop`
|
||||
#
|
||||
# Next we'll add a primitive for looping behavior in a jaxpr. We'll use
|
||||
# `while_loop : (a -> Bool) -> (a -> a) -> a -> a`, where the first
|
||||
# function-valued argument represents the loop condition, the second represents
|
||||
# the loop body, and the final argument is the initial value of the carry.
|
||||
#
|
||||
# After `cond`, adding `while_loop` is not so different:
|
||||
|
||||
def while_loop(cond_fn, body_fn, init_val):
|
||||
init_val, in_tree = tree_flatten(init_val)
|
||||
avals_in = [raise_to_shaped(get_aval(x)) for x in init_val]
|
||||
cond_jaxpr, cond_consts, cond_tree = make_jaxpr(cond_fn, *avals_in)
|
||||
body_jaxpr, body_consts, in_tree_ = make_jaxpr(body_fn, *avals_in)
|
||||
cond_jaxpr, body_jaxpr = _join_jaxpr_consts(
|
||||
cond_jaxpr, body_jaxpr, len(cond_consts), len(body_consts))
|
||||
if cond_tree != tree_flatten(True)[1]: raise TypeError
|
||||
if in_tree != in_tree_: raise TypeError
|
||||
outs = bind(while_loop_p, *cond_consts, *body_consts, *init_val,
|
||||
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
|
||||
return tree_unflatten(in_tree, outs)
|
||||
while_loop_p = Primitive('while_loop')
|
||||
|
||||
def while_loop_impl(*args, cond_jaxpr, body_jaxpr):
|
||||
consts, carry = split_list(args, _loop_num_consts(body_jaxpr))
|
||||
while eval_jaxpr(cond_jaxpr, [*consts, *carry])[0]:
|
||||
carry = eval_jaxpr(body_jaxpr, [*consts, *carry])
|
||||
return carry
|
||||
impl_rules[while_loop_p] = while_loop_impl
|
||||
|
||||
def _loop_num_consts(body_jaxpr: Jaxpr) -> int:
|
||||
return len(body_jaxpr.in_binders) - len(body_jaxpr.outs)
|
||||
|
||||
out = while_loop(lambda x: x > 0, lambda x: x + -3, 10)
|
||||
print(out)
|
||||
|
||||
# Notice the convention that `args = [*consts, *carry]`.
|
||||
#
|
||||
# The `while_loop` JVP rule introduces a wrinkle. For `jvp_jaxpr`, we have the
|
||||
# convention that all the binders for tangent values are appended after all the
|
||||
# binders for primal values, like `args = [*primals, *tangents]`. But that's in
|
||||
# tension with our `while_loop` convention that the carry binders come after all
|
||||
# the constant binders, i.e. that `args = [*consts, *carry]`, because both the
|
||||
# constants and the carries can have their own tangents. For this reason, we
|
||||
# introduce the `_loop_jvp_binders` helper to rearrange binders as needed.
|
||||
|
||||
# +
|
||||
def while_loop_jvp_rule(primals, tangents, *, cond_jaxpr, body_jaxpr):
|
||||
num_consts = _loop_num_consts(body_jaxpr)
|
||||
body_jaxpr, body_consts = jvp_jaxpr(body_jaxpr)
|
||||
cond_jaxpr, body_jaxpr = _loop_jvp_binders(
|
||||
cond_jaxpr, body_jaxpr, len(body_consts), num_consts)
|
||||
outs = bind(while_loop_p, *body_consts, *primals, *tangents,
|
||||
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
|
||||
primals_out, tangents_out = split_half(outs)
|
||||
return primals_out, tangents_out
|
||||
jvp_rules[while_loop_p] = while_loop_jvp_rule
|
||||
|
||||
def _loop_jvp_binders(cond_jaxpr: Jaxpr, body_jaxpr: Jaxpr, n1: int, n2: int
|
||||
) -> Jaxpr:
|
||||
# body binders [c1, c2, x1, c2dot, x2dot] ~~> [c1, c2, c2dot, x1, x1dot]
|
||||
jvp_const_binders, binders = split_list(body_jaxpr.in_binders, n1)
|
||||
primal_binders, tangent_binders = split_half(binders)
|
||||
consts , carry = split_list(primal_binders , n2)
|
||||
consts_dot, carry_dot = split_list(tangent_binders, n2)
|
||||
new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot
|
||||
new_body_jaxpr = Jaxpr(new_in_binders, body_jaxpr.eqns, body_jaxpr.outs)
|
||||
typecheck_jaxpr(new_body_jaxpr)
|
||||
|
||||
# cond binders [c2, x1] ~~> [c1, c2, c2dot, x1, x1dot]
|
||||
assert not set(new_body_jaxpr.in_binders) & set(cond_jaxpr.in_binders)
|
||||
consts, carry = split_list(cond_jaxpr.in_binders, n2)
|
||||
new_in_binders = jvp_const_binders + consts + consts_dot + carry + carry_dot
|
||||
new_cond_jaxpr = Jaxpr(new_in_binders, cond_jaxpr.eqns, cond_jaxpr.outs)
|
||||
|
||||
return new_cond_jaxpr, new_body_jaxpr
|
||||
|
||||
|
||||
# -
|
||||
|
||||
out, out_tan = jvp(lambda x: while_loop(lambda x: x < 10., lambda x: x * 2., x),
|
||||
(1.,), (1.,))
|
||||
print(out_tan)
|
||||
|
||||
# +
|
||||
def f(x):
|
||||
def cond_fn(i, _):
|
||||
return i < 3
|
||||
def body_fn(i, x):
|
||||
return i + 1, cos(x)
|
||||
_, y = while_loop(cond_fn, body_fn, (0, x))
|
||||
return y
|
||||
|
||||
def g(x):
|
||||
return cos(cos(cos(x)))
|
||||
|
||||
print(jvp(f, (1.,), (1.,)))
|
||||
print(jvp(g, (1.,), (1.,)))
|
||||
# -
|
||||
|
||||
# The vmap rule for `while_loop` presents two cases:
|
||||
# 1. if the output of `cond_fun` is not batched, then the loop has the same
|
||||
# basic structure, just with a batched body;
|
||||
# 2. but if the output of `cond_fun` is batched, we must represent a batch of
|
||||
# loops which might run for different numbers of iterations.
|
||||
#
|
||||
# ...Stay tuned for the thrilling conclusion!
|
||||
|
Loading…
x
Reference in New Issue
Block a user