{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "WkadOyTDCAWD" }, "source": [ "# JAX Colab GPU Test\n", "\n", "This notebook is meant to be run in a [Colab](http://colab.research.google.com) GPU runtime as a basic check for JAX updates." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 68 }, "colab_type": "code", "id": "_tKNrbqqBHwu", "outputId": "ae4a051a-91ed-4742-c8e1-31de8304ef33" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gpu-t4-s-kbefivsjoreh\n", "0.1.64\n", "0.1.45\n" ] } ], "source": [ "import jax\n", "import jaxlib\n", "\n", "!cat /var/colab/hostname\n", "print(jax.__version__)\n", "print(jaxlib.__version__)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "oqEG21rADO1F" }, "source": [ "## Confirm Device" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "8BwzMYhKGQj6", "outputId": "ff4f52b3-f7bb-468a-c1ad-debe65841f3f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "JAX device type: gpu:0\n" ] } ], "source": [ "import jax\n", "key = jax.random.PRNGKey(1701)\n", "arr = jax.random.normal(key, (1000,))\n", "device = list(arr.devices())[0]\n", "print(f\"JAX device type: {device}\")\n", "assert device.platform == \"gpu\", \"unexpected JAX device type\"" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "z0FUY9yUC4k1" }, "source": [ "## Matrix Multiplication" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "eXn8GUl6CG5N", "outputId": "688c37f3-e830-4ba8-b1e6-b4e014cb11a9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0216676\n" ] } ], "source": [ "import jax\n", "import numpy as np\n", "\n", "# matrix multiplication on GPU\n", "key = jax.random.PRNGKey(0)\n", "x = jax.random.normal(key, (3000, 3000))\n", "result = jax.numpy.dot(x, x.T).mean()\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "0zTA2Q19DW4G" }, "source": [ "## Linear Algebra" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "uW9j84_UDYof", "outputId": "80069760-12ab-4df2-9f5c-be2536de59b7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[6.9178247 5.9580336 5.5811076 4.5069666 4.1115823 3.9735446 3.3307252\n", " 2.866489 1.8229384 1.5478926]\n" ] } ], "source": [ "import jax.numpy as jnp\n", "import jax.random as rand\n", "\n", "N = 10\n", "M = 20\n", "key = rand.PRNGKey(1701)\n", "\n", "X = rand.normal(key, (N, M))\n", "u, s, vt = jnp.linalg.svd(X)\n", "assert u.shape == (N, N)\n", "assert vt.shape == (M, M)\n", "print(s)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "jCyKUn4-DCXn" }, "source": [ "## XLA Compilation" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "2GOn_HhDPuEn", "outputId": "a51d7d07-8513-4503-bceb-d5b0e2b4e4a8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 0.34676838 -0.7532232 1.7060698 ... 2.1208055 -0.42621925\n", " 0.13093245]\n" ] } ], "source": [ "@jax.jit\n", "def selu(x, alpha=1.67, lmbda=1.05):\n", " return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)\n", "x = jax.random.normal(key, (5000,))\n", "result = selu(x).block_until_ready()\n", "print(result)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "JAX Colab GPU Test", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }