2020-05-08 11:11:42 -07:00
|
|
|
{
|
|
|
|
"cells": [
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {
|
2023-11-29 16:52:09 -08:00
|
|
|
"colab_type": "text",
|
|
|
|
"id": "view-in-github"
|
2020-05-08 11:11:42 -07:00
|
|
|
},
|
|
|
|
"source": [
|
2024-09-20 07:51:48 -07:00
|
|
|
"<a href=\"https://colab.research.google.com/github/jax-ml/jax/blob/main/tests/notebooks/colab_gpu.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
2020-05-08 11:11:42 -07:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {
|
2023-11-29 16:52:09 -08:00
|
|
|
"colab_type": "text",
|
|
|
|
"id": "WkadOyTDCAWD"
|
2020-05-08 11:11:42 -07:00
|
|
|
},
|
|
|
|
"source": [
|
|
|
|
"# JAX Colab GPU Test\n",
|
|
|
|
"\n",
|
|
|
|
"This notebook is meant to be run in a [Colab](http://colab.research.google.com) GPU runtime as a basic check for JAX updates."
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2023-11-29 16:52:09 -08:00
|
|
|
"execution_count": 1,
|
2020-05-08 11:11:42 -07:00
|
|
|
"metadata": {
|
|
|
|
"colab": {
|
|
|
|
"base_uri": "https://localhost:8080/",
|
|
|
|
"height": 68
|
2023-11-29 16:52:09 -08:00
|
|
|
},
|
|
|
|
"colab_type": "code",
|
|
|
|
"id": "_tKNrbqqBHwu",
|
|
|
|
"outputId": "ae4a051a-91ed-4742-c8e1-31de8304ef33"
|
2020-05-08 11:11:42 -07:00
|
|
|
},
|
|
|
|
"outputs": [
|
|
|
|
{
|
2023-11-29 16:52:09 -08:00
|
|
|
"name": "stdout",
|
2020-05-08 11:11:42 -07:00
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"gpu-t4-s-kbefivsjoreh\n",
|
|
|
|
"0.1.64\n",
|
|
|
|
"0.1.45\n"
|
2023-11-29 16:52:09 -08:00
|
|
|
]
|
2020-05-08 11:11:42 -07:00
|
|
|
}
|
2023-11-29 16:52:09 -08:00
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"import jax\n",
|
|
|
|
"import jaxlib\n",
|
|
|
|
"\n",
|
|
|
|
"!cat /var/colab/hostname\n",
|
|
|
|
"print(jax.__version__)\n",
|
|
|
|
"print(jaxlib.__version__)"
|
2020-05-08 11:11:42 -07:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {
|
2023-11-29 16:52:09 -08:00
|
|
|
"colab_type": "text",
|
|
|
|
"id": "oqEG21rADO1F"
|
2020-05-08 11:11:42 -07:00
|
|
|
},
|
|
|
|
"source": [
|
|
|
|
"## Confirm Device"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2023-11-29 16:52:09 -08:00
|
|
|
"execution_count": 2,
|
2020-05-08 11:11:42 -07:00
|
|
|
"metadata": {
|
|
|
|
"colab": {
|
|
|
|
"base_uri": "https://localhost:8080/",
|
|
|
|
"height": 34
|
2023-11-29 16:52:09 -08:00
|
|
|
},
|
|
|
|
"colab_type": "code",
|
|
|
|
"id": "8BwzMYhKGQj6",
|
|
|
|
"outputId": "ff4f52b3-f7bb-468a-c1ad-debe65841f3f"
|
2020-05-08 11:11:42 -07:00
|
|
|
},
|
|
|
|
"outputs": [
|
|
|
|
{
|
2023-11-29 16:52:09 -08:00
|
|
|
"name": "stdout",
|
2020-05-08 11:11:42 -07:00
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"JAX device type: gpu:0\n"
|
2023-11-29 16:52:09 -08:00
|
|
|
]
|
2020-05-08 11:11:42 -07:00
|
|
|
}
|
2023-11-29 16:52:09 -08:00
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"import jax\n",
|
|
|
|
"key = jax.random.PRNGKey(1701)\n",
|
|
|
|
"arr = jax.random.normal(key, (1000,))\n",
|
|
|
|
"device = list(arr.devices())[0]\n",
|
|
|
|
"print(f\"JAX device type: {device}\")\n",
|
|
|
|
"assert device.platform == \"gpu\", \"unexpected JAX device type\""
|
2020-05-08 11:11:42 -07:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {
|
2023-11-29 16:52:09 -08:00
|
|
|
"colab_type": "text",
|
|
|
|
"id": "z0FUY9yUC4k1"
|
2020-05-08 11:11:42 -07:00
|
|
|
},
|
|
|
|
"source": [
|
|
|
|
"## Matrix Multiplication"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2023-11-29 16:52:09 -08:00
|
|
|
"execution_count": 3,
|
2020-05-08 11:11:42 -07:00
|
|
|
"metadata": {
|
|
|
|
"colab": {
|
|
|
|
"base_uri": "https://localhost:8080/",
|
|
|
|
"height": 34
|
2023-11-29 16:52:09 -08:00
|
|
|
},
|
|
|
|
"colab_type": "code",
|
|
|
|
"id": "eXn8GUl6CG5N",
|
|
|
|
"outputId": "688c37f3-e830-4ba8-b1e6-b4e014cb11a9"
|
2020-05-08 11:11:42 -07:00
|
|
|
},
|
2023-11-29 16:52:09 -08:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"1.0216676\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2020-05-08 11:11:42 -07:00
|
|
|
"source": [
|
|
|
|
"import jax\n",
|
|
|
|
"import numpy as np\n",
|
|
|
|
"\n",
|
|
|
|
"# matrix multiplication on GPU\n",
|
|
|
|
"key = jax.random.PRNGKey(0)\n",
|
|
|
|
"x = jax.random.normal(key, (3000, 3000))\n",
|
|
|
|
"result = jax.numpy.dot(x, x.T).mean()\n",
|
|
|
|
"print(result)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {
|
2023-11-29 16:52:09 -08:00
|
|
|
"colab_type": "text",
|
|
|
|
"id": "0zTA2Q19DW4G"
|
2020-05-08 11:11:42 -07:00
|
|
|
},
|
|
|
|
"source": [
|
|
|
|
"## Linear Algebra"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2023-11-29 16:52:09 -08:00
|
|
|
"execution_count": 4,
|
2020-05-08 11:11:42 -07:00
|
|
|
"metadata": {
|
|
|
|
"colab": {
|
|
|
|
"base_uri": "https://localhost:8080/",
|
|
|
|
"height": 51
|
2023-11-29 16:52:09 -08:00
|
|
|
},
|
|
|
|
"colab_type": "code",
|
|
|
|
"id": "uW9j84_UDYof",
|
|
|
|
"outputId": "80069760-12ab-4df2-9f5c-be2536de59b7"
|
2020-05-08 11:11:42 -07:00
|
|
|
},
|
2023-11-29 16:52:09 -08:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"[6.9178247 5.9580336 5.5811076 4.5069666 4.1115823 3.9735446 3.3307252\n",
|
|
|
|
" 2.866489 1.8229384 1.5478926]\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2020-05-08 11:11:42 -07:00
|
|
|
"source": [
|
|
|
|
"import jax.numpy as jnp\n",
|
|
|
|
"import jax.random as rand\n",
|
|
|
|
"\n",
|
|
|
|
"N = 10\n",
|
|
|
|
"M = 20\n",
|
|
|
|
"key = rand.PRNGKey(1701)\n",
|
|
|
|
"\n",
|
|
|
|
"X = rand.normal(key, (N, M))\n",
|
|
|
|
"u, s, vt = jnp.linalg.svd(X)\n",
|
|
|
|
"assert u.shape == (N, N)\n",
|
|
|
|
"assert vt.shape == (M, M)\n",
|
|
|
|
"print(s)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {
|
2023-11-29 16:52:09 -08:00
|
|
|
"colab_type": "text",
|
|
|
|
"id": "jCyKUn4-DCXn"
|
2020-05-08 11:11:42 -07:00
|
|
|
},
|
|
|
|
"source": [
|
|
|
|
"## XLA Compilation"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2023-11-29 16:52:09 -08:00
|
|
|
"execution_count": 5,
|
2020-05-08 11:11:42 -07:00
|
|
|
"metadata": {
|
|
|
|
"colab": {
|
|
|
|
"base_uri": "https://localhost:8080/",
|
|
|
|
"height": 51
|
2023-11-29 16:52:09 -08:00
|
|
|
},
|
|
|
|
"colab_type": "code",
|
|
|
|
"id": "2GOn_HhDPuEn",
|
|
|
|
"outputId": "a51d7d07-8513-4503-bceb-d5b0e2b4e4a8"
|
2020-05-08 11:11:42 -07:00
|
|
|
},
|
|
|
|
"outputs": [
|
|
|
|
{
|
2023-11-29 16:52:09 -08:00
|
|
|
"name": "stdout",
|
2020-05-08 11:11:42 -07:00
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"[ 0.34676838 -0.7532232 1.7060698 ... 2.1208055 -0.42621925\n",
|
|
|
|
" 0.13093245]\n"
|
2023-11-29 16:52:09 -08:00
|
|
|
]
|
2020-05-08 11:11:42 -07:00
|
|
|
}
|
2023-11-29 16:52:09 -08:00
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"@jax.jit\n",
|
|
|
|
"def selu(x, alpha=1.67, lmbda=1.05):\n",
|
|
|
|
" return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)\n",
|
|
|
|
"x = jax.random.normal(key, (5000,))\n",
|
|
|
|
"result = selu(x).block_until_ready()\n",
|
|
|
|
"print(result)"
|
2020-05-08 11:11:42 -07:00
|
|
|
]
|
|
|
|
}
|
2023-11-29 16:52:09 -08:00
|
|
|
],
|
|
|
|
"metadata": {
|
|
|
|
"accelerator": "GPU",
|
|
|
|
"colab": {
|
|
|
|
"collapsed_sections": [],
|
|
|
|
"name": "JAX Colab GPU Test",
|
|
|
|
"provenance": []
|
|
|
|
},
|
|
|
|
"kernelspec": {
|
|
|
|
"display_name": "Python 3",
|
|
|
|
"name": "python3"
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"nbformat": 4,
|
|
|
|
"nbformat_minor": 0
|
2021-06-18 08:55:08 +03:00
|
|
|
}
|