mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 19:06:07 +00:00
560 lines
10 KiB
Plaintext
560 lines
10 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "ebUMqK9mGIDm"
|
|
},
|
|
"source": [
|
|
"## The basics: interactive NumPy on GPU and TPU\n",
|
|
"\n",
|
|
"---\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "27TqNtiQF97X"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import jax\n",
|
|
"import jax.numpy as jnp\n",
|
|
"from jax import random"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "cRWoxSCNGU4o"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"key = random.key(0)\n",
|
|
"key, subkey = random.split(key)\n",
|
|
"x = random.normal(key, (5000, 5000))\n",
|
|
"\n",
|
|
"print(x.shape)\n",
|
|
"print(x.dtype)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "diPllsvgGfSA"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"y = jnp.dot(x, x)\n",
|
|
"print(y[0, 0])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "8-psauxnGiRk"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"x"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "-2FMQ8UeoTJ8"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"plt.plot(x[0])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(jnp.dot(x, x.T))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "z4VX5PkMHJIu"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(jnp.dot(x, 2 * x)[[0, 2, 1, 0], ..., None, ::-1])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "ORZ9Odu85BCJ"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"x_cpu = np.array(x)\n",
|
|
"%timeit -n 5 -r 2 np.dot(x_cpu, x_cpu)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "5BKh0eeAGvO5"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%timeit -n 5 -r 5 jnp.dot(x, x).block_until_ready()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "fm4Q2zpFHUAu"
|
|
},
|
|
"source": [
|
|
"## Automatic differentiation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "MCIQbyUYHWn1"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from jax import grad"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "kfqZpKYsHo4j"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def f(x):\n",
|
|
" if x > 0:\n",
|
|
" return 2 * x ** 3\n",
|
|
" else:\n",
|
|
" return 3 * x"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "K_26_odPHqLJ"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"key = random.key(0)\n",
|
|
"x = random.normal(key, ())\n",
|
|
"\n",
|
|
"print(grad(f)(x))\n",
|
|
"print(grad(f)(-x))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "q5V3A6loHrhS"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(grad(grad(f))(-x))\n",
|
|
"print(grad(grad(grad(f)))(-x))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "bmxAPFC0I8b0"
|
|
},
|
|
"source": [
|
|
"Other JAX autodiff highlights:\n",
|
|
"\n",
|
|
"* Forward- and reverse-mode, totally composable\n",
|
|
"* Fast Jacobians and Hessians\n",
|
|
"* Complex number support (holomorphic and non-holomorphic)\n",
|
|
"* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n",
|
|
"\n",
|
|
"\n",
|
|
"For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "TRkxaVLJKNre"
|
|
},
|
|
"source": [
|
|
"## End-to-end compilation with XLA with `jit`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "bKo4rX9-KSW7"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from jax import jit"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "94iIgZSfKWh8"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"key = random.key(0)\n",
|
|
"x = random.normal(key, (5000, 5000))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "Ybuz8Ag9KXMd"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def f(x):\n",
|
|
" y = x\n",
|
|
" for _ in range(10):\n",
|
|
" y = y - 0.1 * y + 3.\n",
|
|
" return y[:100, :100]\n",
|
|
"\n",
|
|
"f(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "Y9dx5ifSKaGJ"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"g = jit(f)\n",
|
|
"g(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "UtsS67BvKYkC"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%timeit f(x).block_until_ready()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "-vfcaSo9KbvR"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%timeit -n 100 g(x).block_until_ready()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "E3BQF1_AKeLn"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"grad(jit(grad(jit(grad(jnp.tanh)))))(1.0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "Tmf1NT2Wqv5p"
|
|
},
|
|
"source": [
|
|
"## Parallelization over multiple accelerators with pmap"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "t6RRAFn1CEln"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"jax.device_count()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "tEK1I6Duqunw"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from jax import pmap"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "S-iCNfeGqzkY"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"y = pmap(lambda x: x ** 2)(jnp.arange(8))\n",
|
|
"print(y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "xgutf5JPP3wi"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"y"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "uvDL2_bCq7kq"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import matplotlib.pyplot as plt\n",
|
|
"plt.plot(y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "xf5N9ZRirJhL"
|
|
},
|
|
"source": [
|
|
"### Collective communication operations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from functools import partial\n",
|
|
"from jax.lax import psum\n",
|
|
"\n",
|
|
"@partial(pmap, axis_name='i')\n",
|
|
"def f(x):\n",
|
|
" total = psum(x, 'i')\n",
|
|
" return x / total, total\n",
|
|
"\n",
|
|
"normalized, total = f(jnp.arange(8.))\n",
|
|
"\n",
|
|
"print(f\"normalized:\\n{normalized}\\n\")\n",
|
|
"print(\"total:\", total)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "jC-KIMQ1q-lK"
|
|
},
|
|
"source": [
|
|
"For more, see the [`pmap` cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Automatic parallelization with sharded_jit (new!)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from jax.experimental import sharded_jit, PartitionSpec as P"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from jax import lax\n",
|
|
"\n",
|
|
"conv = lambda image, kernel: lax.conv(image, kernel, (1, 1), 'SAME')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"image = jnp.ones((1, 8, 2000, 1000)).astype(np.float32)\n",
|
|
"kernel = jnp.array(np.random.random((8, 8, 5, 5)).astype(np.float32))\n",
|
|
"\n",
|
|
"np.set_printoptions(edgeitems=1)\n",
|
|
"conv(image, kernel)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%timeit conv(image, kernel).block_until_ready()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"image_partitions = P(1, 1, 4, 2)\n",
|
|
"sharded_conv = sharded_jit(conv,\n",
|
|
" in_parts=(image_partitions, None),\n",
|
|
" out_parts=image_partitions)\n",
|
|
"\n",
|
|
"sharded_conv(image, kernel)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%timeit -n 10 sharded_conv(image, kernel).block_until_ready()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "TPU",
|
|
"colab": {
|
|
"collapsed_sections": [
|
|
"AvXl1WDPKjmV"
|
|
],
|
|
"name": "JAX demo.ipynb",
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|