Clarify in JAX Basics that JAX array creation is also an operation that requires accelerator dispatch and converting to a regular Python type is a blocking operation.

This commit is contained in:
Peter Hawkins 2022-10-26 13:38:17 -04:00
parent 280153334b
commit 71a384d25e
2 changed files with 18 additions and 2 deletions

View File

@ -105,7 +105,15 @@
"\n",
"We will now perform a dot product to demonstrate that it can be done in different devices without changing the code. We use `%timeit` to check the performance. \n",
"\n",
"(Technical detail: when a JAX function is called, the corresponding operation is dispatched to an accelerator to be computed asynchronously when possible. The returned array is therefore not necessarily 'filled in' as soon as the function returns. Thus, if we don't require the result immediately, the computation won't block Python execution. Therefore, unless we `block_until_ready`, we will only time the dispatch, not the actual computation. See [Asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html#asynchronous-dispatch) in the JAX docs.)"
"(Technical detail: when a JAX function is called (including `jnp.array`\n",
"creation), the corresponding operation is dispatched to an accelerator to be\n",
"computed asynchronously when possible. The returned array is therefore not\n",
"necessarily 'filled in' as soon as the function returns. Thus, if we don't\n",
"require the result immediately, the computation won't block Python execution.\n",
"Therefore, unless we `block_until_ready` or convert the array to a regular\n",
"Python type, we will only time the dispatch, not the actual computation. See\n",
"[Asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html#asynchronous-dispatch)\n",
"in the JAX docs.)"
]
},
{

View File

@ -59,7 +59,15 @@ One useful feature of JAX is that the same code can be run on different backends
We will now perform a dot product to demonstrate that it can be done in different devices without changing the code. We use `%timeit` to check the performance.
(Technical detail: when a JAX function is called, the corresponding operation is dispatched to an accelerator to be computed asynchronously when possible. The returned array is therefore not necessarily 'filled in' as soon as the function returns. Thus, if we don't require the result immediately, the computation won't block Python execution. Therefore, unless we `block_until_ready`, we will only time the dispatch, not the actual computation. See [Asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html#asynchronous-dispatch) in the JAX docs.)
(Technical detail: when a JAX function is called (including `jnp.array`
creation), the corresponding operation is dispatched to an accelerator to be
computed asynchronously when possible. The returned array is therefore not
necessarily 'filled in' as soon as the function returns. Thus, if we don't
require the result immediately, the computation won't block Python execution.
Therefore, unless we `block_until_ready` or convert the array to a regular
Python type, we will only time the dispatch, not the actual computation. See
[Asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html#asynchronous-dispatch)
in the JAX docs.)
```{code-cell} ipython3
:id: mRvjVxoqo-Bi