DOC: re-enable execution of thinking_in_jax.ipynb

This commit is contained in:
Jake VanderPlas 2023-08-24 09:23:26 -07:00
parent 36cdafdcf4
commit 130a53f2a2
3 changed files with 152 additions and 216 deletions

View File

@ -85,6 +85,8 @@ suppress_warnings = [
'ref.citation', # Many duplicated citations in numpy/scipy docstrings.
'ref.footnote', # Many unreferenced footnotes in numpy/scipy docstrings
'myst.header',
# TODO(jakevdp): remove this suppression once issue is fixed.
'misc.highlighting_failure', # https://github.com/ipython/ipython/issues/14142
]
# Add any paths that contain templates here, relative to this directory.
@ -199,8 +201,6 @@ nb_execution_excludepatterns = [
'notebooks/neural_network_with_tfds_data.*',
# Slow notebook
'notebooks/Neural_Network_and_Data_Loading.*',
# Strange error apparently due to asynchronous cell execution
'notebooks/thinking_in_jax.*',
# Has extra requirements: networkx, pandas, pytorch, tensorflow, etc.
'jep/9407-type-promotion.*',
# TODO(jakevdp): enable execution on the following if possible:

File diff suppressed because one or more lines are too long

View File

@ -11,24 +11,6 @@ kernelspec:
name: python3
---
```{code-cell} ipython3
:id: aPUwOm-eCSFD
:tags: [remove-cell]
# Configure ipython to hide long tracebacks.
import sys
ipython = get_ipython()
def minimal_traceback(*args, **kwargs):
etype, value, tb = sys.exc_info()
value.__cause__ = None # suppress chained exceptions
stb = ipython.InteractiveTB.structured_traceback(etype, value, tb)
del stb[3:-1]
return ipython._showtraceback(etype, value, stb)
ipython.showtraceback = minimal_traceback
```
+++ {"id": "LQHmwePqryRU"}
# How to Think in JAX
@ -51,7 +33,7 @@ NumPy provides a well-known, powerful API for working with numerical data. For c
```{code-cell} ipython3
:id: kZaOXL7-uvUP
:outputId: 17a9ee0a-8719-44bb-a9fe-4c9f24649fef
:outputId: 7fd4dd8e-4194-4983-ac6b-28059f8feb90
import matplotlib.pyplot as plt
import numpy as np
@ -63,7 +45,7 @@ plt.plot(x_np, y_np);
```{code-cell} ipython3
:id: 18XbGpRLuZlr
:outputId: 9e98d928-1925-45b1-d886-37956ca95e7c
:outputId: 3d073b3c-913f-410b-ee33-b3a0eb878436
import jax.numpy as jnp
@ -80,14 +62,14 @@ The arrays themselves are implemented as different Python types:
```{code-cell} ipython3
:id: PjFFunI7xNe8
:outputId: e1706c61-2821-437a-efcd-d8082f913c1f
:outputId: d3b0007e-7997-45c0-d4b8-9f5699cedcbc
type(x_np)
```
```{code-cell} ipython3
:id: kpv5K7QYxQnX
:outputId: 8a3f1cb6-c6d6-494c-8efe-24a8217a9d55
:outputId: ba68a1de-f938-477d-9942-83a839aeca09
type(x_jnp)
```
@ -102,7 +84,7 @@ Here is an example of mutating an array in NumPy:
```{code-cell} ipython3
:id: fzp-y1ZVyGD4
:outputId: 300a44cc-1ccd-4fb2-f0ee-2179763f7690
:outputId: 6eb76bf8-0edd-43a5-b2be-85a79fb23190
# NumPy: mutable arrays
x = np.arange(10)
@ -114,9 +96,16 @@ print(x)
The equivalent in JAX results in an error, as JAX arrays are immutable:
```{code-cell} ipython3
:id: l2AP0QERb0P7
:outputId: 528a8e5f-538f-4739-fe95-1c3605ba8c8a
%xmode minimal
```
```{code-cell} ipython3
:id: pCPX0JR-yM4i
:outputId: 02a442bc-8f23-4dce-9500-81cd28c0b21f
:outputId: c7bf4afd-8b7f-4dac-d065-8189679861d6
:tags: [raises-exception]
# JAX: immutable arrays
@ -130,7 +119,7 @@ For updating individual elements, JAX provides an [indexed update syntax](https:
```{code-cell} ipython3
:id: 8zqPEAeP3UK5
:outputId: 7e6c996d-d0b0-4d52-e722-410ba78eb3b1
:outputId: 20a40c26-3419-4e60-bd2c-83ad30bd7650
y = x.at[0].set(10)
print(x)
@ -155,7 +144,7 @@ For example, while `jax.numpy` will implicitly promote arguments to allow operat
```{code-cell} ipython3
:id: c6EFPcj12mw0
:outputId: 730e2ca4-30a5-45bc-923c-c3a5143496e2
:outputId: 827d09eb-c8aa-43bc-b471-0a6c9c4f6601
import jax.numpy as jnp
jnp.add(1, 1.0) # jax.numpy API implicitly promotes mixed types.
@ -163,7 +152,7 @@ jnp.add(1, 1.0) # jax.numpy API implicitly promotes mixed types.
```{code-cell} ipython3
:id: 0VkqlcXL2qSp
:outputId: 601b0562-3e6a-402d-f83b-3afdd1e7e7c4
:outputId: 7e1e9233-2fe1-46a8-8eb1-1d1dbc54b58c
:tags: [raises-exception]
from jax import lax
@ -176,7 +165,7 @@ If using `jax.lax` directly, you'll have to do type promotion explicitly in such
```{code-cell} ipython3
:id: 3PNQlieT81mi
:outputId: cb3ed074-f410-456f-c086-23107eae2634
:outputId: 4bd2b6f3-d2d1-44cb-f8ee-18976ae40239
lax.add(jnp.float32(1), 1.0)
```
@ -189,7 +178,7 @@ For example, consider a 1D convolution, which can be expressed in NumPy this way
```{code-cell} ipython3
:id: Bv-7XexyzVCN
:outputId: f5d38cd8-e7fc-49e2-bff3-a0eee306cb54
:outputId: d570f64a-ca61-456f-8cab-6cd643cb8ea1
x = jnp.array([1, 2, 1])
y = jnp.ones(10)
@ -202,7 +191,7 @@ Under the hood, this NumPy operation is translated to a much more general convol
```{code-cell} ipython3
:id: pi4f6ikjzc3l
:outputId: b9b37edc-b911-4010-aaf8-ee8f500111d7
:outputId: 0bb56ae2-7837-4c04-ff8b-6cbc0565b7d7
from jax import lax
result = lax.conv_general_dilated(
@ -261,7 +250,7 @@ This function returns the same results as the original, up to standard floating-
```{code-cell} ipython3
:id: oz7zzyS3AwMc
:outputId: 914f9242-82c4-4365-abb2-77843a704e03
:outputId: ed1c796c-59f8-4238-f6e2-f54330edadf0
np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
@ -274,7 +263,7 @@ But due to the compilation (which includes fusing of operations, avoidance of al
```{code-cell} ipython3
:id: 6mUB6VdDAEIY
:outputId: 5d7e1bbd-4064-4fe3-f3d9-5435b5283199
:outputId: 1050a69c-e713-44c1-b3eb-1ef875691978
%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
@ -288,7 +277,7 @@ For example, this operation can be executed in op-by-op mode:
```{code-cell} ipython3
:id: YfZd9mW7CSKM
:outputId: 899fedcc-0857-4381-8f57-bb653e0aa2f1
:outputId: 6fdbfde4-7cde-447f-badf-26e1f8db288d
def get_negatives(x):
return x[x < 0]
@ -303,7 +292,7 @@ But it returns an error if you attempt to execute it in jit mode:
```{code-cell} ipython3
:id: yYWvE4rxCjPK
:outputId: 765b46d3-49cd-41b7-9815-e8bb7cd80175
:outputId: 9cf7f2d4-8f28-4265-d701-d52086cfd437
:tags: [raises-exception]
jit(get_negatives)(x)
@ -327,7 +316,7 @@ To use `jax.jit` effectively, it is useful to understand how it works. Let's put
```{code-cell} ipython3
:id: TfjVIVuD4gnc
:outputId: df6ad898-b047-4ad1-eb18-2fbcb3fd2ab3
:outputId: 9f4ddcaa-8ab7-4984-afb6-47fede5314ea
@jit
def f(x, y):
@ -353,7 +342,7 @@ When we call the compiled function again on matching inputs, no re-compilation i
```{code-cell} ipython3
:id: xGntvzNH7skE
:outputId: 66694b8b-181f-4635-a8e2-1fc7f244d94b
:outputId: 43aaeee6-3853-4b00-fb2b-646df695204a
x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
@ -366,7 +355,7 @@ The extracted sequence of operations is encoded in a JAX expression, or *jaxpr*
```{code-cell} ipython3
:id: 89TMp_Op5-JZ
:outputId: 151210e2-af6f-4950-ac1e-9fdb81d4aae1
:outputId: 48212815-059a-4af1-de82-cd39ecac264a
from jax import make_jaxpr
@ -382,7 +371,7 @@ Note one consequence of this: because JIT compilation is done *without* informat
```{code-cell} ipython3
:id: A0rFdM95-Ix_
:outputId: d7ffa367-b241-488e-df96-ad0576536605
:outputId: e37bf04e-6a6a-4536-e423-f082f52d5f11
:tags: [raises-exception]
@jit
@ -398,7 +387,7 @@ If there are variables that you would not like to be traced, they can be marked
```{code-cell} ipython3
:id: K1C7ZnVv-lbv
:outputId: cdbdf152-30fd-4ecb-c9ec-1d1124f337f7
:outputId: e9d6cce3-b036-43da-ad99-887af9625ab0
from functools import partial
@ -415,7 +404,7 @@ Note that calling a JIT-compiled function with a different static argument resul
```{code-cell} ipython3
:id: sXqczBOrG7-w
:outputId: 3a3f50e6-d1fc-42bb-d6df-eb3d206e4b67
:outputId: 5fb7c278-b87e-4a6b-ef50-5e4e9c765b52
f(1, False)
```
@ -440,7 +429,7 @@ This distinction between static and traced values makes it important to think ab
```{code-cell} ipython3
:id: XJCQ7slcD4iU
:outputId: a89a5614-7359-4dc7-c165-03e7d0fc6610
:outputId: 3646dea0-f6b6-48e9-9dc0-c4dec7816b7a
:tags: [raises-exception]
import jax.numpy as jnp
@ -460,7 +449,7 @@ This fails with an error specifying that a tracer was found instead of a 1D sequ
```{code-cell} ipython3
:id: Cb4mbeVZEi_q
:outputId: f72c1ce3-950c-400f-bfea-10c0d0118911
:outputId: 30d8621f-34e1-4e1d-e6c4-c3e0d8769ec4
@jit
def f(x):
@ -481,7 +470,7 @@ A useful pattern is to use `numpy` for operations that should be static (i.e. do
```{code-cell} ipython3
:id: GiovOOPcGJhg
:outputId: 399ee059-1807-4866-9beb-1c5131e38e15
:outputId: 5363ad1b-23d9-4dd6-d9db-95a6c9de05da
from jax import jit
import jax.numpy as jnp