mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Mirror the minor spelling fixes over the {.ipynb, .md, .py}.
This commit is contained in:
parent
a5a93dc845
commit
0d19b7c082
@ -2718,8 +2718,8 @@
|
||||
"interpreter will build a jaxpr on the fly while tracking data dependencies. To\n",
|
||||
"do so, it builds a bipartite directed acyclic graph (DAG) between\n",
|
||||
"`PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe`\n",
|
||||
"nodes, representing formulas for how to compute some values from others. One kind\n",
|
||||
"of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive\n",
|
||||
"nodes, representing formulas for how to compute some values from others. One\n",
|
||||
"kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive\n",
|
||||
"application, but we also have recipe types for constants and lambda binders:"
|
||||
]
|
||||
},
|
||||
|
@ -71,7 +71,7 @@ flow through our program. For example, we might want to replace the
|
||||
application of every primitive with an application of [its JVP
|
||||
rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),
|
||||
and let primal-tangent pairs flow through our program. Moreover, we want to be
|
||||
able to comopse multiple transformations, leading to stacks of interpreters.
|
||||
able to compose multiple transformations, leading to stacks of interpreters.
|
||||
|
||||
+++
|
||||
|
||||
@ -125,7 +125,7 @@ to `bind`, and in particular we follow a handy internal convention: when we
|
||||
call `bind`, we pass values representing array data as positional arguments,
|
||||
and we pass metadata like the `axis` argument to `sum_p` via keyword. This
|
||||
calling convention simplifies some core logic (since e.g. instances of the
|
||||
`Tracer` class to be defined below can only occurr in positional arguments to
|
||||
`Tracer` class to be defined below can only occur in positional arguments to
|
||||
`bind`). The wrappers can also provide docstrings!
|
||||
|
||||
We represent active interpreters as a stack. The stack is just a simple
|
||||
@ -174,7 +174,7 @@ evaluation. So at the bottom we'll put an evaluation interpreter.
|
||||
Let's sketch out the interface for interpreters, which is based on the `Trace`
|
||||
and `Tracer` base classes. A `Tracer` represents a boxed-up value, perhaps
|
||||
carrying some extra context data used by the interpreter. A `Trace` handles
|
||||
boxing up vales into `Tracers` and also handles primitive application.
|
||||
boxing up values into `Tracers` and also handles primitive application.
|
||||
|
||||
```{code-cell}
|
||||
class Trace:
|
||||
@ -1750,7 +1750,7 @@ print(ys)
|
||||
|
||||
One piece missing is device memory persistence for arrays. That is, we've
|
||||
defined `handle_result` to transfer results back to CPU memory as NumPy
|
||||
arrays, but it's often preferrable to avoid transferring results just to
|
||||
arrays, but it's often preferable to avoid transferring results just to
|
||||
transfer them back for the next operation. We can do that by introducing a
|
||||
`DeviceArray` class, which can wrap XLA buffers and otherwise duck-type
|
||||
`numpy.ndarray`s:
|
||||
@ -1831,7 +1831,7 @@ we evaluate all the primal values as we trace, but stage the tangent
|
||||
computations into a jaxpr. This is our second way to build jaxprs. But where
|
||||
`make_jaxpr` and its underlying `JaxprTrace`/`JaxprTracer` interpreters aim
|
||||
to stage out every primitive bind, this second approach stages out only those
|
||||
primitive binds with a data dependence on tagent inputs.
|
||||
primitive binds with a data dependence on tangent inputs.
|
||||
|
||||
First, some utilities:
|
||||
|
||||
@ -1897,7 +1897,7 @@ behavior relies on the data dependencies inside the given Python callable and
|
||||
not just its type. Nevertheless a heuristic type signature is useful. If we
|
||||
assume the input function's type signature is `(a1, a2) -> (b1, b2)`, where
|
||||
`a1` and `a2` represent the known and unknown inputs, respectively, and where
|
||||
`b1` only has a data depenence on `a1` while `b2` has some data dependnece on
|
||||
`b1` only has a data dependency on `a1` while `b2` has some data dependency on
|
||||
`a2`, then we might write
|
||||
|
||||
```
|
||||
@ -2000,8 +2000,8 @@ Next we need to implement `PartialEvalTrace` and its `PartialEvalTracer`. This
|
||||
interpreter will build a jaxpr on the fly while tracking data dependencies. To
|
||||
do so, it builds a bipartite directed acyclic graph (DAG) between
|
||||
`PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe`
|
||||
nodes, representing formulas for how compute some values from others. One kind
|
||||
of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive
|
||||
nodes, representing formulas for how to compute some values from others. One
|
||||
kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive
|
||||
application, but we also have recipe types for constants and lambda binders:
|
||||
|
||||
```{code-cell}
|
||||
|
@ -60,7 +60,7 @@
|
||||
# application of every primitive with an application of [its JVP
|
||||
# rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),
|
||||
# and let primal-tangent pairs flow through our program. Moreover, we want to be
|
||||
# able to comopse multiple transformations, leading to stacks of interpreters.
|
||||
# able to compose multiple transformations, leading to stacks of interpreters.
|
||||
|
||||
# ### JAX core machinery
|
||||
#
|
||||
@ -113,7 +113,7 @@ def bind1(prim, *args, **params):
|
||||
# call `bind`, we pass values representing array data as positional arguments,
|
||||
# and we pass metadata like the `axis` argument to `sum_p` via keyword. This
|
||||
# calling convention simplifies some core logic (since e.g. instances of the
|
||||
# `Tracer` class to be defined below can only occurr in positional arguments to
|
||||
# `Tracer` class to be defined below can only occur in positional arguments to
|
||||
# `bind`). The wrappers can also provide docstrings!
|
||||
#
|
||||
# We represent active interpreters as a stack. The stack is just a simple
|
||||
@ -162,7 +162,7 @@ def new_main(trace_type: Type['Trace'], global_data=None):
|
||||
# Let's sketch out the interface for interpreters, which is based on the `Trace`
|
||||
# and `Tracer` base classes. A `Tracer` represents a boxed-up value, perhaps
|
||||
# carrying some extra context data used by the interpreter. A `Trace` handles
|
||||
# boxing up vales into `Tracers` and also handles primitive application.
|
||||
# boxing up values into `Tracers` and also handles primitive application.
|
||||
|
||||
class Trace:
|
||||
main: MainTrace
|
||||
@ -1672,7 +1672,7 @@ print(ys)
|
||||
|
||||
# One piece missing is device memory persistence for arrays. That is, we've
|
||||
# defined `handle_result` to transfer results back to CPU memory as NumPy
|
||||
# arrays, but it's often preferrable to avoid transferring results just to
|
||||
# arrays, but it's often preferable to avoid transferring results just to
|
||||
# transfer them back for the next operation. We can do that by introducing a
|
||||
# `DeviceArray` class, which can wrap XLA buffers and otherwise duck-type
|
||||
# `numpy.ndarray`s:
|
||||
@ -1750,7 +1750,7 @@ print(ydot)
|
||||
# computations into a jaxpr. This is our second way to build jaxprs. But where
|
||||
# `make_jaxpr` and its underlying `JaxprTrace`/`JaxprTracer` interpreters aim
|
||||
# to stage out every primitive bind, this second approach stages out only those
|
||||
# primitive binds with a data dependence on tagent inputs.
|
||||
# primitive binds with a data dependence on tangent inputs.
|
||||
#
|
||||
# First, some utilities:
|
||||
|
||||
@ -1816,7 +1816,7 @@ def vspace(aval: ShapedArray) -> ShapedArray:
|
||||
# not just its type. Nevertheless a heuristic type signature is useful. If we
|
||||
# assume the input function's type signature is `(a1, a2) -> (b1, b2)`, where
|
||||
# `a1` and `a2` represent the known and unknown inputs, respectively, and where
|
||||
# `b1` only has a data depenence on `a1` while `b2` has some data dependnece on
|
||||
# `b1` only has a data dependency on `a1` while `b2` has some data dependency on
|
||||
# `a2`, then we might write
|
||||
#
|
||||
# ```
|
||||
@ -1915,9 +1915,10 @@ def partial_eval_flat(f, pvals_in: List[PartialVal]):
|
||||
# interpreter will build a jaxpr on the fly while tracking data dependencies. To
|
||||
# do so, it builds a bipartite directed acyclic graph (DAG) between
|
||||
# `PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe`
|
||||
# nodes, representing formulas for how compute some values from others. One kind
|
||||
# of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s primitive
|
||||
# application, but we also have recipe types for constants and lambda binders:
|
||||
# nodes, representing formulas for how to compute some values from others. One
|
||||
# kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s
|
||||
# primitive application, but we also have recipe types for constants and lambda
|
||||
# binders:
|
||||
|
||||
# +
|
||||
from weakref import ref, ReferenceType
|
||||
|
Loading…
x
Reference in New Issue
Block a user