mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #10074 from jakevdp:jit-partial-doc
PiperOrigin-RevId: 438193490
This commit is contained in:
commit
9da5f4e793
@ -431,7 +431,7 @@
|
||||
"id": "5XUT2acoHBz-"
|
||||
},
|
||||
"source": [
|
||||
"If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified input. It is only a good strategy if the function is guaranteed to get limited different values."
|
||||
"If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified static input. It is only a good strategy if the function is guaranteed to get limited different values."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -472,10 +472,46 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"g_jit_correct = jax.jit(g, static_argnums=1)\n",
|
||||
"g_jit_correct = jax.jit(g, static_argnames=['n'])\n",
|
||||
"print(g_jit_correct(10, 20))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To specify such arguments when using `jit` as a decorator, a common pattern is to use python's `functools.partial`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "2X5rR4jkIO",
|
||||
"outputId": "81-4744-dc2e4-4e10f470f2-a19e71d9121"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"30\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from functools import partial\n",
|
||||
"\n",
|
||||
"@partial(jax.jit, static_argnames=['n'])\n",
|
||||
"def g_jit_decorated(x, n):\n",
|
||||
" i = 0\n",
|
||||
" while i < n:\n",
|
||||
" i += 1\n",
|
||||
" return x + i\n",
|
||||
"\n",
|
||||
"print(g_jit_decorated(10, 20))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
|
@ -223,7 +223,7 @@ g_inner_jitted(10, 20)
|
||||
|
||||
+++ {"id": "5XUT2acoHBz-"}
|
||||
|
||||
If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified input. It is only a good strategy if the function is guaranteed to get limited different values.
|
||||
If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified static input. It is only a good strategy if the function is guaranteed to get limited different values.
|
||||
|
||||
```{code-cell}
|
||||
:id: 2yQmQTDNAenY
|
||||
@ -237,10 +237,28 @@ print(f_jit_correct(10))
|
||||
:id: R4SXUEu-M-u1
|
||||
:outputId: 9e712e14-4e81-4744-dcf2-a10f470d9121
|
||||
|
||||
g_jit_correct = jax.jit(g, static_argnums=1)
|
||||
g_jit_correct = jax.jit(g, static_argnames=['n'])
|
||||
print(g_jit_correct(10, 20))
|
||||
```
|
||||
|
||||
To specify such arguments when using `jit` as a decorator, a common pattern is to use python's `functools.partial`:
|
||||
|
||||
```{code-cell}
|
||||
:id: 2X5rR4jkIO
|
||||
:outputId: 81-4744-dc2e4-4e10f470f2-a19e71d9121
|
||||
|
||||
from functools import partial
|
||||
|
||||
@partial(jax.jit, static_argnames=['n'])
|
||||
def g_jit_decorated(x, n):
|
||||
i = 0
|
||||
while i < n:
|
||||
i += 1
|
||||
return x + i
|
||||
|
||||
print(g_jit_decorated(10, 20))
|
||||
```
|
||||
|
||||
+++ {"id": "LczjIBt2X2Ms"}
|
||||
|
||||
## When to use JIT
|
||||
|
@ -297,20 +297,35 @@ def jit(
|
||||
Returns:
|
||||
A wrapped version of ``fun``, set up for just-in-time compilation.
|
||||
|
||||
In the following example, ``selu`` can be compiled into a single fused kernel
|
||||
by XLA:
|
||||
Examples:
|
||||
In the following example, ``selu`` can be compiled into a single fused kernel
|
||||
by XLA:
|
||||
|
||||
>>> import jax
|
||||
>>>
|
||||
>>> @jax.jit
|
||||
... def selu(x, alpha=1.67, lmbda=1.05):
|
||||
... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
|
||||
>>>
|
||||
>>> key = jax.random.PRNGKey(0)
|
||||
>>> x = jax.random.normal(key, (10,))
|
||||
>>> print(selu(x)) # doctest: +SKIP
|
||||
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
|
||||
-0.85743 -0.78232 0.76827 0.59566 ]
|
||||
>>> import jax
|
||||
>>>
|
||||
>>> @jax.jit
|
||||
... def selu(x, alpha=1.67, lmbda=1.05):
|
||||
... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
|
||||
>>>
|
||||
>>> key = jax.random.PRNGKey(0)
|
||||
>>> x = jax.random.normal(key, (10,))
|
||||
>>> print(selu(x)) # doctest: +SKIP
|
||||
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
|
||||
-0.85743 -0.78232 0.76827 0.59566 ]
|
||||
|
||||
To pass arguments such as ``static_argnames`` when decorating a function, a common
|
||||
pattern is to use :func:`functools.partial`:
|
||||
|
||||
>>> from functools import partial
|
||||
>>>
|
||||
>>> @partial(jax.jit, static_argnames=['n'])
|
||||
... def g(x, n):
|
||||
... for i in range(n):
|
||||
... x = x ** 2
|
||||
... return x
|
||||
>>>
|
||||
>>> g(jnp.arange(4), 3)
|
||||
DeviceArray([ 0, 1, 256, 6561], dtype=int32)
|
||||
"""
|
||||
if FLAGS.experimental_cpp_jit:
|
||||
return _cpp_jit(fun, static_argnums, static_argnames, device, backend,
|
||||
|
Loading…
x
Reference in New Issue
Block a user