mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
DOC: re-enable execution of thinking_in_jax.ipynb
This commit is contained in:
parent
36cdafdcf4
commit
130a53f2a2
@ -85,6 +85,8 @@ suppress_warnings = [
|
||||
'ref.citation', # Many duplicated citations in numpy/scipy docstrings.
|
||||
'ref.footnote', # Many unreferenced footnotes in numpy/scipy docstrings
|
||||
'myst.header',
|
||||
# TODO(jakevdp): remove this suppression once issue is fixed.
|
||||
'misc.highlighting_failure', # https://github.com/ipython/ipython/issues/14142
|
||||
]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
@ -199,8 +201,6 @@ nb_execution_excludepatterns = [
|
||||
'notebooks/neural_network_with_tfds_data.*',
|
||||
# Slow notebook
|
||||
'notebooks/Neural_Network_and_Data_Loading.*',
|
||||
# Strange error apparently due to asynchronous cell execution
|
||||
'notebooks/thinking_in_jax.*',
|
||||
# Has extra requirements: networkx, pandas, pytorch, tensorflow, etc.
|
||||
'jep/9407-type-promotion.*',
|
||||
# TODO(jakevdp): enable execution on the following if possible:
|
||||
|
File diff suppressed because one or more lines are too long
@ -11,24 +11,6 @@ kernelspec:
|
||||
name: python3
|
||||
---
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: aPUwOm-eCSFD
|
||||
:tags: [remove-cell]
|
||||
|
||||
# Configure ipython to hide long tracebacks.
|
||||
import sys
|
||||
ipython = get_ipython()
|
||||
|
||||
def minimal_traceback(*args, **kwargs):
|
||||
etype, value, tb = sys.exc_info()
|
||||
value.__cause__ = None # suppress chained exceptions
|
||||
stb = ipython.InteractiveTB.structured_traceback(etype, value, tb)
|
||||
del stb[3:-1]
|
||||
return ipython._showtraceback(etype, value, stb)
|
||||
|
||||
ipython.showtraceback = minimal_traceback
|
||||
```
|
||||
|
||||
+++ {"id": "LQHmwePqryRU"}
|
||||
|
||||
# How to Think in JAX
|
||||
@ -51,7 +33,7 @@ NumPy provides a well-known, powerful API for working with numerical data. For c
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: kZaOXL7-uvUP
|
||||
:outputId: 17a9ee0a-8719-44bb-a9fe-4c9f24649fef
|
||||
:outputId: 7fd4dd8e-4194-4983-ac6b-28059f8feb90
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
@ -63,7 +45,7 @@ plt.plot(x_np, y_np);
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: 18XbGpRLuZlr
|
||||
:outputId: 9e98d928-1925-45b1-d886-37956ca95e7c
|
||||
:outputId: 3d073b3c-913f-410b-ee33-b3a0eb878436
|
||||
|
||||
import jax.numpy as jnp
|
||||
|
||||
@ -80,14 +62,14 @@ The arrays themselves are implemented as different Python types:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: PjFFunI7xNe8
|
||||
:outputId: e1706c61-2821-437a-efcd-d8082f913c1f
|
||||
:outputId: d3b0007e-7997-45c0-d4b8-9f5699cedcbc
|
||||
|
||||
type(x_np)
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: kpv5K7QYxQnX
|
||||
:outputId: 8a3f1cb6-c6d6-494c-8efe-24a8217a9d55
|
||||
:outputId: ba68a1de-f938-477d-9942-83a839aeca09
|
||||
|
||||
type(x_jnp)
|
||||
```
|
||||
@ -102,7 +84,7 @@ Here is an example of mutating an array in NumPy:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: fzp-y1ZVyGD4
|
||||
:outputId: 300a44cc-1ccd-4fb2-f0ee-2179763f7690
|
||||
:outputId: 6eb76bf8-0edd-43a5-b2be-85a79fb23190
|
||||
|
||||
# NumPy: mutable arrays
|
||||
x = np.arange(10)
|
||||
@ -114,9 +96,16 @@ print(x)
|
||||
|
||||
The equivalent in JAX results in an error, as JAX arrays are immutable:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: l2AP0QERb0P7
|
||||
:outputId: 528a8e5f-538f-4739-fe95-1c3605ba8c8a
|
||||
|
||||
%xmode minimal
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: pCPX0JR-yM4i
|
||||
:outputId: 02a442bc-8f23-4dce-9500-81cd28c0b21f
|
||||
:outputId: c7bf4afd-8b7f-4dac-d065-8189679861d6
|
||||
:tags: [raises-exception]
|
||||
|
||||
# JAX: immutable arrays
|
||||
@ -130,7 +119,7 @@ For updating individual elements, JAX provides an [indexed update syntax](https:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: 8zqPEAeP3UK5
|
||||
:outputId: 7e6c996d-d0b0-4d52-e722-410ba78eb3b1
|
||||
:outputId: 20a40c26-3419-4e60-bd2c-83ad30bd7650
|
||||
|
||||
y = x.at[0].set(10)
|
||||
print(x)
|
||||
@ -155,7 +144,7 @@ For example, while `jax.numpy` will implicitly promote arguments to allow operat
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: c6EFPcj12mw0
|
||||
:outputId: 730e2ca4-30a5-45bc-923c-c3a5143496e2
|
||||
:outputId: 827d09eb-c8aa-43bc-b471-0a6c9c4f6601
|
||||
|
||||
import jax.numpy as jnp
|
||||
jnp.add(1, 1.0) # jax.numpy API implicitly promotes mixed types.
|
||||
@ -163,7 +152,7 @@ jnp.add(1, 1.0) # jax.numpy API implicitly promotes mixed types.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: 0VkqlcXL2qSp
|
||||
:outputId: 601b0562-3e6a-402d-f83b-3afdd1e7e7c4
|
||||
:outputId: 7e1e9233-2fe1-46a8-8eb1-1d1dbc54b58c
|
||||
:tags: [raises-exception]
|
||||
|
||||
from jax import lax
|
||||
@ -176,7 +165,7 @@ If using `jax.lax` directly, you'll have to do type promotion explicitly in such
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: 3PNQlieT81mi
|
||||
:outputId: cb3ed074-f410-456f-c086-23107eae2634
|
||||
:outputId: 4bd2b6f3-d2d1-44cb-f8ee-18976ae40239
|
||||
|
||||
lax.add(jnp.float32(1), 1.0)
|
||||
```
|
||||
@ -189,7 +178,7 @@ For example, consider a 1D convolution, which can be expressed in NumPy this way
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: Bv-7XexyzVCN
|
||||
:outputId: f5d38cd8-e7fc-49e2-bff3-a0eee306cb54
|
||||
:outputId: d570f64a-ca61-456f-8cab-6cd643cb8ea1
|
||||
|
||||
x = jnp.array([1, 2, 1])
|
||||
y = jnp.ones(10)
|
||||
@ -202,7 +191,7 @@ Under the hood, this NumPy operation is translated to a much more general convol
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: pi4f6ikjzc3l
|
||||
:outputId: b9b37edc-b911-4010-aaf8-ee8f500111d7
|
||||
:outputId: 0bb56ae2-7837-4c04-ff8b-6cbc0565b7d7
|
||||
|
||||
from jax import lax
|
||||
result = lax.conv_general_dilated(
|
||||
@ -261,7 +250,7 @@ This function returns the same results as the original, up to standard floating-
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: oz7zzyS3AwMc
|
||||
:outputId: 914f9242-82c4-4365-abb2-77843a704e03
|
||||
:outputId: ed1c796c-59f8-4238-f6e2-f54330edadf0
|
||||
|
||||
np.random.seed(1701)
|
||||
X = jnp.array(np.random.rand(10000, 10))
|
||||
@ -274,7 +263,7 @@ But due to the compilation (which includes fusing of operations, avoidance of al
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: 6mUB6VdDAEIY
|
||||
:outputId: 5d7e1bbd-4064-4fe3-f3d9-5435b5283199
|
||||
:outputId: 1050a69c-e713-44c1-b3eb-1ef875691978
|
||||
|
||||
%timeit norm(X).block_until_ready()
|
||||
%timeit norm_compiled(X).block_until_ready()
|
||||
@ -288,7 +277,7 @@ For example, this operation can be executed in op-by-op mode:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: YfZd9mW7CSKM
|
||||
:outputId: 899fedcc-0857-4381-8f57-bb653e0aa2f1
|
||||
:outputId: 6fdbfde4-7cde-447f-badf-26e1f8db288d
|
||||
|
||||
def get_negatives(x):
|
||||
return x[x < 0]
|
||||
@ -303,7 +292,7 @@ But it returns an error if you attempt to execute it in jit mode:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: yYWvE4rxCjPK
|
||||
:outputId: 765b46d3-49cd-41b7-9815-e8bb7cd80175
|
||||
:outputId: 9cf7f2d4-8f28-4265-d701-d52086cfd437
|
||||
:tags: [raises-exception]
|
||||
|
||||
jit(get_negatives)(x)
|
||||
@ -327,7 +316,7 @@ To use `jax.jit` effectively, it is useful to understand how it works. Let's put
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: TfjVIVuD4gnc
|
||||
:outputId: df6ad898-b047-4ad1-eb18-2fbcb3fd2ab3
|
||||
:outputId: 9f4ddcaa-8ab7-4984-afb6-47fede5314ea
|
||||
|
||||
@jit
|
||||
def f(x, y):
|
||||
@ -353,7 +342,7 @@ When we call the compiled function again on matching inputs, no re-compilation i
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: xGntvzNH7skE
|
||||
:outputId: 66694b8b-181f-4635-a8e2-1fc7f244d94b
|
||||
:outputId: 43aaeee6-3853-4b00-fb2b-646df695204a
|
||||
|
||||
x2 = np.random.randn(3, 4)
|
||||
y2 = np.random.randn(4)
|
||||
@ -366,7 +355,7 @@ The extracted sequence of operations is encoded in a JAX expression, or *jaxpr*
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: 89TMp_Op5-JZ
|
||||
:outputId: 151210e2-af6f-4950-ac1e-9fdb81d4aae1
|
||||
:outputId: 48212815-059a-4af1-de82-cd39ecac264a
|
||||
|
||||
from jax import make_jaxpr
|
||||
|
||||
@ -382,7 +371,7 @@ Note one consequence of this: because JIT compilation is done *without* informat
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: A0rFdM95-Ix_
|
||||
:outputId: d7ffa367-b241-488e-df96-ad0576536605
|
||||
:outputId: e37bf04e-6a6a-4536-e423-f082f52d5f11
|
||||
:tags: [raises-exception]
|
||||
|
||||
@jit
|
||||
@ -398,7 +387,7 @@ If there are variables that you would not like to be traced, they can be marked
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: K1C7ZnVv-lbv
|
||||
:outputId: cdbdf152-30fd-4ecb-c9ec-1d1124f337f7
|
||||
:outputId: e9d6cce3-b036-43da-ad99-887af9625ab0
|
||||
|
||||
from functools import partial
|
||||
|
||||
@ -415,7 +404,7 @@ Note that calling a JIT-compiled function with a different static argument resul
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: sXqczBOrG7-w
|
||||
:outputId: 3a3f50e6-d1fc-42bb-d6df-eb3d206e4b67
|
||||
:outputId: 5fb7c278-b87e-4a6b-ef50-5e4e9c765b52
|
||||
|
||||
f(1, False)
|
||||
```
|
||||
@ -440,7 +429,7 @@ This distinction between static and traced values makes it important to think ab
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: XJCQ7slcD4iU
|
||||
:outputId: a89a5614-7359-4dc7-c165-03e7d0fc6610
|
||||
:outputId: 3646dea0-f6b6-48e9-9dc0-c4dec7816b7a
|
||||
:tags: [raises-exception]
|
||||
|
||||
import jax.numpy as jnp
|
||||
@ -460,7 +449,7 @@ This fails with an error specifying that a tracer was found instead of a 1D sequ
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: Cb4mbeVZEi_q
|
||||
:outputId: f72c1ce3-950c-400f-bfea-10c0d0118911
|
||||
:outputId: 30d8621f-34e1-4e1d-e6c4-c3e0d8769ec4
|
||||
|
||||
@jit
|
||||
def f(x):
|
||||
@ -481,7 +470,7 @@ A useful pattern is to use `numpy` for operations that should be static (i.e. do
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: GiovOOPcGJhg
|
||||
:outputId: 399ee059-1807-4866-9beb-1c5131e38e15
|
||||
:outputId: 5363ad1b-23d9-4dd6-d9db-95a6c9de05da
|
||||
|
||||
from jax import jit
|
||||
import jax.numpy as jnp
|
||||
|
Loading…
x
Reference in New Issue
Block a user