mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Document jax.disable_jit. Add an example to jax.grad.
This commit is contained in:
parent
7f65ec9394
commit
67174e3d57
@ -17,6 +17,6 @@ Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: jax
|
||||
:members: jit, grad, value_and_grad, vmap, jacfwd, jacrev, hessian, jvp, vjp, make_jaxpr
|
||||
:members: jit, disable_jit, grad, value_and_grad, vmap, jacfwd, jacrev, hessian, jvp, vjp, make_jaxpr
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
41
jax/api.py
41
jax/api.py
@ -81,7 +81,7 @@ def jit(fun, static_argnums=()):
|
||||
In the following example, `selu` can be compiled into a single fused kernel by
|
||||
XLA:
|
||||
|
||||
>>> @jit
|
||||
>>> @jax.jit
|
||||
>>> def selu(x, alpha=1.67, lmbda=1.05):
|
||||
>>> return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
|
||||
>>>
|
||||
@ -112,6 +112,38 @@ def jit(fun, static_argnums=()):
|
||||
|
||||
@contextmanager
|
||||
def disable_jit():
|
||||
"""Context manager that disables `jit`.
|
||||
|
||||
For debugging purposes, it is useful to have a mechanism that disables `jit`
|
||||
everywhere in a block of code, namely the `disable_jit` decorator.
|
||||
|
||||
Inside a `jit`-ted function the values flowing through
|
||||
traced code can be abstract (i.e., shaped arrays with an unknown values),
|
||||
instead of concrete (i.e., specific arrays with known values).
|
||||
|
||||
For example:
|
||||
|
||||
>>> @jax.jit
|
||||
>>> def f(x):
|
||||
>>> y = x *2
|
||||
>>> print("Value of y is", y)
|
||||
>>> return y + 3
|
||||
>>>
|
||||
>>> print(f(jax.numpy.array([1, 2, 3])))
|
||||
Value of y is Traced<ShapedArray(int32[3]):JaxprTrace(level=-1/1)>
|
||||
[5 7 9]
|
||||
|
||||
Here `y` has been abstracted by `jit` to a `ShapedArray`, which represents an
|
||||
array with a fixed shape and type but an arbitrary value. If we want to see a
|
||||
concrete values while debugging, we can use the `disable_jit` decorator, at
|
||||
the cost of slower code:
|
||||
|
||||
>>> with jax.disable_jit():
|
||||
>>> print(f(np.array([1, 2, 3])))
|
||||
>>>
|
||||
Value of y is [2 4 6]
|
||||
[5 7 9]
|
||||
"""
|
||||
global _jit_is_disabled
|
||||
_jit_is_disabled, prev_val = True, _jit_is_disabled
|
||||
yield
|
||||
@ -136,6 +168,13 @@ def grad(fun, argnums=0):
|
||||
type as the positional argument indicated by that integer. If argnums is a
|
||||
tuple of integers, the gradient is a tuple of values with the same shapes
|
||||
and types as the corresponding arguments.
|
||||
|
||||
For example:
|
||||
|
||||
>>> grad_tanh = jax.grad(jax.numpy.tanh)
|
||||
>>> grad_tanh(0.2)
|
||||
array(0.961043, dtype=float32)
|
||||
|
||||
"""
|
||||
value_and_grad_f = value_and_grad(fun, argnums)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user