Merge pull request #10074 from jakevdp:jit-partial-doc

PiperOrigin-RevId: 438193490
This commit is contained in:
jax authors 2022-03-29 20:02:17 -07:00
commit 9da5f4e793
3 changed files with 86 additions and 17 deletions

View File

@ -431,7 +431,7 @@
"id": "5XUT2acoHBz-"
},
"source": [
"If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified input. It is only a good strategy if the function is guaranteed to get limited different values."
"If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified static input. It is only a good strategy if the function is guaranteed to get limited different values."
]
},
{
@ -472,10 +472,46 @@
}
],
"source": [
"g_jit_correct = jax.jit(g, static_argnums=1)\n",
"g_jit_correct = jax.jit(g, static_argnames=['n'])\n",
"print(g_jit_correct(10, 20))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To specify such arguments when using `jit` as a decorator, a common pattern is to use python's `functools.partial`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2X5rR4jkIO",
"outputId": "81-4744-dc2e4-4e10f470f2-a19e71d9121"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"30\n"
]
}
],
"source": [
"from functools import partial\n",
"\n",
"@partial(jax.jit, static_argnames=['n'])\n",
"def g_jit_decorated(x, n):\n",
" i = 0\n",
" while i < n:\n",
" i += 1\n",
" return x + i\n",
"\n",
"print(g_jit_decorated(10, 20))"
]
},
{
"cell_type": "markdown",
"metadata": {

View File

@ -223,7 +223,7 @@ g_inner_jitted(10, 20)
+++ {"id": "5XUT2acoHBz-"}
If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified input. It is only a good strategy if the function is guaranteed to get limited different values.
If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified static input. It is only a good strategy if the function is guaranteed to get limited different values.
```{code-cell}
:id: 2yQmQTDNAenY
@ -237,10 +237,28 @@ print(f_jit_correct(10))
:id: R4SXUEu-M-u1
:outputId: 9e712e14-4e81-4744-dcf2-a10f470d9121
g_jit_correct = jax.jit(g, static_argnums=1)
g_jit_correct = jax.jit(g, static_argnames=['n'])
print(g_jit_correct(10, 20))
```
To specify such arguments when using `jit` as a decorator, a common pattern is to use python's `functools.partial`:
```{code-cell}
:id: 2X5rR4jkIO
:outputId: 81-4744-dc2e4-4e10f470f2-a19e71d9121
from functools import partial
@partial(jax.jit, static_argnames=['n'])
def g_jit_decorated(x, n):
i = 0
while i < n:
i += 1
return x + i
print(g_jit_decorated(10, 20))
```
+++ {"id": "LczjIBt2X2Ms"}
## When to use JIT

View File

@ -297,20 +297,35 @@ def jit(
Returns:
A wrapped version of ``fun``, set up for just-in-time compilation.
In the following example, ``selu`` can be compiled into a single fused kernel
by XLA:
Examples:
In the following example, ``selu`` can be compiled into a single fused kernel
by XLA:
>>> import jax
>>>
>>> @jax.jit
... def selu(x, alpha=1.67, lmbda=1.05):
... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
>>>
>>> key = jax.random.PRNGKey(0)
>>> x = jax.random.normal(key, (10,))
>>> print(selu(x)) # doctest: +SKIP
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
-0.85743 -0.78232 0.76827 0.59566 ]
>>> import jax
>>>
>>> @jax.jit
... def selu(x, alpha=1.67, lmbda=1.05):
... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
>>>
>>> key = jax.random.PRNGKey(0)
>>> x = jax.random.normal(key, (10,))
>>> print(selu(x)) # doctest: +SKIP
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
-0.85743 -0.78232 0.76827 0.59566 ]
To pass arguments such as ``static_argnames`` when decorating a function, a common
pattern is to use :func:`functools.partial`:
>>> from functools import partial
>>>
>>> @partial(jax.jit, static_argnames=['n'])
... def g(x, n):
... for i in range(n):
... x = x ** 2
... return x
>>>
>>> g(jnp.arange(4), 3)
DeviceArray([ 0, 1, 256, 6561], dtype=int32)
"""
if FLAGS.experimental_cpp_jit:
return _cpp_jit(fun, static_argnums, static_argnames, device, backend,