Merge pull request #10668 from sharadmv:custom-interpreter-update

PiperOrigin-RevId: 448058058
This commit is contained in:
jax authors 2022-05-11 12:31:13 -07:00
commit d092d6305f
2 changed files with 22 additions and 28 deletions

View File

@ -102,9 +102,8 @@
"source": [
"To get a first look at Jaxprs, consider the `make_jaxpr` transformation. `make_jaxpr` is essentially a \"pretty-printing\" transformation:\n",
"it transforms a function into one that, given example arguments, produces a Jaxpr representation of its computation.\n",
"Although we can't generally use the Jaxprs that it returns, it is useful for debugging and introspection.\n",
"Let's use it to look at how some example Jaxprs\n",
"are structured."
"`make_jaxpr` is useful for debugging and introspection.\n",
"Let's use it to look at how some example Jaxprs are structured."
]
},
{
@ -201,7 +200,7 @@
"\n",
"### 1. Tracing a function\n",
"\n",
"We can't use `make_jaxpr` for this, because we need to pull out constants created during the trace to pass into the Jaxpr. However, we can write a function that does something very similar to `make_jaxpr`."
"Let's use `make_jaxpr` to trace a function into a Jaxpr."
]
},
{
@ -227,8 +226,8 @@
"id": "CpTml2PTrzZ4"
},
"source": [
"This function first flattens its arguments into a list, which are the abstracted and wrapped as partial values. The `jax.make_jaxpr` function is used to then trace a function into a Jaxpr\n",
"from a list of partial value inputs."
"`jax.make_jaxpr` returns a *closed* Jaxpr, which is a Jaxpr that has been bundled with\n",
"the constants (`literals`) from the trace."
]
},
{
@ -243,7 +242,7 @@
" return jnp.exp(jnp.tanh(x))\n",
"\n",
"closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))\n",
"print(closed_jaxpr)\n",
"print(closed_jaxpr.jaxpr)\n",
"print(closed_jaxpr.literals)"
]
},
@ -321,7 +320,7 @@
"source": [
"Notice that `eval_jaxpr` will always return a flat list even if the original function does not.\n",
"\n",
"Furthermore, this interpreter does not handle `subjaxprs`, which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover."
"Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover."
]
},
{
@ -389,9 +388,8 @@
"def inverse(fun):\n",
" @wraps(fun)\n",
" def wrapped(*args, **kwargs):\n",
" # Since we assume unary functions, we won't\n",
" # worry about flattening and\n",
" # unflattening arguments\n",
" # Since we assume unary functions, we won't worry about flattening and\n",
" # unflattening arguments.\n",
" closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)\n",
" out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)\n",
" return out[0]\n",
@ -434,9 +432,8 @@
" # outvars are now invars \n",
" invals = safe_map(read, eqn.outvars)\n",
" if eqn.primitive not in inverse_registry:\n",
" raise NotImplementedError(\"{} does not have registered inverse.\".format(\n",
" eqn.primitive\n",
" ))\n",
" raise NotImplementedError(\n",
" f\"{eqn.primitive} does not have registered inverse.\")\n",
" # Assuming a unary function \n",
" outval = inverse_registry[eqn.primitive](*invals)\n",
" safe_map(write, eqn.invars, [outval])\n",

View File

@ -70,9 +70,8 @@ for function transformation.
To get a first look at Jaxprs, consider the `make_jaxpr` transformation. `make_jaxpr` is essentially a "pretty-printing" transformation:
it transforms a function into one that, given example arguments, produces a Jaxpr representation of its computation.
Although we can't generally use the Jaxprs that it returns, it is useful for debugging and introspection.
Let's use it to look at how some example Jaxprs
are structured.
`make_jaxpr` is useful for debugging and introspection.
Let's use it to look at how some example Jaxprs are structured.
```{code-cell} ipython3
:id: RSxEiWi-EeYW
@ -139,7 +138,7 @@ The way we'll implement this is by (1) tracing `f` into a Jaxpr, then (2) interp
### 1. Tracing a function
We can't use `make_jaxpr` for this, because we need to pull out constants created during the trace to pass into the Jaxpr. However, we can write a function that does something very similar to `make_jaxpr`.
Let's use `make_jaxpr` to trace a function into a Jaxpr.
```{code-cell} ipython3
:id: BHkg_3P1pXJj
@ -155,8 +154,8 @@ from jax._src.util import safe_map
+++ {"id": "CpTml2PTrzZ4"}
This function first flattens its arguments into a list, which are the abstracted and wrapped as partial values. The `jax.make_jaxpr` function is used to then trace a function into a Jaxpr
from a list of partial value inputs.
`jax.make_jaxpr` returns a *closed* Jaxpr, which is a Jaxpr that has been bundled with
the constants (`literals`) from the trace.
```{code-cell} ipython3
:id: Tc1REN5aq_fH
@ -165,7 +164,7 @@ def f(x):
return jnp.exp(jnp.tanh(x))
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
print(closed_jaxpr)
print(closed_jaxpr.jaxpr)
print(closed_jaxpr.literals)
```
@ -224,7 +223,7 @@ eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))
Notice that `eval_jaxpr` will always return a flat list even if the original function does not.
Furthermore, this interpreter does not handle `subjaxprs`, which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover.
Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover.
+++ {"id": "0vb2ZoGrCMM4"}
@ -261,9 +260,8 @@ inverse_registry[lax.tanh_p] = jnp.arctanh
def inverse(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
# Since we assume unary functions, we won't
# worry about flattening and
# unflattening arguments
# Since we assume unary functions, we won't worry about flattening and
# unflattening arguments.
closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return out[0]
@ -296,9 +294,8 @@ def inverse_jaxpr(jaxpr, consts, *args):
# outvars are now invars
invals = safe_map(read, eqn.outvars)
if eqn.primitive not in inverse_registry:
raise NotImplementedError("{} does not have registered inverse.".format(
eqn.primitive
))
raise NotImplementedError(
f"{eqn.primitive} does not have registered inverse.")
# Assuming a unary function
outval = inverse_registry[eqn.primitive](*invals)
safe_map(write, eqn.invars, [outval])