Document jax.disable_jit. Add an example to jax.grad.

This commit is contained in:
Peter Hawkins 2019-02-20 09:00:12 -05:00
parent 7f65ec9394
commit 67174e3d57
2 changed files with 41 additions and 2 deletions

View File

@ -17,6 +17,6 @@ Module contents
---------------
.. automodule:: jax
:members: jit, grad, value_and_grad, vmap, jacfwd, jacrev, hessian, jvp, vjp, make_jaxpr
:members: jit, disable_jit, grad, value_and_grad, vmap, jacfwd, jacrev, hessian, jvp, vjp, make_jaxpr
:undoc-members:
:show-inheritance:

View File

@ -81,7 +81,7 @@ def jit(fun, static_argnums=()):
In the following example, `selu` can be compiled into a single fused kernel by
XLA:
>>> @jit
>>> @jax.jit
>>> def selu(x, alpha=1.67, lmbda=1.05):
>>> return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
>>>
@ -112,6 +112,38 @@ def jit(fun, static_argnums=()):
@contextmanager
def disable_jit():
"""Context manager that disables `jit`.
For debugging purposes, it is useful to have a mechanism that disables `jit`
everywhere in a block of code, namely the `disable_jit` decorator.
Inside a `jit`-ted function the values flowing through
traced code can be abstract (i.e., shaped arrays with an unknown values),
instead of concrete (i.e., specific arrays with known values).
For example:
>>> @jax.jit
>>> def f(x):
>>> y = x *2
>>> print("Value of y is", y)
>>> return y + 3
>>>
>>> print(f(jax.numpy.array([1, 2, 3])))
Value of y is Traced<ShapedArray(int32[3]):JaxprTrace(level=-1/1)>
[5 7 9]
Here `y` has been abstracted by `jit` to a `ShapedArray`, which represents an
array with a fixed shape and type but an arbitrary value. If we want to see a
concrete values while debugging, we can use the `disable_jit` decorator, at
the cost of slower code:
>>> with jax.disable_jit():
>>> print(f(np.array([1, 2, 3])))
>>>
Value of y is [2 4 6]
[5 7 9]
"""
global _jit_is_disabled
_jit_is_disabled, prev_val = True, _jit_is_disabled
yield
@ -136,6 +168,13 @@ def grad(fun, argnums=0):
type as the positional argument indicated by that integer. If argnums is a
tuple of integers, the gradient is a tuple of values with the same shapes
and types as the corresponding arguments.
For example:
>>> grad_tanh = jax.grad(jax.numpy.tanh)
>>> grad_tanh(0.2)
array(0.961043, dtype=float32)
"""
value_and_grad_f = value_and_grad(fun, argnums)