mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
reviewer comments, delint
This commit is contained in:
parent
0fcecb2bd2
commit
9789677e85
@ -35,10 +35,11 @@
|
||||
"where we apply primitive functions to numerical inputs to produce numerical\n",
|
||||
"outputs, we want to override primitive application and let different values\n",
|
||||
"flow through our program. For example, we might want to replace the\n",
|
||||
"application of every primitive with type `a -> b` with an application of its\n",
|
||||
"JVP rule with type `(a, T a) -> (b, T b)`, and let primal-tangent pairs flow\n",
|
||||
"through our program. Moreover, we want to apply a composition of multiple\n",
|
||||
"transformations, leading to stacks of interpreters."
|
||||
"application of every primitive with an application of [its JVP\n",
|
||||
"rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),\n",
|
||||
"and let primal-tangent pairs flow through our program. Moreover, we want to\n",
|
||||
"apply a composition of multiple transformations, leading to stacks of\n",
|
||||
"interpreters."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -144,7 +145,7 @@
|
||||
"source": [
|
||||
"When we're about to apply a transformed function, we'll push another\n",
|
||||
"interpreter onto the stack using `new_main`. Then, as we apply primitives in\n",
|
||||
"the function, we can think of the `bind` first being interprted by the trace\n",
|
||||
"the function, we can think of the `bind` first being interpreted by the trace\n",
|
||||
"at the top of the stack (i.e. with the highest level). If that first\n",
|
||||
"interpreter itself binds other primitives in its interpretation rule for the\n",
|
||||
"primitive, like how the JVP rule of `sin_p` might bind `cos_p` and `mul_p`,\n",
|
||||
@ -283,7 +284,7 @@
|
||||
"\n",
|
||||
" @staticmethod\n",
|
||||
" def _nonzero(tracer):\n",
|
||||
" return nonzero(tracer.aval.val)\n",
|
||||
" return bool(tracer.aval.val)\n",
|
||||
"\n",
|
||||
"def get_aval(x):\n",
|
||||
" if isinstance(x, Tracer):\n",
|
||||
@ -474,7 +475,7 @@
|
||||
"source": [
|
||||
"### Forward-mode autodiff with `jvp`\n",
|
||||
"\n",
|
||||
"First, a couple helper functions:"
|
||||
"First, a couple of helper functions:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -578,7 +579,7 @@
|
||||
"\n",
|
||||
"def reduce_sum_jvp(primals, tangents, *, axis):\n",
|
||||
" (x,), (x_dot,) = primals, tangents\n",
|
||||
" return reduce_sum(x_dot, axis), reduce_sum(x_dot, axis)\n",
|
||||
" return reduce_sum(x, axis), reduce_sum(x_dot, axis)\n",
|
||||
"jvp_rules[reduce_sum_p] = reduce_sum_jvp\n",
|
||||
"\n",
|
||||
"def greater_jvp(primals, tangents):\n",
|
||||
@ -805,7 +806,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here we've implemented the optional `Tracer.full_lower` method, which lets\n",
|
||||
"Here we've implemented the optional `Tracer.full_lower` method, which lets us\n",
|
||||
"peel off a batching tracer if it's not needed because it doesn't represent a\n",
|
||||
"batched value.\n",
|
||||
"\n",
|
||||
@ -965,18 +966,20 @@
|
||||
"carry a little bit of extra context, but for both `jit` and `vjp` we need\n",
|
||||
"much richer context: we need to represent _programs_. That is, we need jaxprs!\n",
|
||||
"\n",
|
||||
"We need a program representation for `jit` because the purpose of `jit` is to\n",
|
||||
"stage computation out of Python. For any computation we want to stage out,\n",
|
||||
"we need to be able to represent it as data, and build it up as we trace a\n",
|
||||
"Python function. Similarly, `vjp` needs a way to represent the computation for\n",
|
||||
"the backward pass of reverse-mode autodiff. We use the same jaxpr program\n",
|
||||
"representation for both needs.\n",
|
||||
"Jaxprs are JAX's internal intermediate representation of programs. Jaxprs are\n",
|
||||
"an explicitly typed, functional, first-order language. We need a program\n",
|
||||
"representation for `jit` because the purpose of `jit` is to stage computation\n",
|
||||
"out of Python. For any computation we want to stage out, we need to be able to\n",
|
||||
"represent it as data, and build it up as we trace a Python function.\n",
|
||||
"Similarly, `vjp` needs a way to represent the computation for the backward\n",
|
||||
"pass of reverse-mode autodiff. We use the same jaxpr program representation\n",
|
||||
"for both needs.\n",
|
||||
"\n",
|
||||
"(Building a program representation is the most\n",
|
||||
"[free](https://en.wikipedia.org/wiki/Free_object) kind of trace-\n",
|
||||
"transformation, and so except for issues around handling native Python control\n",
|
||||
"flow, any transformation could be implemented by first tracing to a jaxpr\n",
|
||||
"and then interpreting the jaxpr.)\n",
|
||||
"[free](https://en.wikipedia.org/wiki/Free_object) kind of\n",
|
||||
"trace- transformation, and so except for issues around handling native Python\n",
|
||||
"control flow, any transformation could be implemented by first tracing to a\n",
|
||||
"jaxpr and then interpreting the jaxpr.)\n",
|
||||
"\n",
|
||||
"The jaxpr term syntax is roughly:\n",
|
||||
"\n",
|
||||
@ -1035,7 +1038,7 @@
|
||||
" primitive: Primitive\n",
|
||||
" inputs: List[Atom]\n",
|
||||
" params: Dict[str, Any]\n",
|
||||
" out_binder: List[Var]\n",
|
||||
" out_binder: Var\n",
|
||||
"\n",
|
||||
"class Jaxpr(NamedTuple):\n",
|
||||
" in_binders: List[Var]\n",
|
||||
|
@ -6,7 +6,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.10.0
|
||||
jupytext_version: 1.10.2
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
@ -40,10 +40,11 @@ atomic units of processing rather than compositions.
|
||||
where we apply primitive functions to numerical inputs to produce numerical
|
||||
outputs, we want to override primitive application and let different values
|
||||
flow through our program. For example, we might want to replace the
|
||||
application of every primitive with type `a -> b` with an application of its
|
||||
JVP rule with type `(a, T a) -> (b, T b)`, and let primal-tangent pairs flow
|
||||
through our program. Moreover, we want to apply a composition of multiple
|
||||
transformations, leading to stacks of interpreters.
|
||||
application of every primitive with an application of [its JVP
|
||||
rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),
|
||||
and let primal-tangent pairs flow through our program. Moreover, we want to
|
||||
apply a composition of multiple transformations, leading to stacks of
|
||||
interpreters.
|
||||
|
||||
+++
|
||||
|
||||
@ -124,7 +125,7 @@ def new_main(trace_type: Type['Trace'], global_data=None):
|
||||
|
||||
When we're about to apply a transformed function, we'll push another
|
||||
interpreter onto the stack using `new_main`. Then, as we apply primitives in
|
||||
the function, we can think of the `bind` first being interprted by the trace
|
||||
the function, we can think of the `bind` first being interpreted by the trace
|
||||
at the top of the stack (i.e. with the highest level). If that first
|
||||
interpreter itself binds other primitives in its interpretation rule for the
|
||||
primitive, like how the JVP rule of `sin_p` might bind `cos_p` and `mul_p`,
|
||||
@ -247,7 +248,7 @@ class ConcreteArray(ShapedArray):
|
||||
|
||||
@staticmethod
|
||||
def _nonzero(tracer):
|
||||
return nonzero(tracer.aval.val)
|
||||
return bool(tracer.aval.val)
|
||||
|
||||
def get_aval(x):
|
||||
if isinstance(x, Tracer):
|
||||
@ -370,7 +371,7 @@ that now we can add some real transformations.
|
||||
|
||||
### Forward-mode autodiff with `jvp`
|
||||
|
||||
First, a couple helper functions:
|
||||
First, a couple of helper functions:
|
||||
|
||||
```{code-cell}
|
||||
def zeros_like(val):
|
||||
@ -445,7 +446,7 @@ jvp_rules[neg_p] = neg_jvp
|
||||
|
||||
def reduce_sum_jvp(primals, tangents, *, axis):
|
||||
(x,), (x_dot,) = primals, tangents
|
||||
return reduce_sum(x_dot, axis), reduce_sum(x_dot, axis)
|
||||
return reduce_sum(x, axis), reduce_sum(x_dot, axis)
|
||||
jvp_rules[reduce_sum_p] = reduce_sum_jvp
|
||||
|
||||
def greater_jvp(primals, tangents):
|
||||
@ -575,7 +576,7 @@ class BatchTrace(Trace):
|
||||
vmap_rules = {}
|
||||
```
|
||||
|
||||
Here we've implemented the optional `Tracer.full_lower` method, which lets
|
||||
Here we've implemented the optional `Tracer.full_lower` method, which lets us
|
||||
peel off a batching tracer if it's not needed because it doesn't represent a
|
||||
batched value.
|
||||
|
||||
@ -678,18 +679,20 @@ wrapper around `vjp`.) For `jvp` and `vmap` we only needed each `Tracer` to
|
||||
carry a little bit of extra context, but for both `jit` and `vjp` we need
|
||||
much richer context: we need to represent _programs_. That is, we need jaxprs!
|
||||
|
||||
We need a program representation for `jit` because the purpose of `jit` is to
|
||||
stage computation out of Python. For any computation we want to stage out,
|
||||
we need to be able to represent it as data, and build it up as we trace a
|
||||
Python function. Similarly, `vjp` needs a way to represent the computation for
|
||||
the backward pass of reverse-mode autodiff. We use the same jaxpr program
|
||||
representation for both needs.
|
||||
Jaxprs are JAX's internal intermediate representation of programs. Jaxprs are
|
||||
an explicitly typed, functional, first-order language. We need a program
|
||||
representation for `jit` because the purpose of `jit` is to stage computation
|
||||
out of Python. For any computation we want to stage out, we need to be able to
|
||||
represent it as data, and build it up as we trace a Python function.
|
||||
Similarly, `vjp` needs a way to represent the computation for the backward
|
||||
pass of reverse-mode autodiff. We use the same jaxpr program representation
|
||||
for both needs.
|
||||
|
||||
(Building a program representation is the most
|
||||
[free](https://en.wikipedia.org/wiki/Free_object) kind of trace-
|
||||
transformation, and so except for issues around handling native Python control
|
||||
flow, any transformation could be implemented by first tracing to a jaxpr
|
||||
and then interpreting the jaxpr.)
|
||||
[free](https://en.wikipedia.org/wiki/Free_object) kind of
|
||||
trace- transformation, and so except for issues around handling native Python
|
||||
control flow, any transformation could be implemented by first tracing to a
|
||||
jaxpr and then interpreting the jaxpr.)
|
||||
|
||||
The jaxpr term syntax is roughly:
|
||||
|
||||
@ -742,7 +745,7 @@ class JaxprEqn(NamedTuple):
|
||||
primitive: Primitive
|
||||
inputs: List[Atom]
|
||||
params: Dict[str, Any]
|
||||
out_binder: List[Var]
|
||||
out_binder: Var
|
||||
|
||||
class Jaxpr(NamedTuple):
|
||||
in_binders: List[Var]
|
||||
|
@ -6,7 +6,7 @@
|
||||
# extension: .py
|
||||
# format_name: light
|
||||
# format_version: '1.5'
|
||||
# jupytext_version: 1.10.2
|
||||
# jupytext_version: 1.10.0
|
||||
# kernelspec:
|
||||
# display_name: Python 3
|
||||
# name: python3
|
||||
@ -38,10 +38,11 @@
|
||||
# where we apply primitive functions to numerical inputs to produce numerical
|
||||
# outputs, we want to override primitive application and let different values
|
||||
# flow through our program. For example, we might want to replace the
|
||||
# application of every primitive with type `a -> b` with an application of its
|
||||
# JVP rule with type `(a, T a) -> (b, T b)`, and let primal-tangent pairs flow
|
||||
# through our program. Moreover, we want to apply a composition of multiple
|
||||
# transformations, leading to stacks of interpreters.
|
||||
# application of every primitive with an application of [its JVP
|
||||
# rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),
|
||||
# and let primal-tangent pairs flow through our program. Moreover, we want to
|
||||
# apply a composition of multiple transformations, leading to stacks of
|
||||
# interpreters.
|
||||
|
||||
# ### JAX core machinery
|
||||
#
|
||||
@ -125,7 +126,7 @@ def new_main(trace_type: Type['Trace'], global_data=None):
|
||||
|
||||
# When we're about to apply a transformed function, we'll push another
|
||||
# interpreter onto the stack using `new_main`. Then, as we apply primitives in
|
||||
# the function, we can think of the `bind` first being interprted by the trace
|
||||
# the function, we can think of the `bind` first being interpreted by the trace
|
||||
# at the top of the stack (i.e. with the highest level). If that first
|
||||
# interpreter itself binds other primitives in its interpretation rule for the
|
||||
# primitive, like how the JVP rule of `sin_p` might bind `cos_p` and `mul_p`,
|
||||
@ -375,7 +376,7 @@ print(f(3.0))
|
||||
|
||||
# ### Forward-mode autodiff with `jvp`
|
||||
#
|
||||
# First, a couple helper functions:
|
||||
# First, a couple of helper functions:
|
||||
|
||||
# +
|
||||
def zeros_like(val):
|
||||
@ -586,7 +587,7 @@ class BatchTrace(Trace):
|
||||
vmap_rules = {}
|
||||
# -
|
||||
|
||||
# Here we've implemented the optional `Tracer.full_lower` method, which lets
|
||||
# Here we've implemented the optional `Tracer.full_lower` method, which lets us
|
||||
# peel off a batching tracer if it's not needed because it doesn't represent a
|
||||
# batched value.
|
||||
#
|
||||
@ -688,18 +689,20 @@ jacfwd(f, np.arange(3.))
|
||||
# carry a little bit of extra context, but for both `jit` and `vjp` we need
|
||||
# much richer context: we need to represent _programs_. That is, we need jaxprs!
|
||||
#
|
||||
# We need a program representation for `jit` because the purpose of `jit` is to
|
||||
# stage computation out of Python. For any computation we want to stage out,
|
||||
# we need to be able to represent it as data, and build it up as we trace a
|
||||
# Python function. Similarly, `vjp` needs a way to represent the computation for
|
||||
# the backward pass of reverse-mode autodiff. We use the same jaxpr program
|
||||
# representation for both needs.
|
||||
# Jaxprs are JAX's internal intermediate representation of programs. Jaxprs are
|
||||
# an explicitly typed, functional, first-order language. We need a program
|
||||
# representation for `jit` because the purpose of `jit` is to stage computation
|
||||
# out of Python. For any computation we want to stage out, we need to be able to
|
||||
# represent it as data, and build it up as we trace a Python function.
|
||||
# Similarly, `vjp` needs a way to represent the computation for the backward
|
||||
# pass of reverse-mode autodiff. We use the same jaxpr program representation
|
||||
# for both needs.
|
||||
#
|
||||
# (Building a program representation is the most
|
||||
# [free](https://en.wikipedia.org/wiki/Free_object) kind of trace-
|
||||
# transformation, and so except for issues around handling native Python control
|
||||
# flow, any transformation could be implemented by first tracing to a jaxpr
|
||||
# and then interpreting the jaxpr.)
|
||||
# [free](https://en.wikipedia.org/wiki/Free_object) kind of
|
||||
# trace- transformation, and so except for issues around handling native Python
|
||||
# control flow, any transformation could be implemented by first tracing to a
|
||||
# jaxpr and then interpreting the jaxpr.)
|
||||
#
|
||||
# The jaxpr term syntax is roughly:
|
||||
#
|
||||
|
Loading…
x
Reference in New Issue
Block a user