Mirror the minor spelling fixes over the {.ipynb, .md, .py}.

This commit is contained in:
James Knighton 2021-03-27 21:47:07 -07:00
parent a5a93dc845
commit 0d19b7c082
3 changed files with 20 additions and 19 deletions

View File

@ -2718,8 +2718,8 @@
"interpreter will build a jaxpr on the fly while tracking data dependencies. To\n",
"do so, it builds a bipartite directed acyclic graph (DAG) between\n",
"`PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe`\n",
"nodes, representing formulas for how to compute some values from others. One kind\n",
"of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive\n",
"nodes, representing formulas for how to compute some values from others. One\n",
"kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive\n",
"application, but we also have recipe types for constants and lambda binders:"
]
},

View File

@ -71,7 +71,7 @@ flow through our program. For example, we might want to replace the
application of every primitive with an application of [its JVP
rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),
and let primal-tangent pairs flow through our program. Moreover, we want to be
able to comopse multiple transformations, leading to stacks of interpreters.
able to compose multiple transformations, leading to stacks of interpreters.
+++
@ -125,7 +125,7 @@ to `bind`, and in particular we follow a handy internal convention: when we
call `bind`, we pass values representing array data as positional arguments,
and we pass metadata like the `axis` argument to `sum_p` via keyword. This
calling convention simplifies some core logic (since e.g. instances of the
`Tracer` class to be defined below can only occurr in positional arguments to
`Tracer` class to be defined below can only occur in positional arguments to
`bind`). The wrappers can also provide docstrings!
We represent active interpreters as a stack. The stack is just a simple
@ -174,7 +174,7 @@ evaluation. So at the bottom we'll put an evaluation interpreter.
Let's sketch out the interface for interpreters, which is based on the `Trace`
and `Tracer` base classes. A `Tracer` represents a boxed-up value, perhaps
carrying some extra context data used by the interpreter. A `Trace` handles
boxing up vales into `Tracers` and also handles primitive application.
boxing up values into `Tracers` and also handles primitive application.
```{code-cell}
class Trace:
@ -1750,7 +1750,7 @@ print(ys)
One piece missing is device memory persistence for arrays. That is, we've
defined `handle_result` to transfer results back to CPU memory as NumPy
arrays, but it's often preferrable to avoid transferring results just to
arrays, but it's often preferable to avoid transferring results just to
transfer them back for the next operation. We can do that by introducing a
`DeviceArray` class, which can wrap XLA buffers and otherwise duck-type
`numpy.ndarray`s:
@ -1831,7 +1831,7 @@ we evaluate all the primal values as we trace, but stage the tangent
computations into a jaxpr. This is our second way to build jaxprs. But where
`make_jaxpr` and its underlying `JaxprTrace`/`JaxprTracer` interpreters aim
to stage out every primitive bind, this second approach stages out only those
primitive binds with a data dependence on tagent inputs.
primitive binds with a data dependence on tangent inputs.
First, some utilities:
@ -1897,7 +1897,7 @@ behavior relies on the data dependencies inside the given Python callable and
not just its type. Nevertheless a heuristic type signature is useful. If we
assume the input function's type signature is `(a1, a2) -> (b1, b2)`, where
`a1` and `a2` represent the known and unknown inputs, respectively, and where
`b1` only has a data depenence on `a1` while `b2` has some data dependnece on
`b1` only has a data dependency on `a1` while `b2` has some data dependency on
`a2`, then we might write
```
@ -2000,8 +2000,8 @@ Next we need to implement `PartialEvalTrace` and its `PartialEvalTracer`. This
interpreter will build a jaxpr on the fly while tracking data dependencies. To
do so, it builds a bipartite directed acyclic graph (DAG) between
`PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe`
nodes, representing formulas for how compute some values from others. One kind
of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive
nodes, representing formulas for how to compute some values from others. One
kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive
application, but we also have recipe types for constants and lambda binders:
```{code-cell}

View File

@ -60,7 +60,7 @@
# application of every primitive with an application of [its JVP
# rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),
# and let primal-tangent pairs flow through our program. Moreover, we want to be
# able to comopse multiple transformations, leading to stacks of interpreters.
# able to compose multiple transformations, leading to stacks of interpreters.
# ### JAX core machinery
#
@ -113,7 +113,7 @@ def bind1(prim, *args, **params):
# call `bind`, we pass values representing array data as positional arguments,
# and we pass metadata like the `axis` argument to `sum_p` via keyword. This
# calling convention simplifies some core logic (since e.g. instances of the
# `Tracer` class to be defined below can only occurr in positional arguments to
# `Tracer` class to be defined below can only occur in positional arguments to
# `bind`). The wrappers can also provide docstrings!
#
# We represent active interpreters as a stack. The stack is just a simple
@ -162,7 +162,7 @@ def new_main(trace_type: Type['Trace'], global_data=None):
# Let's sketch out the interface for interpreters, which is based on the `Trace`
# and `Tracer` base classes. A `Tracer` represents a boxed-up value, perhaps
# carrying some extra context data used by the interpreter. A `Trace` handles
# boxing up vales into `Tracers` and also handles primitive application.
# boxing up values into `Tracers` and also handles primitive application.
class Trace:
main: MainTrace
@ -1672,7 +1672,7 @@ print(ys)
# One piece missing is device memory persistence for arrays. That is, we've
# defined `handle_result` to transfer results back to CPU memory as NumPy
# arrays, but it's often preferrable to avoid transferring results just to
# arrays, but it's often preferable to avoid transferring results just to
# transfer them back for the next operation. We can do that by introducing a
# `DeviceArray` class, which can wrap XLA buffers and otherwise duck-type
# `numpy.ndarray`s:
@ -1750,7 +1750,7 @@ print(ydot)
# computations into a jaxpr. This is our second way to build jaxprs. But where
# `make_jaxpr` and its underlying `JaxprTrace`/`JaxprTracer` interpreters aim
# to stage out every primitive bind, this second approach stages out only those
# primitive binds with a data dependence on tagent inputs.
# primitive binds with a data dependence on tangent inputs.
#
# First, some utilities:
@ -1816,7 +1816,7 @@ def vspace(aval: ShapedArray) -> ShapedArray:
# not just its type. Nevertheless a heuristic type signature is useful. If we
# assume the input function's type signature is `(a1, a2) -> (b1, b2)`, where
# `a1` and `a2` represent the known and unknown inputs, respectively, and where
# `b1` only has a data depenence on `a1` while `b2` has some data dependnece on
# `b1` only has a data dependency on `a1` while `b2` has some data dependency on
# `a2`, then we might write
#
# ```
@ -1915,9 +1915,10 @@ def partial_eval_flat(f, pvals_in: List[PartialVal]):
# interpreter will build a jaxpr on the fly while tracking data dependencies. To
# do so, it builds a bipartite directed acyclic graph (DAG) between
# `PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe`
# nodes, representing formulas for how compute some values from others. One kind
# of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive
# application, but we also have recipe types for constants and lambda binders:
# nodes, representing formulas for how to compute some values from others. One
# kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s
# primitive application, but we also have recipe types for constants and lambda
# binders:
# +
from weakref import ref, ReferenceType