reviewer comments, delint

This commit is contained in:
Matthew Johnson 2021-02-24 20:25:24 -08:00
parent 0fcecb2bd2
commit 9789677e85
3 changed files with 68 additions and 59 deletions

View File

@ -35,10 +35,11 @@
"where we apply primitive functions to numerical inputs to produce numerical\n",
"outputs, we want to override primitive application and let different values\n",
"flow through our program. For example, we might want to replace the\n",
"application of every primitive with type `a -> b` with an application of its\n",
"JVP rule with type `(a, T a) -> (b, T b)`, and let primal-tangent pairs flow\n",
"through our program. Moreover, we want to apply a composition of multiple\n",
"transformations, leading to stacks of interpreters."
"application of every primitive with an application of [its JVP\n",
"rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),\n",
"and let primal-tangent pairs flow through our program. Moreover, we want to\n",
"apply a composition of multiple transformations, leading to stacks of\n",
"interpreters."
]
},
{
@ -144,7 +145,7 @@
"source": [
"When we're about to apply a transformed function, we'll push another\n",
"interpreter onto the stack using `new_main`. Then, as we apply primitives in\n",
"the function, we can think of the `bind` first being interprted by the trace\n",
"the function, we can think of the `bind` first being interpreted by the trace\n",
"at the top of the stack (i.e. with the highest level). If that first\n",
"interpreter itself binds other primitives in its interpretation rule for the\n",
"primitive, like how the JVP rule of `sin_p` might bind `cos_p` and `mul_p`,\n",
@ -283,7 +284,7 @@
"\n",
" @staticmethod\n",
" def _nonzero(tracer):\n",
" return nonzero(tracer.aval.val)\n",
" return bool(tracer.aval.val)\n",
"\n",
"def get_aval(x):\n",
" if isinstance(x, Tracer):\n",
@ -474,7 +475,7 @@
"source": [
"### Forward-mode autodiff with `jvp`\n",
"\n",
"First, a couple helper functions:"
"First, a couple of helper functions:"
]
},
{
@ -578,7 +579,7 @@
"\n",
"def reduce_sum_jvp(primals, tangents, *, axis):\n",
" (x,), (x_dot,) = primals, tangents\n",
" return reduce_sum(x_dot, axis), reduce_sum(x_dot, axis)\n",
" return reduce_sum(x, axis), reduce_sum(x_dot, axis)\n",
"jvp_rules[reduce_sum_p] = reduce_sum_jvp\n",
"\n",
"def greater_jvp(primals, tangents):\n",
@ -805,7 +806,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we've implemented the optional `Tracer.full_lower` method, which lets\n",
"Here we've implemented the optional `Tracer.full_lower` method, which lets us\n",
"peel off a batching tracer if it's not needed because it doesn't represent a\n",
"batched value.\n",
"\n",
@ -965,18 +966,20 @@
"carry a little bit of extra context, but for both `jit` and `vjp` we need\n",
"much richer context: we need to represent _programs_. That is, we need jaxprs!\n",
"\n",
"We need a program representation for `jit` because the purpose of `jit` is to\n",
"stage computation out of Python. For any computation we want to stage out,\n",
"we need to be able to represent it as data, and build it up as we trace a\n",
"Python function. Similarly, `vjp` needs a way to represent the computation for\n",
"the backward pass of reverse-mode autodiff. We use the same jaxpr program\n",
"representation for both needs.\n",
"Jaxprs are JAX's internal intermediate representation of programs. Jaxprs are\n",
"an explicitly typed, functional, first-order language. We need a program\n",
"representation for `jit` because the purpose of `jit` is to stage computation\n",
"out of Python. For any computation we want to stage out, we need to be able to\n",
"represent it as data, and build it up as we trace a Python function.\n",
"Similarly, `vjp` needs a way to represent the computation for the backward\n",
"pass of reverse-mode autodiff. We use the same jaxpr program representation\n",
"for both needs.\n",
"\n",
"(Building a program representation is the most\n",
"[free](https://en.wikipedia.org/wiki/Free_object) kind of trace-\n",
"transformation, and so except for issues around handling native Python control\n",
"flow, any transformation could be implemented by first tracing to a jaxpr\n",
"and then interpreting the jaxpr.)\n",
"[free](https://en.wikipedia.org/wiki/Free_object) kind of\n",
"trace- transformation, and so except for issues around handling native Python\n",
"control flow, any transformation could be implemented by first tracing to a\n",
"jaxpr and then interpreting the jaxpr.)\n",
"\n",
"The jaxpr term syntax is roughly:\n",
"\n",
@ -1035,7 +1038,7 @@
" primitive: Primitive\n",
" inputs: List[Atom]\n",
" params: Dict[str, Any]\n",
" out_binder: List[Var]\n",
" out_binder: Var\n",
"\n",
"class Jaxpr(NamedTuple):\n",
" in_binders: List[Var]\n",

View File

@ -6,7 +6,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.10.0
jupytext_version: 1.10.2
kernelspec:
display_name: Python 3
name: python3
@ -40,10 +40,11 @@ atomic units of processing rather than compositions.
where we apply primitive functions to numerical inputs to produce numerical
outputs, we want to override primitive application and let different values
flow through our program. For example, we might want to replace the
application of every primitive with type `a -> b` with an application of its
JVP rule with type `(a, T a) -> (b, T b)`, and let primal-tangent pairs flow
through our program. Moreover, we want to apply a composition of multiple
transformations, leading to stacks of interpreters.
application of every primitive with an application of [its JVP
rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),
and let primal-tangent pairs flow through our program. Moreover, we want to
apply a composition of multiple transformations, leading to stacks of
interpreters.
+++
@ -124,7 +125,7 @@ def new_main(trace_type: Type['Trace'], global_data=None):
When we're about to apply a transformed function, we'll push another
interpreter onto the stack using `new_main`. Then, as we apply primitives in
the function, we can think of the `bind` first being interprted by the trace
the function, we can think of the `bind` first being interpreted by the trace
at the top of the stack (i.e. with the highest level). If that first
interpreter itself binds other primitives in its interpretation rule for the
primitive, like how the JVP rule of `sin_p` might bind `cos_p` and `mul_p`,
@ -247,7 +248,7 @@ class ConcreteArray(ShapedArray):
@staticmethod
def _nonzero(tracer):
return nonzero(tracer.aval.val)
return bool(tracer.aval.val)
def get_aval(x):
if isinstance(x, Tracer):
@ -370,7 +371,7 @@ that now we can add some real transformations.
### Forward-mode autodiff with `jvp`
First, a couple helper functions:
First, a couple of helper functions:
```{code-cell}
def zeros_like(val):
@ -445,7 +446,7 @@ jvp_rules[neg_p] = neg_jvp
def reduce_sum_jvp(primals, tangents, *, axis):
(x,), (x_dot,) = primals, tangents
return reduce_sum(x_dot, axis), reduce_sum(x_dot, axis)
return reduce_sum(x, axis), reduce_sum(x_dot, axis)
jvp_rules[reduce_sum_p] = reduce_sum_jvp
def greater_jvp(primals, tangents):
@ -575,7 +576,7 @@ class BatchTrace(Trace):
vmap_rules = {}
```
Here we've implemented the optional `Tracer.full_lower` method, which lets
Here we've implemented the optional `Tracer.full_lower` method, which lets us
peel off a batching tracer if it's not needed because it doesn't represent a
batched value.
@ -678,18 +679,20 @@ wrapper around `vjp`.) For `jvp` and `vmap` we only needed each `Tracer` to
carry a little bit of extra context, but for both `jit` and `vjp` we need
much richer context: we need to represent _programs_. That is, we need jaxprs!
We need a program representation for `jit` because the purpose of `jit` is to
stage computation out of Python. For any computation we want to stage out,
we need to be able to represent it as data, and build it up as we trace a
Python function. Similarly, `vjp` needs a way to represent the computation for
the backward pass of reverse-mode autodiff. We use the same jaxpr program
representation for both needs.
Jaxprs are JAX's internal intermediate representation of programs. Jaxprs are
an explicitly typed, functional, first-order language. We need a program
representation for `jit` because the purpose of `jit` is to stage computation
out of Python. For any computation we want to stage out, we need to be able to
represent it as data, and build it up as we trace a Python function.
Similarly, `vjp` needs a way to represent the computation for the backward
pass of reverse-mode autodiff. We use the same jaxpr program representation
for both needs.
(Building a program representation is the most
[free](https://en.wikipedia.org/wiki/Free_object) kind of trace-
transformation, and so except for issues around handling native Python control
flow, any transformation could be implemented by first tracing to a jaxpr
and then interpreting the jaxpr.)
[free](https://en.wikipedia.org/wiki/Free_object) kind of
trace- transformation, and so except for issues around handling native Python
control flow, any transformation could be implemented by first tracing to a
jaxpr and then interpreting the jaxpr.)
The jaxpr term syntax is roughly:
@ -742,7 +745,7 @@ class JaxprEqn(NamedTuple):
primitive: Primitive
inputs: List[Atom]
params: Dict[str, Any]
out_binder: List[Var]
out_binder: Var
class Jaxpr(NamedTuple):
in_binders: List[Var]

View File

@ -6,7 +6,7 @@
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.10.2
# jupytext_version: 1.10.0
# kernelspec:
# display_name: Python 3
# name: python3
@ -38,10 +38,11 @@
# where we apply primitive functions to numerical inputs to produce numerical
# outputs, we want to override primitive application and let different values
# flow through our program. For example, we might want to replace the
# application of every primitive with type `a -> b` with an application of its
# JVP rule with type `(a, T a) -> (b, T b)`, and let primal-tangent pairs flow
# through our program. Moreover, we want to apply a composition of multiple
# transformations, leading to stacks of interpreters.
# application of every primitive with an application of [its JVP
# rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),
# and let primal-tangent pairs flow through our program. Moreover, we want to
# apply a composition of multiple transformations, leading to stacks of
# interpreters.
# ### JAX core machinery
#
@ -125,7 +126,7 @@ def new_main(trace_type: Type['Trace'], global_data=None):
# When we're about to apply a transformed function, we'll push another
# interpreter onto the stack using `new_main`. Then, as we apply primitives in
# the function, we can think of the `bind` first being interprted by the trace
# the function, we can think of the `bind` first being interpreted by the trace
# at the top of the stack (i.e. with the highest level). If that first
# interpreter itself binds other primitives in its interpretation rule for the
# primitive, like how the JVP rule of `sin_p` might bind `cos_p` and `mul_p`,
@ -375,7 +376,7 @@ print(f(3.0))
# ### Forward-mode autodiff with `jvp`
#
# First, a couple helper functions:
# First, a couple of helper functions:
# +
def zeros_like(val):
@ -586,7 +587,7 @@ class BatchTrace(Trace):
vmap_rules = {}
# -
# Here we've implemented the optional `Tracer.full_lower` method, which lets
# Here we've implemented the optional `Tracer.full_lower` method, which lets us
# peel off a batching tracer if it's not needed because it doesn't represent a
# batched value.
#
@ -688,18 +689,20 @@ jacfwd(f, np.arange(3.))
# carry a little bit of extra context, but for both `jit` and `vjp` we need
# much richer context: we need to represent _programs_. That is, we need jaxprs!
#
# We need a program representation for `jit` because the purpose of `jit` is to
# stage computation out of Python. For any computation we want to stage out,
# we need to be able to represent it as data, and build it up as we trace a
# Python function. Similarly, `vjp` needs a way to represent the computation for
# the backward pass of reverse-mode autodiff. We use the same jaxpr program
# representation for both needs.
# Jaxprs are JAX's internal intermediate representation of programs. Jaxprs are
# an explicitly typed, functional, first-order language. We need a program
# representation for `jit` because the purpose of `jit` is to stage computation
# out of Python. For any computation we want to stage out, we need to be able to
# represent it as data, and build it up as we trace a Python function.
# Similarly, `vjp` needs a way to represent the computation for the backward
# pass of reverse-mode autodiff. We use the same jaxpr program representation
# for both needs.
#
# (Building a program representation is the most
# [free](https://en.wikipedia.org/wiki/Free_object) kind of trace-
# transformation, and so except for issues around handling native Python control
# flow, any transformation could be implemented by first tracing to a jaxpr
# and then interpreting the jaxpr.)
# [free](https://en.wikipedia.org/wiki/Free_object) kind of
# trace- transformation, and so except for issues around handling native Python
# control flow, any transformation could be implemented by first tracing to a
# jaxpr and then interpreting the jaxpr.)
#
# The jaxpr term syntax is roughly:
#