{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "WkadOyTDCAWD" }, "source": [ "# JAX Colab CPU Test\n", "\n", "This notebook is meant to be run in a [Colab](http://colab.research.google.com) CPU runtime as a basic check for JAX updates." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 68 }, "colab_type": "code", "id": "_tKNrbqqBHwu", "outputId": "071fb360-ddf5-41ae-d772-acc08ec71d9b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "m-s-1p12yf76kgzz\n", "0.1.64\n", "0.1.45\n" ] } ], "source": [ "import jax\n", "import jaxlib\n", "\n", "!cat /var/colab/hostname\n", "print(jax.__version__)\n", "print(jaxlib.__version__)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "oqEG21rADO1F" }, "source": [ "## Confirm Device" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 68 }, "colab_type": "code", "id": "8BwzMYhKGQj6", "outputId": "f79a44e3-4303-494c-9288-a4e582bb34cb" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.\n", " warnings.warn('No GPU/TPU found, falling back to CPU.')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "JAX device type: cpu:0\n" ] } ], "source": [ "from jaxlib import xla_extension\n", "import jax\n", "key = jax.random.PRNGKey(1701)\n", "arr = jax.random.normal(key, (1000,))\n", "device = list(arr.devices())[0]\n", "print(f\"JAX device type: {device}\")\n", "assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\"" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "z0FUY9yUC4k1" }, "source": [ "## Matrix Multiplication" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "eXn8GUl6CG5N", "outputId": "307aa669-76f1-4117-b62a-7acb2aee2c16" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0216691\n" ] } ], "source": [ "import jax\n", "import numpy as np\n", "\n", "# matrix multiplication on GPU\n", "key = jax.random.PRNGKey(0)\n", "x = jax.random.normal(key, (3000, 3000))\n", "result = jax.numpy.dot(x, x.T).mean()\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "0zTA2Q19DW4G" }, "source": [ "## Linear Algebra" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "uW9j84_UDYof", "outputId": "3dd5d7c0-9d47-4be1-c6f7-068b432b69f7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[6.9178133 5.9580317 5.581113 4.506963 4.111582 3.973543 3.3307292\n", " 2.8664916 1.8229378 1.5478933]\n" ] } ], "source": [ "import jax.numpy as jnp\n", "import jax.random as rand\n", "\n", "N = 10\n", "M = 20\n", "key = rand.PRNGKey(1701)\n", "\n", "X = rand.normal(key, (N, M))\n", "u, s, vt = jnp.linalg.svd(X)\n", "assert u.shape == (N, N)\n", "assert vt.shape == (M, M)\n", "print(s)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "jCyKUn4-DCXn" }, "source": [ "## XLA Compilation" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "2GOn_HhDPuEn", "outputId": "41a40dd9-3680-458d-cedd-81ebcc2ab06f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 0.34676832 -0.7532232 1.7060695 ... 2.1208048 -0.42621925\n", " 0.13093236]\n" ] } ], "source": [ "@jax.jit\n", "def selu(x, alpha=1.67, lmbda=1.05):\n", " return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)\n", "x = jax.random.normal(key, (5000,))\n", "result = selu(x).block_until_ready()\n", "print(result)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "JAX Colab CPU Test", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }