mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
add linearize code (needs text)
This commit is contained in:
parent
72a3036b1a
commit
3457696e80
@ -27,10 +27,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 0
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"import pdb, sys, traceback\n",
|
||||
"def info(type, value, tb):\n",
|
||||
" traceback.print_exception(type, value, tb)\n",
|
||||
" pdb.pm()\n",
|
||||
"sys.excepthook = info"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@ -322,6 +328,17 @@
|
||||
" raise AttributeError(f\"{self.__class__.__name__} has no attribute {name}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def swap(f): return lambda x, y: f(y, x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -345,9 +362,9 @@
|
||||
"\n",
|
||||
" _neg = staticmethod(neg)\n",
|
||||
" _add = staticmethod(add)\n",
|
||||
" _radd = staticmethod(add)\n",
|
||||
" _radd = staticmethod(swap(add))\n",
|
||||
" _mul = staticmethod(mul)\n",
|
||||
" _rmul = staticmethod(mul)\n",
|
||||
" _rmul = staticmethod(swap(mul))\n",
|
||||
" _gt = staticmethod(greater)\n",
|
||||
"\n",
|
||||
" @staticmethod\n",
|
||||
@ -407,8 +424,22 @@
|
||||
"def get_aval(x):\n",
|
||||
" if isinstance(x, Tracer):\n",
|
||||
" return x.aval\n",
|
||||
" elif type(x) in jax_types:\n",
|
||||
" return ConcreteArray(np.asarray(x))\n",
|
||||
" else:\n",
|
||||
" return ConcreteArray(np.asarray(x))"
|
||||
" raise TypeError(x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"jax_types = {bool, int, float,\n",
|
||||
" np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -459,7 +490,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from operator import attrgetter"
|
||||
"import operator as op"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -472,7 +503,7 @@
|
||||
"source": [
|
||||
"def find_top_trace(xs) -> Trace:\n",
|
||||
" top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),\n",
|
||||
" default=trace_stack[0], key=attrgetter('level'))\n",
|
||||
" default=trace_stack[0], key=op.attrgetter('level'))\n",
|
||||
" if dynamic_trace and dynamic_trace.level > top_main.level:\n",
|
||||
" top_main = dynamic_trace\n",
|
||||
" return top_main.trace_type(top_main)"
|
||||
@ -526,6 +557,7 @@
|
||||
"source": [
|
||||
"def full_raise(trace: Trace, val: Any) -> Tracer:\n",
|
||||
" if not isinstance(val, Tracer):\n",
|
||||
" assert type(val) in jax_types\n",
|
||||
" return trace.pure(val)\n",
|
||||
" level = trace.main.level\n",
|
||||
" if val._trace.main is trace.main:\n",
|
||||
@ -1457,7 +1489,10 @@
|
||||
"def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in):\n",
|
||||
" (x, y), (x_bdim, y_bdim) = vals_in, dims_in\n",
|
||||
" if x_bdim != y_bdim:\n",
|
||||
" y = move_batch_axis(axis_size, y_bdim, x_bdim, y)\n",
|
||||
" if x_bdim is not_mapped:\n",
|
||||
" x = move_batch_axis(axis_size, x_bdim, y_bdim, x)\n",
|
||||
" else:\n",
|
||||
" y = move_batch_axis(axis_size, y_bdim, x_bdim, y)\n",
|
||||
" return [op(x, y)], [x_bdim]\n",
|
||||
"vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add)\n",
|
||||
"vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul)"
|
||||
@ -1705,6 +1740,9 @@
|
||||
" eqns: List[JaxprEqn]\n",
|
||||
" outs: List[Atom]\n",
|
||||
"\n",
|
||||
" def __hash__(self): return id(self)\n",
|
||||
" __eq__ = op.is_\n",
|
||||
"\n",
|
||||
"def raise_to_shaped(aval):\n",
|
||||
" return ShapedArray(aval.shape, aval.dtype)"
|
||||
]
|
||||
@ -2693,6 +2731,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@lru_cache()\n",
|
||||
"def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:\n",
|
||||
" def jvp_traceable(*primals_and_tangents):\n",
|
||||
" n = len(primals_and_tangents) // 2\n",
|
||||
@ -2714,13 +2753,14 @@
|
||||
"source": [
|
||||
"def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):\n",
|
||||
" del num_consts # Unused.\n",
|
||||
" new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, dims_in)\n",
|
||||
" new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))\n",
|
||||
" outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,\n",
|
||||
" num_consts=len(new_consts))\n",
|
||||
" return outs, [0] * len(outs)\n",
|
||||
"vmap_rules[xla_call_p] = xla_call_vmap_rule\n",
|
||||
"\n",
|
||||
"def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: List[BatchAxis]\n",
|
||||
"@lru_cache()\n",
|
||||
"def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: Tuple[BatchAxis]\n",
|
||||
" ) -> Tuple[Jaxpr, List[Any]]:\n",
|
||||
" vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))\n",
|
||||
" in_avals = [unmapped_aval(axis_size, d, v.aval)\n",
|
||||
@ -2749,6 +2789,7 @@
|
||||
"source": [
|
||||
"@jit\n",
|
||||
"def f(x):\n",
|
||||
" print('tracing!')\n",
|
||||
" y = sin(x) * 2.\n",
|
||||
" z = - y + x\n",
|
||||
" return z\n",
|
||||
@ -2758,6 +2799,8 @@
|
||||
"print(y)\n",
|
||||
"print(ydot)\n",
|
||||
"\n",
|
||||
"y, ydot = jvp(f, (x,), (xdot,)) # 'tracing!' not printed\n",
|
||||
"\n",
|
||||
"ys = vmap(f, (0,))(np.arange(3.))\n",
|
||||
"print(ys)"
|
||||
]
|
||||
@ -2807,7 +2850,9 @@
|
||||
" _mul = staticmethod(mul)\n",
|
||||
" _rmul = staticmethod(mul)\n",
|
||||
" _gt = staticmethod(greater)\n",
|
||||
"input_handlers[DeviceArray] = lambda x: x.buf"
|
||||
"input_handlers[DeviceArray] = lambda x: x.buf\n",
|
||||
"\n",
|
||||
"jax_types.add(DeviceArray)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -2827,6 +2872,317 @@
|
||||
"print(y)\n",
|
||||
"print(ydot)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Part 4: `linearize` and `vjp` (and `grad`!)\n",
|
||||
"\n",
|
||||
"The `linearize` and `vjp` autodiff functions are built on `jvp`, but involve\n",
|
||||
"jaxprs as well. That's because both involve staging out, or delaying,\n",
|
||||
"computation.\n",
|
||||
"\n",
|
||||
"In the case of `linearize`, we want to stage out the linear part of a `jvp`\n",
|
||||
"computation. That is, if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`,\n",
|
||||
"then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, where\n",
|
||||
"```\n",
|
||||
"y, f_lin = linearize(f, x)\n",
|
||||
"y_dot = f_lin(x_dot)\n",
|
||||
"```\n",
|
||||
"gives the same result for `(y, y_dot)` as\n",
|
||||
"```\n",
|
||||
"y, y_dot = jvp(f, (x,), (x_dot,))\n",
|
||||
"```\n",
|
||||
"and where the application of `f_lin` does not redo any of the linearization\n",
|
||||
"work. We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr.\n",
|
||||
"\n",
|
||||
"To build the `f_lin` jaxpr from a JVP, we need to perform partial evaluation:\n",
|
||||
"we evaluate all the primal values as we trace, but stage the tangent\n",
|
||||
"computations into a jaxpr."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def split_half(lst):\n",
|
||||
" n, ragged = divmod(len(lst), 2)\n",
|
||||
" assert not ragged\n",
|
||||
" return lst[:n], lst[n:]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def linearize_flat(f, *primals_in):\n",
|
||||
" pvals_in = ([PartialVal.known(x) for x in primals_in] +\n",
|
||||
" [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])\n",
|
||||
"\n",
|
||||
" def f_jvp(*primals_tangents_in):\n",
|
||||
" primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))\n",
|
||||
" return [*primals_out, *tangents_out]\n",
|
||||
"\n",
|
||||
" jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)\n",
|
||||
" primal_pvals, _ = split_half(pvals_out)\n",
|
||||
" assert all(pval.is_known for pval in primal_pvals)\n",
|
||||
" primals_out = [pval.const for pval in primal_pvals]\n",
|
||||
" f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])\n",
|
||||
" return primals_out, f_lin\n",
|
||||
"\n",
|
||||
"def linearize(f, *primals_in):\n",
|
||||
" primals_in_flat, in_tree = tree_flatten(primals_in)\n",
|
||||
" f, out_tree = flatten_fun(f, in_tree)\n",
|
||||
" primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)\n",
|
||||
" primals_out = tree_unflatten(out_tree(), primals_out_flat)\n",
|
||||
"\n",
|
||||
" def f_lin(*tangents_in):\n",
|
||||
" tangents_in_flat, in_tree2 = tree_flatten(tangents_in)\n",
|
||||
" if in_tree != in_tree2: raise TypeError\n",
|
||||
" tangents_out_flat = f_lin_flat(*tangents_in_flat)\n",
|
||||
" return tree_unflatten(out_tree(), tangents_out_flat)\n",
|
||||
"\n",
|
||||
" return primals_out, f_lin\n",
|
||||
"\n",
|
||||
"def vspace(aval: ShapedArray) -> ShapedArray:\n",
|
||||
" return raise_to_shaped(aval)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class PartialVal(NamedTuple):\n",
|
||||
" aval: ShapedArray\n",
|
||||
" const: Optional[Any]\n",
|
||||
"\n",
|
||||
" @classmethod\n",
|
||||
" def known(cls, val: Any):\n",
|
||||
" return PartialVal(get_aval(val), val)\n",
|
||||
"\n",
|
||||
" @classmethod\n",
|
||||
" def unknown(cls, aval: ShapedArray):\n",
|
||||
" return PartialVal(aval, None)\n",
|
||||
"\n",
|
||||
" is_known = property(lambda self: self.const is not None)\n",
|
||||
" is_unknown = property(lambda self: self.const is None)\n",
|
||||
"\n",
|
||||
"def partial_eval_flat(f, pvals_in: List[PartialVal]):\n",
|
||||
" with new_main(PartialEvalTrace) as main:\n",
|
||||
" trace = PartialEvalTrace(main)\n",
|
||||
" tracers_in = [trace.new_arg(pval) for pval in pvals_in]\n",
|
||||
" outs = f(*tracers_in)\n",
|
||||
" tracers_out = [full_raise(trace, out) for out in outs]\n",
|
||||
" jaxpr, consts = tracers_to_jaxpr(tracers_in, tracers_out)\n",
|
||||
" pvals_out = [t.pval for t in tracers_out]\n",
|
||||
" return jaxpr, pvals_out, consts"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from weakref import ref, ReferenceType\n",
|
||||
"\n",
|
||||
"class LambdaBindingRecipe(NamedTuple): pass\n",
|
||||
"\n",
|
||||
"class ConstRecipe(NamedTuple):\n",
|
||||
" val: Any\n",
|
||||
"\n",
|
||||
"class JaxprEqnRecipe:\n",
|
||||
" prim: Primitive\n",
|
||||
" tracers_in: List['PartialEvalTracer']\n",
|
||||
" params: Dict[str, Any]\n",
|
||||
" avals_out: List[ShapedArray]\n",
|
||||
" tracer_refs_out: List['ReferenceType[PartialEvalTracer]']\n",
|
||||
"\n",
|
||||
" def __init__(self, prim, tracers_in, params, avals_out, tracer_refs_out):\n",
|
||||
" self.prim = prim\n",
|
||||
" self.tracers_in = tracers_in\n",
|
||||
" self.params = params\n",
|
||||
" self.avals_out = avals_out\n",
|
||||
" self.tracer_refs_out = tracer_refs_out\n",
|
||||
"\n",
|
||||
"JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]\n",
|
||||
"\n",
|
||||
"class PartialEvalTracer(Tracer):\n",
|
||||
" pval: PartialVal\n",
|
||||
" recipe: JaxprRecipe\n",
|
||||
"\n",
|
||||
" def __init__(self, trace, pval, recipe):\n",
|
||||
" self._trace = trace\n",
|
||||
" self.pval = pval\n",
|
||||
" self.recipe = recipe\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def aval(self):\n",
|
||||
" return self.pval.aval\n",
|
||||
"\n",
|
||||
" def full_lower(self):\n",
|
||||
" if self.pval.is_known:\n",
|
||||
" return full_lower(self.pval.const)\n",
|
||||
" return self\n",
|
||||
"\n",
|
||||
"class PartialEvalTrace(Trace):\n",
|
||||
" def new_arg(self, pval: PartialVal) -> Any:\n",
|
||||
" return PartialEvalTracer(self, pval, LambdaBindingRecipe())\n",
|
||||
"\n",
|
||||
" def lift(self, val: Any) -> PartialEvalTracer:\n",
|
||||
" return PartialEvalTracer(self, PartialVal.known(val), None)\n",
|
||||
" pure = lift\n",
|
||||
"\n",
|
||||
" def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:\n",
|
||||
" if tracer.pval.is_unknown:\n",
|
||||
" return tracer\n",
|
||||
" else:\n",
|
||||
" pval = PartialVal.unknown(raise_to_shaped(tracer.aval))\n",
|
||||
" return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))\n",
|
||||
"\n",
|
||||
" def process_primitive(self, primitive, tracers, params):\n",
|
||||
" if all(t.pval.is_known for t in tracers):\n",
|
||||
" return bind(primitive, *map(full_lower, tracers), **params)\n",
|
||||
" tracers_in = [self.instantiate_const(t) for t in tracers]\n",
|
||||
" avals_in = [t.aval for t in tracers_in]\n",
|
||||
" avals_out = abstract_eval_rules[primitive](*avals_in, **params)\n",
|
||||
" tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)\n",
|
||||
" for aval in avals_out]\n",
|
||||
" eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,\n",
|
||||
" map(ref, tracers_out))\n",
|
||||
" for t in tracers_out: t.recipe = eqn\n",
|
||||
" return tracers_out"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer],\n",
|
||||
" tracers_out: List[PartialEvalTracer]):\n",
|
||||
" tracers_in = [t for t in tracers_in if t.pval.is_unknown]\n",
|
||||
" tracers_out = [t for t in tracers_out if t.pval.is_unknown]\n",
|
||||
"\n",
|
||||
" tracer_to_var = {id(t): Var(raise_to_shaped(t.aval)) for t in tracers_in}\n",
|
||||
" constvar_to_val = {}\n",
|
||||
" constid_to_var = {}\n",
|
||||
" processed_eqns = set()\n",
|
||||
" eqns = []\n",
|
||||
" for t in toposort(tracers_out, tracer_parents):\n",
|
||||
" if isinstance(t.recipe, LambdaBindingRecipe):\n",
|
||||
" assert id(t) in set(map(id, tracers_in))\n",
|
||||
" elif isinstance(t.recipe, ConstRecipe):\n",
|
||||
" val = t.recipe.val\n",
|
||||
" var = constid_to_var.get(id(val))\n",
|
||||
" if var is None:\n",
|
||||
" aval = raise_to_shaped(get_aval(val))\n",
|
||||
" var = tracer_to_var[id(t)] = constid_to_var[id(val)] = Var(aval)\n",
|
||||
" constvar_to_val[var] = val\n",
|
||||
" elif isinstance(t.recipe, JaxprEqnRecipe):\n",
|
||||
" if id(t.recipe) not in processed_eqns:\n",
|
||||
" eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))\n",
|
||||
" processed_eqns.add(id(t.recipe))\n",
|
||||
" else:\n",
|
||||
" raise TypeError(t.recipe)\n",
|
||||
"\n",
|
||||
" constvars, constvals = unzip2(constvar_to_val.items())\n",
|
||||
" in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]\n",
|
||||
" out_vars = [tracer_to_var[id(t)] for t in tracers_out]\n",
|
||||
" jaxpr = Jaxpr(in_binders, eqns, out_vars)\n",
|
||||
" typecheck_jaxpr(jaxpr)\n",
|
||||
" return jaxpr, constvals\n",
|
||||
"\n",
|
||||
"def recipe_to_eqn(tracer_to_var: Dict[int, Var], recipe: JaxprEqnRecipe\n",
|
||||
" ) -> JaxprEqn:\n",
|
||||
" inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]\n",
|
||||
" out_binders = [Var(aval) for aval in recipe.avals_out]\n",
|
||||
" for t_ref, var in zip(recipe.tracer_refs_out, out_binders):\n",
|
||||
" if t_ref() is not None: tracer_to_var[id(t_ref())] = var\n",
|
||||
" return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)\n",
|
||||
"\n",
|
||||
"def tracer_parents(t: PartialEvalTracer) -> List[PartialEvalTracer]:\n",
|
||||
" return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]):\n",
|
||||
" if not out_nodes: return []\n",
|
||||
" out_nodes = remove_duplicates(out_nodes)\n",
|
||||
"\n",
|
||||
" child_counts = {}\n",
|
||||
" stack = list(out_nodes)\n",
|
||||
" while stack:\n",
|
||||
" node = stack.pop()\n",
|
||||
" if id(node) in child_counts:\n",
|
||||
" child_counts[id(node)] += 1\n",
|
||||
" else:\n",
|
||||
" child_counts[id(node)] = 1\n",
|
||||
" stack.extend(parents(node))\n",
|
||||
" for node in out_nodes:\n",
|
||||
" child_counts[id(node)] -= 1\n",
|
||||
"\n",
|
||||
" sorted_nodes = []\n",
|
||||
" childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]\n",
|
||||
" while childless_nodes:\n",
|
||||
" node = childless_nodes.pop()\n",
|
||||
" sorted_nodes.append(node)\n",
|
||||
" for parent in parents(node):\n",
|
||||
" if child_counts[id(parent)] == 1:\n",
|
||||
" childless_nodes.append(parent)\n",
|
||||
" else:\n",
|
||||
" child_counts[id(parent)] -= 1\n",
|
||||
"\n",
|
||||
" sorted_nodes = sorted_nodes[::-1]\n",
|
||||
" check_toposort(sorted_nodes, parents)\n",
|
||||
" return sorted_nodes\n",
|
||||
"\n",
|
||||
"def remove_duplicates(lst):\n",
|
||||
" seen = set()\n",
|
||||
" return [x for x in lst if id(x) not in seen and not seen.add(id(x))]\n",
|
||||
"\n",
|
||||
"def check_toposort(nodes: List[Any], parents: Callable[[Any], List[Any]]):\n",
|
||||
" seen = set()\n",
|
||||
" for node in nodes:\n",
|
||||
" assert all(id(parent) in seen for parent in parents(node))\n",
|
||||
" seen.add(id(node))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y, sin_lin = linearize(sin, 3.)\n",
|
||||
"print(y, sin(3.))\n",
|
||||
"print(sin_lin(1.), cos(3.))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -33,7 +33,11 @@ limitations under the License.
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
|
||||
import pdb, sys, traceback
|
||||
def info(type, value, tb):
|
||||
traceback.print_exception(type, value, tb)
|
||||
pdb.pm()
|
||||
sys.excepthook = info
|
||||
```
|
||||
|
||||
# Autodidax: JAX core from scratch
|
||||
@ -241,6 +245,10 @@ class Tracer:
|
||||
raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
def swap(f): return lambda x, y: f(y, x)
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
class ShapedArray:
|
||||
array_abstraction_level = 1
|
||||
@ -257,9 +265,9 @@ class ShapedArray:
|
||||
|
||||
_neg = staticmethod(neg)
|
||||
_add = staticmethod(add)
|
||||
_radd = staticmethod(add)
|
||||
_radd = staticmethod(swap(add))
|
||||
_mul = staticmethod(mul)
|
||||
_rmul = staticmethod(mul)
|
||||
_rmul = staticmethod(swap(mul))
|
||||
_gt = staticmethod(greater)
|
||||
|
||||
@staticmethod
|
||||
@ -304,8 +312,15 @@ class ConcreteArray(ShapedArray):
|
||||
def get_aval(x):
|
||||
if isinstance(x, Tracer):
|
||||
return x.aval
|
||||
else:
|
||||
elif type(x) in jax_types:
|
||||
return ConcreteArray(np.asarray(x))
|
||||
else:
|
||||
raise TypeError(x)
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
jax_types = {bool, int, float,
|
||||
np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}
|
||||
```
|
||||
|
||||
Notice that we actually have two `AbstractValue`s for arrays, representing
|
||||
@ -332,13 +347,13 @@ top trace's `Tracer` instances, and the call to `full_lower` is an optional
|
||||
optimization so that we unbox values out of `Tracer`s as much as possible.
|
||||
|
||||
```{code-cell} ipython3
|
||||
from operator import attrgetter
|
||||
import operator as op
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
def find_top_trace(xs) -> Trace:
|
||||
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
|
||||
default=trace_stack[0], key=attrgetter('level'))
|
||||
default=trace_stack[0], key=op.attrgetter('level'))
|
||||
if dynamic_trace and dynamic_trace.level > top_main.level:
|
||||
top_main = dynamic_trace
|
||||
return top_main.trace_type(top_main)
|
||||
@ -373,6 +388,7 @@ def full_lower(val: Any):
|
||||
```{code-cell} ipython3
|
||||
def full_raise(trace: Trace, val: Any) -> Tracer:
|
||||
if not isinstance(val, Tracer):
|
||||
assert type(val) in jax_types
|
||||
return trace.pure(val)
|
||||
level = trace.main.level
|
||||
if val._trace.main is trace.main:
|
||||
@ -875,7 +891,10 @@ from functools import partial
|
||||
def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in):
|
||||
(x, y), (x_bdim, y_bdim) = vals_in, dims_in
|
||||
if x_bdim != y_bdim:
|
||||
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
|
||||
if x_bdim is not_mapped:
|
||||
x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
|
||||
else:
|
||||
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
|
||||
return [op(x, y)], [x_bdim]
|
||||
vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add)
|
||||
vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul)
|
||||
@ -1058,6 +1077,9 @@ class Jaxpr(NamedTuple):
|
||||
eqns: List[JaxprEqn]
|
||||
outs: List[Atom]
|
||||
|
||||
def __hash__(self): return id(self)
|
||||
__eq__ = op.is_
|
||||
|
||||
def raise_to_shaped(aval):
|
||||
return ShapedArray(aval.shape, aval.dtype)
|
||||
```
|
||||
@ -1743,6 +1765,7 @@ jvp_rules[xla_call_p] = xla_call_jvp_rule
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
@lru_cache()
|
||||
def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:
|
||||
def jvp_traceable(*primals_and_tangents):
|
||||
n = len(primals_and_tangents) // 2
|
||||
@ -1757,13 +1780,14 @@ def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:
|
||||
```{code-cell} ipython3
|
||||
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
|
||||
del num_consts # Unused.
|
||||
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, dims_in)
|
||||
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
|
||||
outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
|
||||
num_consts=len(new_consts))
|
||||
return outs, [0] * len(outs)
|
||||
vmap_rules[xla_call_p] = xla_call_vmap_rule
|
||||
|
||||
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: List[BatchAxis]
|
||||
@lru_cache()
|
||||
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: Tuple[BatchAxis]
|
||||
) -> Tuple[Jaxpr, List[Any]]:
|
||||
vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
|
||||
in_avals = [unmapped_aval(axis_size, d, v.aval)
|
||||
@ -1784,6 +1808,7 @@ def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
|
||||
```{code-cell} ipython3
|
||||
@jit
|
||||
def f(x):
|
||||
print('tracing!')
|
||||
y = sin(x) * 2.
|
||||
z = - y + x
|
||||
return z
|
||||
@ -1793,6 +1818,8 @@ y, ydot = jvp(f, (x,), (xdot,))
|
||||
print(y)
|
||||
print(ydot)
|
||||
|
||||
y, ydot = jvp(f, (x,), (xdot,)) # 'tracing!' not printed
|
||||
|
||||
ys = vmap(f, (0,))(np.arange(3.))
|
||||
print(ys)
|
||||
```
|
||||
@ -1831,6 +1858,8 @@ class DeviceArray:
|
||||
_rmul = staticmethod(mul)
|
||||
_gt = staticmethod(greater)
|
||||
input_handlers[DeviceArray] = lambda x: x.buf
|
||||
|
||||
jax_types.add(DeviceArray)
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
@ -1845,3 +1874,262 @@ y, ydot = jvp(f, (x,), (xdot,))
|
||||
print(y)
|
||||
print(ydot)
|
||||
```
|
||||
|
||||
## Part 4: `linearize` and `vjp` (and `grad`!)
|
||||
|
||||
The `linearize` and `vjp` autodiff functions are built on `jvp`, but involve
|
||||
jaxprs as well. That's because both involve staging out, or delaying,
|
||||
computation.
|
||||
|
||||
In the case of `linearize`, we want to stage out the linear part of a `jvp`
|
||||
computation. That is, if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`,
|
||||
then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, where
|
||||
```
|
||||
y, f_lin = linearize(f, x)
|
||||
y_dot = f_lin(x_dot)
|
||||
```
|
||||
gives the same result for `(y, y_dot)` as
|
||||
```
|
||||
y, y_dot = jvp(f, (x,), (x_dot,))
|
||||
```
|
||||
and where the application of `f_lin` does not redo any of the linearization
|
||||
work. We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr.
|
||||
|
||||
To build the `f_lin` jaxpr from a JVP, we need to perform partial evaluation:
|
||||
we evaluate all the primal values as we trace, but stage the tangent
|
||||
computations into a jaxpr.
|
||||
|
||||
```{code-cell} ipython3
|
||||
def split_half(lst):
|
||||
n, ragged = divmod(len(lst), 2)
|
||||
assert not ragged
|
||||
return lst[:n], lst[n:]
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
def linearize_flat(f, *primals_in):
|
||||
pvals_in = ([PartialVal.known(x) for x in primals_in] +
|
||||
[PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
|
||||
|
||||
def f_jvp(*primals_tangents_in):
|
||||
primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
|
||||
return [*primals_out, *tangents_out]
|
||||
|
||||
jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)
|
||||
primal_pvals, _ = split_half(pvals_out)
|
||||
assert all(pval.is_known for pval in primal_pvals)
|
||||
primals_out = [pval.const for pval in primal_pvals]
|
||||
f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])
|
||||
return primals_out, f_lin
|
||||
|
||||
def linearize(f, *primals_in):
|
||||
primals_in_flat, in_tree = tree_flatten(primals_in)
|
||||
f, out_tree = flatten_fun(f, in_tree)
|
||||
primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)
|
||||
primals_out = tree_unflatten(out_tree(), primals_out_flat)
|
||||
|
||||
def f_lin(*tangents_in):
|
||||
tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
|
||||
if in_tree != in_tree2: raise TypeError
|
||||
tangents_out_flat = f_lin_flat(*tangents_in_flat)
|
||||
return tree_unflatten(out_tree(), tangents_out_flat)
|
||||
|
||||
return primals_out, f_lin
|
||||
|
||||
def vspace(aval: ShapedArray) -> ShapedArray:
|
||||
return raise_to_shaped(aval)
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
class PartialVal(NamedTuple):
|
||||
aval: ShapedArray
|
||||
const: Optional[Any]
|
||||
|
||||
@classmethod
|
||||
def known(cls, val: Any):
|
||||
return PartialVal(get_aval(val), val)
|
||||
|
||||
@classmethod
|
||||
def unknown(cls, aval: ShapedArray):
|
||||
return PartialVal(aval, None)
|
||||
|
||||
is_known = property(lambda self: self.const is not None)
|
||||
is_unknown = property(lambda self: self.const is None)
|
||||
|
||||
def partial_eval_flat(f, pvals_in: List[PartialVal]):
|
||||
with new_main(PartialEvalTrace) as main:
|
||||
trace = PartialEvalTrace(main)
|
||||
tracers_in = [trace.new_arg(pval) for pval in pvals_in]
|
||||
outs = f(*tracers_in)
|
||||
tracers_out = [full_raise(trace, out) for out in outs]
|
||||
jaxpr, consts = tracers_to_jaxpr(tracers_in, tracers_out)
|
||||
pvals_out = [t.pval for t in tracers_out]
|
||||
return jaxpr, pvals_out, consts
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
from weakref import ref, ReferenceType
|
||||
|
||||
class LambdaBindingRecipe(NamedTuple): pass
|
||||
|
||||
class ConstRecipe(NamedTuple):
|
||||
val: Any
|
||||
|
||||
class JaxprEqnRecipe:
|
||||
prim: Primitive
|
||||
tracers_in: List['PartialEvalTracer']
|
||||
params: Dict[str, Any]
|
||||
avals_out: List[ShapedArray]
|
||||
tracer_refs_out: List['ReferenceType[PartialEvalTracer]']
|
||||
|
||||
def __init__(self, prim, tracers_in, params, avals_out, tracer_refs_out):
|
||||
self.prim = prim
|
||||
self.tracers_in = tracers_in
|
||||
self.params = params
|
||||
self.avals_out = avals_out
|
||||
self.tracer_refs_out = tracer_refs_out
|
||||
|
||||
JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
|
||||
|
||||
class PartialEvalTracer(Tracer):
|
||||
pval: PartialVal
|
||||
recipe: JaxprRecipe
|
||||
|
||||
def __init__(self, trace, pval, recipe):
|
||||
self._trace = trace
|
||||
self.pval = pval
|
||||
self.recipe = recipe
|
||||
|
||||
@property
|
||||
def aval(self):
|
||||
return self.pval.aval
|
||||
|
||||
def full_lower(self):
|
||||
if self.pval.is_known:
|
||||
return full_lower(self.pval.const)
|
||||
return self
|
||||
|
||||
class PartialEvalTrace(Trace):
|
||||
def new_arg(self, pval: PartialVal) -> Any:
|
||||
return PartialEvalTracer(self, pval, LambdaBindingRecipe())
|
||||
|
||||
def lift(self, val: Any) -> PartialEvalTracer:
|
||||
return PartialEvalTracer(self, PartialVal.known(val), None)
|
||||
pure = lift
|
||||
|
||||
def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:
|
||||
if tracer.pval.is_unknown:
|
||||
return tracer
|
||||
else:
|
||||
pval = PartialVal.unknown(raise_to_shaped(tracer.aval))
|
||||
return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
if all(t.pval.is_known for t in tracers):
|
||||
return bind(primitive, *map(full_lower, tracers), **params)
|
||||
tracers_in = [self.instantiate_const(t) for t in tracers]
|
||||
avals_in = [t.aval for t in tracers_in]
|
||||
avals_out = abstract_eval_rules[primitive](*avals_in, **params)
|
||||
tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)
|
||||
for aval in avals_out]
|
||||
eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,
|
||||
map(ref, tracers_out))
|
||||
for t in tracers_out: t.recipe = eqn
|
||||
return tracers_out
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer],
|
||||
tracers_out: List[PartialEvalTracer]):
|
||||
tracers_in = [t for t in tracers_in if t.pval.is_unknown]
|
||||
tracers_out = [t for t in tracers_out if t.pval.is_unknown]
|
||||
|
||||
tracer_to_var = {id(t): Var(raise_to_shaped(t.aval)) for t in tracers_in}
|
||||
constvar_to_val = {}
|
||||
constid_to_var = {}
|
||||
processed_eqns = set()
|
||||
eqns = []
|
||||
for t in toposort(tracers_out, tracer_parents):
|
||||
if isinstance(t.recipe, LambdaBindingRecipe):
|
||||
assert id(t) in set(map(id, tracers_in))
|
||||
elif isinstance(t.recipe, ConstRecipe):
|
||||
val = t.recipe.val
|
||||
var = constid_to_var.get(id(val))
|
||||
if var is None:
|
||||
aval = raise_to_shaped(get_aval(val))
|
||||
var = tracer_to_var[id(t)] = constid_to_var[id(val)] = Var(aval)
|
||||
constvar_to_val[var] = val
|
||||
elif isinstance(t.recipe, JaxprEqnRecipe):
|
||||
if id(t.recipe) not in processed_eqns:
|
||||
eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))
|
||||
processed_eqns.add(id(t.recipe))
|
||||
else:
|
||||
raise TypeError(t.recipe)
|
||||
|
||||
constvars, constvals = unzip2(constvar_to_val.items())
|
||||
in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]
|
||||
out_vars = [tracer_to_var[id(t)] for t in tracers_out]
|
||||
jaxpr = Jaxpr(in_binders, eqns, out_vars)
|
||||
typecheck_jaxpr(jaxpr)
|
||||
return jaxpr, constvals
|
||||
|
||||
def recipe_to_eqn(tracer_to_var: Dict[int, Var], recipe: JaxprEqnRecipe
|
||||
) -> JaxprEqn:
|
||||
inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]
|
||||
out_binders = [Var(aval) for aval in recipe.avals_out]
|
||||
for t_ref, var in zip(recipe.tracer_refs_out, out_binders):
|
||||
if t_ref() is not None: tracer_to_var[id(t_ref())] = var
|
||||
return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)
|
||||
|
||||
def tracer_parents(t: PartialEvalTracer) -> List[PartialEvalTracer]:
|
||||
return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]):
|
||||
if not out_nodes: return []
|
||||
out_nodes = remove_duplicates(out_nodes)
|
||||
|
||||
child_counts = {}
|
||||
stack = list(out_nodes)
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if id(node) in child_counts:
|
||||
child_counts[id(node)] += 1
|
||||
else:
|
||||
child_counts[id(node)] = 1
|
||||
stack.extend(parents(node))
|
||||
for node in out_nodes:
|
||||
child_counts[id(node)] -= 1
|
||||
|
||||
sorted_nodes = []
|
||||
childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]
|
||||
while childless_nodes:
|
||||
node = childless_nodes.pop()
|
||||
sorted_nodes.append(node)
|
||||
for parent in parents(node):
|
||||
if child_counts[id(parent)] == 1:
|
||||
childless_nodes.append(parent)
|
||||
else:
|
||||
child_counts[id(parent)] -= 1
|
||||
|
||||
sorted_nodes = sorted_nodes[::-1]
|
||||
check_toposort(sorted_nodes, parents)
|
||||
return sorted_nodes
|
||||
|
||||
def remove_duplicates(lst):
|
||||
seen = set()
|
||||
return [x for x in lst if id(x) not in seen and not seen.add(id(x))]
|
||||
|
||||
def check_toposort(nodes: List[Any], parents: Callable[[Any], List[Any]]):
|
||||
seen = set()
|
||||
for node in nodes:
|
||||
assert all(id(parent) in seen for parent in parents(node))
|
||||
seen.add(id(node))
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
y, sin_lin = linearize(sin, 3.)
|
||||
print(y, sin(3.))
|
||||
print(sin_lin(1.), cos(3.))
|
||||
```
|
||||
|
@ -26,6 +26,12 @@
|
||||
# name: python3
|
||||
# ---
|
||||
|
||||
import pdb, sys, traceback
|
||||
def info(type, value, tb):
|
||||
traceback.print_exception(type, value, tb)
|
||||
pdb.pm()
|
||||
sys.excepthook = info
|
||||
|
||||
|
||||
# # Autodidax: JAX core from scratch
|
||||
#
|
||||
@ -215,6 +221,8 @@ class Tracer:
|
||||
except AttributeError:
|
||||
raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")
|
||||
|
||||
def swap(f): return lambda x, y: f(y, x)
|
||||
|
||||
class ShapedArray:
|
||||
array_abstraction_level = 1
|
||||
shape: Tuple[int]
|
||||
@ -230,9 +238,9 @@ class ShapedArray:
|
||||
|
||||
_neg = staticmethod(neg)
|
||||
_add = staticmethod(add)
|
||||
_radd = staticmethod(add)
|
||||
_radd = staticmethod(swap(add))
|
||||
_mul = staticmethod(mul)
|
||||
_rmul = staticmethod(mul)
|
||||
_rmul = staticmethod(swap(mul))
|
||||
_gt = staticmethod(greater)
|
||||
|
||||
@staticmethod
|
||||
@ -273,8 +281,13 @@ class ConcreteArray(ShapedArray):
|
||||
def get_aval(x):
|
||||
if isinstance(x, Tracer):
|
||||
return x.aval
|
||||
else:
|
||||
elif type(x) in jax_types:
|
||||
return ConcreteArray(np.asarray(x))
|
||||
else:
|
||||
raise TypeError(x)
|
||||
|
||||
jax_types = {bool, int, float,
|
||||
np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}
|
||||
|
||||
# Notice that we actually have two `AbstractValue`s for arrays, representing
|
||||
# different levels of abstraction. A `ShapedArray` represents the set of all
|
||||
@ -297,11 +310,11 @@ def bind(prim, *args, **params):
|
||||
# top trace's `Tracer` instances, and the call to `full_lower` is an optional
|
||||
# optimization so that we unbox values out of `Tracer`s as much as possible.
|
||||
|
||||
from operator import attrgetter
|
||||
import operator as op
|
||||
|
||||
def find_top_trace(xs) -> Trace:
|
||||
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
|
||||
default=trace_stack[0], key=attrgetter('level'))
|
||||
default=trace_stack[0], key=op.attrgetter('level'))
|
||||
if dynamic_trace and dynamic_trace.level > top_main.level:
|
||||
top_main = dynamic_trace
|
||||
return top_main.trace_type(top_main)
|
||||
@ -332,6 +345,7 @@ def full_lower(val: Any):
|
||||
|
||||
def full_raise(trace: Trace, val: Any) -> Tracer:
|
||||
if not isinstance(val, Tracer):
|
||||
assert type(val) in jax_types
|
||||
return trace.pure(val)
|
||||
level = trace.main.level
|
||||
if val._trace.main is trace.main:
|
||||
@ -749,7 +763,10 @@ from functools import partial
|
||||
def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in):
|
||||
(x, y), (x_bdim, y_bdim) = vals_in, dims_in
|
||||
if x_bdim != y_bdim:
|
||||
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
|
||||
if x_bdim is not_mapped:
|
||||
x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
|
||||
else:
|
||||
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
|
||||
return [op(x, y)], [x_bdim]
|
||||
vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add)
|
||||
vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul)
|
||||
@ -919,6 +936,9 @@ class Jaxpr(NamedTuple):
|
||||
eqns: List[JaxprEqn]
|
||||
outs: List[Atom]
|
||||
|
||||
def __hash__(self): return id(self)
|
||||
__eq__ = op.is_
|
||||
|
||||
def raise_to_shaped(aval):
|
||||
return ShapedArray(aval.shape, aval.dtype)
|
||||
# -
|
||||
@ -1551,6 +1571,7 @@ def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
|
||||
return primals_out, tangents_out
|
||||
jvp_rules[xla_call_p] = xla_call_jvp_rule
|
||||
|
||||
@lru_cache()
|
||||
def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:
|
||||
def jvp_traceable(*primals_and_tangents):
|
||||
n = len(primals_and_tangents) // 2
|
||||
@ -1564,13 +1585,14 @@ def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:
|
||||
# +
|
||||
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
|
||||
del num_consts # Unused.
|
||||
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, dims_in)
|
||||
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
|
||||
outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
|
||||
num_consts=len(new_consts))
|
||||
return outs, [0] * len(outs)
|
||||
vmap_rules[xla_call_p] = xla_call_vmap_rule
|
||||
|
||||
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: List[BatchAxis]
|
||||
@lru_cache()
|
||||
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: Tuple[BatchAxis]
|
||||
) -> Tuple[Jaxpr, List[Any]]:
|
||||
vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
|
||||
in_avals = [unmapped_aval(axis_size, d, v.aval)
|
||||
@ -1591,6 +1613,7 @@ def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
|
||||
# +
|
||||
@jit
|
||||
def f(x):
|
||||
print('tracing!')
|
||||
y = sin(x) * 2.
|
||||
z = - y + x
|
||||
return z
|
||||
@ -1600,6 +1623,8 @@ y, ydot = jvp(f, (x,), (xdot,))
|
||||
print(y)
|
||||
print(ydot)
|
||||
|
||||
y, ydot = jvp(f, (x,), (xdot,)) # 'tracing!' not printed
|
||||
|
||||
ys = vmap(f, (0,))(np.arange(3.))
|
||||
print(ys)
|
||||
# -
|
||||
@ -1640,6 +1665,8 @@ class DeviceArray:
|
||||
_gt = staticmethod(greater)
|
||||
input_handlers[DeviceArray] = lambda x: x.buf
|
||||
|
||||
jax_types.add(DeviceArray)
|
||||
|
||||
# +
|
||||
@jit
|
||||
def f(x):
|
||||
@ -1651,3 +1678,259 @@ x, xdot = 3., 1.
|
||||
y, ydot = jvp(f, (x,), (xdot,))
|
||||
print(y)
|
||||
print(ydot)
|
||||
|
||||
|
||||
# -
|
||||
|
||||
# ## Part 4: `linearize` and `vjp` (and `grad`!)
|
||||
#
|
||||
# The `linearize` and `vjp` autodiff functions are built on `jvp`, but involve
|
||||
# jaxprs as well. That's because both involve staging out, or delaying,
|
||||
# computation.
|
||||
#
|
||||
# In the case of `linearize`, we want to stage out the linear part of a `jvp`
|
||||
# computation. That is, if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`,
|
||||
# then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, where
|
||||
# ```
|
||||
# y, f_lin = linearize(f, x)
|
||||
# y_dot = f_lin(x_dot)
|
||||
# ```
|
||||
# gives the same result for `(y, y_dot)` as
|
||||
# ```
|
||||
# y, y_dot = jvp(f, (x,), (x_dot,))
|
||||
# ```
|
||||
# and where the application of `f_lin` does not redo any of the linearization
|
||||
# work. We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr.
|
||||
#
|
||||
# To build the `f_lin` jaxpr from a JVP, we need to perform partial evaluation:
|
||||
# we evaluate all the primal values as we trace, but stage the tangent
|
||||
# computations into a jaxpr.
|
||||
|
||||
def split_half(lst):
|
||||
n, ragged = divmod(len(lst), 2)
|
||||
assert not ragged
|
||||
return lst[:n], lst[n:]
|
||||
|
||||
# +
|
||||
def linearize_flat(f, *primals_in):
|
||||
pvals_in = ([PartialVal.known(x) for x in primals_in] +
|
||||
[PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
|
||||
|
||||
def f_jvp(*primals_tangents_in):
|
||||
primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
|
||||
return [*primals_out, *tangents_out]
|
||||
|
||||
jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)
|
||||
primal_pvals, _ = split_half(pvals_out)
|
||||
assert all(pval.is_known for pval in primal_pvals)
|
||||
primals_out = [pval.const for pval in primal_pvals]
|
||||
f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])
|
||||
return primals_out, f_lin
|
||||
|
||||
def linearize(f, *primals_in):
|
||||
primals_in_flat, in_tree = tree_flatten(primals_in)
|
||||
f, out_tree = flatten_fun(f, in_tree)
|
||||
primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)
|
||||
primals_out = tree_unflatten(out_tree(), primals_out_flat)
|
||||
|
||||
def f_lin(*tangents_in):
|
||||
tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
|
||||
if in_tree != in_tree2: raise TypeError
|
||||
tangents_out_flat = f_lin_flat(*tangents_in_flat)
|
||||
return tree_unflatten(out_tree(), tangents_out_flat)
|
||||
|
||||
return primals_out, f_lin
|
||||
|
||||
def vspace(aval: ShapedArray) -> ShapedArray:
|
||||
return raise_to_shaped(aval)
|
||||
|
||||
# +
|
||||
class PartialVal(NamedTuple):
|
||||
aval: ShapedArray
|
||||
const: Optional[Any]
|
||||
|
||||
@classmethod
|
||||
def known(cls, val: Any):
|
||||
return PartialVal(get_aval(val), val)
|
||||
|
||||
@classmethod
|
||||
def unknown(cls, aval: ShapedArray):
|
||||
return PartialVal(aval, None)
|
||||
|
||||
is_known = property(lambda self: self.const is not None)
|
||||
is_unknown = property(lambda self: self.const is None)
|
||||
|
||||
def partial_eval_flat(f, pvals_in: List[PartialVal]):
|
||||
with new_main(PartialEvalTrace) as main:
|
||||
trace = PartialEvalTrace(main)
|
||||
tracers_in = [trace.new_arg(pval) for pval in pvals_in]
|
||||
outs = f(*tracers_in)
|
||||
tracers_out = [full_raise(trace, out) for out in outs]
|
||||
jaxpr, consts = tracers_to_jaxpr(tracers_in, tracers_out)
|
||||
pvals_out = [t.pval for t in tracers_out]
|
||||
return jaxpr, pvals_out, consts
|
||||
|
||||
# +
|
||||
from weakref import ref, ReferenceType
|
||||
|
||||
class LambdaBindingRecipe(NamedTuple): pass
|
||||
|
||||
class ConstRecipe(NamedTuple):
|
||||
val: Any
|
||||
|
||||
class JaxprEqnRecipe:
|
||||
prim: Primitive
|
||||
tracers_in: List['PartialEvalTracer']
|
||||
params: Dict[str, Any]
|
||||
avals_out: List[ShapedArray]
|
||||
tracer_refs_out: List['ReferenceType[PartialEvalTracer]']
|
||||
|
||||
def __init__(self, prim, tracers_in, params, avals_out, tracer_refs_out):
|
||||
self.prim = prim
|
||||
self.tracers_in = tracers_in
|
||||
self.params = params
|
||||
self.avals_out = avals_out
|
||||
self.tracer_refs_out = tracer_refs_out
|
||||
|
||||
JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
|
||||
|
||||
class PartialEvalTracer(Tracer):
|
||||
pval: PartialVal
|
||||
recipe: JaxprRecipe
|
||||
|
||||
def __init__(self, trace, pval, recipe):
|
||||
self._trace = trace
|
||||
self.pval = pval
|
||||
self.recipe = recipe
|
||||
|
||||
@property
|
||||
def aval(self):
|
||||
return self.pval.aval
|
||||
|
||||
def full_lower(self):
|
||||
if self.pval.is_known:
|
||||
return full_lower(self.pval.const)
|
||||
return self
|
||||
|
||||
class PartialEvalTrace(Trace):
|
||||
def new_arg(self, pval: PartialVal) -> Any:
|
||||
return PartialEvalTracer(self, pval, LambdaBindingRecipe())
|
||||
|
||||
def lift(self, val: Any) -> PartialEvalTracer:
|
||||
return PartialEvalTracer(self, PartialVal.known(val), None)
|
||||
pure = lift
|
||||
|
||||
def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:
|
||||
if tracer.pval.is_unknown:
|
||||
return tracer
|
||||
else:
|
||||
pval = PartialVal.unknown(raise_to_shaped(tracer.aval))
|
||||
return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
if all(t.pval.is_known for t in tracers):
|
||||
return bind(primitive, *map(full_lower, tracers), **params)
|
||||
tracers_in = [self.instantiate_const(t) for t in tracers]
|
||||
avals_in = [t.aval for t in tracers_in]
|
||||
avals_out = abstract_eval_rules[primitive](*avals_in, **params)
|
||||
tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)
|
||||
for aval in avals_out]
|
||||
eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,
|
||||
map(ref, tracers_out))
|
||||
for t in tracers_out: t.recipe = eqn
|
||||
return tracers_out
|
||||
|
||||
# +
|
||||
def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer],
|
||||
tracers_out: List[PartialEvalTracer]):
|
||||
tracers_in = [t for t in tracers_in if t.pval.is_unknown]
|
||||
tracers_out = [t for t in tracers_out if t.pval.is_unknown]
|
||||
|
||||
tracer_to_var = {id(t): Var(raise_to_shaped(t.aval)) for t in tracers_in}
|
||||
constvar_to_val = {}
|
||||
constid_to_var = {}
|
||||
processed_eqns = set()
|
||||
eqns = []
|
||||
for t in toposort(tracers_out, tracer_parents):
|
||||
if isinstance(t.recipe, LambdaBindingRecipe):
|
||||
assert id(t) in set(map(id, tracers_in))
|
||||
elif isinstance(t.recipe, ConstRecipe):
|
||||
val = t.recipe.val
|
||||
var = constid_to_var.get(id(val))
|
||||
if var is None:
|
||||
aval = raise_to_shaped(get_aval(val))
|
||||
var = tracer_to_var[id(t)] = constid_to_var[id(val)] = Var(aval)
|
||||
constvar_to_val[var] = val
|
||||
elif isinstance(t.recipe, JaxprEqnRecipe):
|
||||
if id(t.recipe) not in processed_eqns:
|
||||
eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))
|
||||
processed_eqns.add(id(t.recipe))
|
||||
else:
|
||||
raise TypeError(t.recipe)
|
||||
|
||||
constvars, constvals = unzip2(constvar_to_val.items())
|
||||
in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]
|
||||
out_vars = [tracer_to_var[id(t)] for t in tracers_out]
|
||||
jaxpr = Jaxpr(in_binders, eqns, out_vars)
|
||||
typecheck_jaxpr(jaxpr)
|
||||
return jaxpr, constvals
|
||||
|
||||
def recipe_to_eqn(tracer_to_var: Dict[int, Var], recipe: JaxprEqnRecipe
|
||||
) -> JaxprEqn:
|
||||
inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]
|
||||
out_binders = [Var(aval) for aval in recipe.avals_out]
|
||||
for t_ref, var in zip(recipe.tracer_refs_out, out_binders):
|
||||
if t_ref() is not None: tracer_to_var[id(t_ref())] = var
|
||||
return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)
|
||||
|
||||
def tracer_parents(t: PartialEvalTracer) -> List[PartialEvalTracer]:
|
||||
return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []
|
||||
|
||||
# +
|
||||
def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]):
|
||||
if not out_nodes: return []
|
||||
out_nodes = remove_duplicates(out_nodes)
|
||||
|
||||
child_counts = {}
|
||||
stack = list(out_nodes)
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if id(node) in child_counts:
|
||||
child_counts[id(node)] += 1
|
||||
else:
|
||||
child_counts[id(node)] = 1
|
||||
stack.extend(parents(node))
|
||||
for node in out_nodes:
|
||||
child_counts[id(node)] -= 1
|
||||
|
||||
sorted_nodes = []
|
||||
childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]
|
||||
while childless_nodes:
|
||||
node = childless_nodes.pop()
|
||||
sorted_nodes.append(node)
|
||||
for parent in parents(node):
|
||||
if child_counts[id(parent)] == 1:
|
||||
childless_nodes.append(parent)
|
||||
else:
|
||||
child_counts[id(parent)] -= 1
|
||||
|
||||
sorted_nodes = sorted_nodes[::-1]
|
||||
check_toposort(sorted_nodes, parents)
|
||||
return sorted_nodes
|
||||
|
||||
def remove_duplicates(lst):
|
||||
seen = set()
|
||||
return [x for x in lst if id(x) not in seen and not seen.add(id(x))]
|
||||
|
||||
def check_toposort(nodes: List[Any], parents: Callable[[Any], List[Any]]):
|
||||
seen = set()
|
||||
for node in nodes:
|
||||
assert all(id(parent) in seen for parent in parents(node))
|
||||
seen.add(id(node))
|
||||
|
||||
|
||||
# -
|
||||
|
||||
y, sin_lin = linearize(sin, 3.)
|
||||
print(y, sin(3.))
|
||||
print(sin_lin(1.), cos(3.))
|
||||
|
Loading…
x
Reference in New Issue
Block a user