mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #10668 from sharadmv:custom-interpreter-update
PiperOrigin-RevId: 448058058
This commit is contained in:
commit
d092d6305f
@ -102,9 +102,8 @@
|
||||
"source": [
|
||||
"To get a first look at Jaxprs, consider the `make_jaxpr` transformation. `make_jaxpr` is essentially a \"pretty-printing\" transformation:\n",
|
||||
"it transforms a function into one that, given example arguments, produces a Jaxpr representation of its computation.\n",
|
||||
"Although we can't generally use the Jaxprs that it returns, it is useful for debugging and introspection.\n",
|
||||
"Let's use it to look at how some example Jaxprs\n",
|
||||
"are structured."
|
||||
"`make_jaxpr` is useful for debugging and introspection.\n",
|
||||
"Let's use it to look at how some example Jaxprs are structured."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -201,7 +200,7 @@
|
||||
"\n",
|
||||
"### 1. Tracing a function\n",
|
||||
"\n",
|
||||
"We can't use `make_jaxpr` for this, because we need to pull out constants created during the trace to pass into the Jaxpr. However, we can write a function that does something very similar to `make_jaxpr`."
|
||||
"Let's use `make_jaxpr` to trace a function into a Jaxpr."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -227,8 +226,8 @@
|
||||
"id": "CpTml2PTrzZ4"
|
||||
},
|
||||
"source": [
|
||||
"This function first flattens its arguments into a list, which are the abstracted and wrapped as partial values. The `jax.make_jaxpr` function is used to then trace a function into a Jaxpr\n",
|
||||
"from a list of partial value inputs."
|
||||
"`jax.make_jaxpr` returns a *closed* Jaxpr, which is a Jaxpr that has been bundled with\n",
|
||||
"the constants (`literals`) from the trace."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -243,7 +242,7 @@
|
||||
" return jnp.exp(jnp.tanh(x))\n",
|
||||
"\n",
|
||||
"closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))\n",
|
||||
"print(closed_jaxpr)\n",
|
||||
"print(closed_jaxpr.jaxpr)\n",
|
||||
"print(closed_jaxpr.literals)"
|
||||
]
|
||||
},
|
||||
@ -321,7 +320,7 @@
|
||||
"source": [
|
||||
"Notice that `eval_jaxpr` will always return a flat list even if the original function does not.\n",
|
||||
"\n",
|
||||
"Furthermore, this interpreter does not handle `subjaxprs`, which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover."
|
||||
"Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -389,9 +388,8 @@
|
||||
"def inverse(fun):\n",
|
||||
" @wraps(fun)\n",
|
||||
" def wrapped(*args, **kwargs):\n",
|
||||
" # Since we assume unary functions, we won't\n",
|
||||
" # worry about flattening and\n",
|
||||
" # unflattening arguments\n",
|
||||
" # Since we assume unary functions, we won't worry about flattening and\n",
|
||||
" # unflattening arguments.\n",
|
||||
" closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)\n",
|
||||
" out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)\n",
|
||||
" return out[0]\n",
|
||||
@ -434,9 +432,8 @@
|
||||
" # outvars are now invars \n",
|
||||
" invals = safe_map(read, eqn.outvars)\n",
|
||||
" if eqn.primitive not in inverse_registry:\n",
|
||||
" raise NotImplementedError(\"{} does not have registered inverse.\".format(\n",
|
||||
" eqn.primitive\n",
|
||||
" ))\n",
|
||||
" raise NotImplementedError(\n",
|
||||
" f\"{eqn.primitive} does not have registered inverse.\")\n",
|
||||
" # Assuming a unary function \n",
|
||||
" outval = inverse_registry[eqn.primitive](*invals)\n",
|
||||
" safe_map(write, eqn.invars, [outval])\n",
|
||||
|
@ -70,9 +70,8 @@ for function transformation.
|
||||
|
||||
To get a first look at Jaxprs, consider the `make_jaxpr` transformation. `make_jaxpr` is essentially a "pretty-printing" transformation:
|
||||
it transforms a function into one that, given example arguments, produces a Jaxpr representation of its computation.
|
||||
Although we can't generally use the Jaxprs that it returns, it is useful for debugging and introspection.
|
||||
Let's use it to look at how some example Jaxprs
|
||||
are structured.
|
||||
`make_jaxpr` is useful for debugging and introspection.
|
||||
Let's use it to look at how some example Jaxprs are structured.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: RSxEiWi-EeYW
|
||||
@ -139,7 +138,7 @@ The way we'll implement this is by (1) tracing `f` into a Jaxpr, then (2) interp
|
||||
|
||||
### 1. Tracing a function
|
||||
|
||||
We can't use `make_jaxpr` for this, because we need to pull out constants created during the trace to pass into the Jaxpr. However, we can write a function that does something very similar to `make_jaxpr`.
|
||||
Let's use `make_jaxpr` to trace a function into a Jaxpr.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: BHkg_3P1pXJj
|
||||
@ -155,8 +154,8 @@ from jax._src.util import safe_map
|
||||
|
||||
+++ {"id": "CpTml2PTrzZ4"}
|
||||
|
||||
This function first flattens its arguments into a list, which are the abstracted and wrapped as partial values. The `jax.make_jaxpr` function is used to then trace a function into a Jaxpr
|
||||
from a list of partial value inputs.
|
||||
`jax.make_jaxpr` returns a *closed* Jaxpr, which is a Jaxpr that has been bundled with
|
||||
the constants (`literals`) from the trace.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: Tc1REN5aq_fH
|
||||
@ -165,7 +164,7 @@ def f(x):
|
||||
return jnp.exp(jnp.tanh(x))
|
||||
|
||||
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
|
||||
print(closed_jaxpr)
|
||||
print(closed_jaxpr.jaxpr)
|
||||
print(closed_jaxpr.literals)
|
||||
```
|
||||
|
||||
@ -224,7 +223,7 @@ eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))
|
||||
|
||||
Notice that `eval_jaxpr` will always return a flat list even if the original function does not.
|
||||
|
||||
Furthermore, this interpreter does not handle `subjaxprs`, which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover.
|
||||
Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover.
|
||||
|
||||
+++ {"id": "0vb2ZoGrCMM4"}
|
||||
|
||||
@ -261,9 +260,8 @@ inverse_registry[lax.tanh_p] = jnp.arctanh
|
||||
def inverse(fun):
|
||||
@wraps(fun)
|
||||
def wrapped(*args, **kwargs):
|
||||
# Since we assume unary functions, we won't
|
||||
# worry about flattening and
|
||||
# unflattening arguments
|
||||
# Since we assume unary functions, we won't worry about flattening and
|
||||
# unflattening arguments.
|
||||
closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
|
||||
out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
|
||||
return out[0]
|
||||
@ -296,9 +294,8 @@ def inverse_jaxpr(jaxpr, consts, *args):
|
||||
# outvars are now invars
|
||||
invals = safe_map(read, eqn.outvars)
|
||||
if eqn.primitive not in inverse_registry:
|
||||
raise NotImplementedError("{} does not have registered inverse.".format(
|
||||
eqn.primitive
|
||||
))
|
||||
raise NotImplementedError(
|
||||
f"{eqn.primitive} does not have registered inverse.")
|
||||
# Assuming a unary function
|
||||
outval = inverse_registry[eqn.primitive](*invals)
|
||||
safe_map(write, eqn.invars, [outval])
|
||||
|
Loading…
x
Reference in New Issue
Block a user