rocm_jax/tests/notebooks/colab_gpu.ipynb

243 lines
5.9 KiB
Plaintext
Raw Permalink Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/jax-ml/jax/blob/main/tests/notebooks/colab_gpu.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "WkadOyTDCAWD"
},
"source": [
"# JAX Colab GPU Test\n",
"\n",
"This notebook is meant to be run in a [Colab](http://colab.research.google.com) GPU runtime as a basic check for JAX updates."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
},
"colab_type": "code",
"id": "_tKNrbqqBHwu",
"outputId": "ae4a051a-91ed-4742-c8e1-31de8304ef33"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gpu-t4-s-kbefivsjoreh\n",
"0.1.64\n",
"0.1.45\n"
]
}
],
"source": [
"import jax\n",
"import jaxlib\n",
"\n",
"!cat /var/colab/hostname\n",
"print(jax.__version__)\n",
"print(jaxlib.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "oqEG21rADO1F"
},
"source": [
"## Confirm Device"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "8BwzMYhKGQj6",
"outputId": "ff4f52b3-f7bb-468a-c1ad-debe65841f3f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"JAX device type: gpu:0\n"
]
}
],
"source": [
"import jax\n",
"key = jax.random.PRNGKey(1701)\n",
"arr = jax.random.normal(key, (1000,))\n",
"device = list(arr.devices())[0]\n",
"print(f\"JAX device type: {device}\")\n",
"assert device.platform == \"gpu\", \"unexpected JAX device type\""
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "z0FUY9yUC4k1"
},
"source": [
"## Matrix Multiplication"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "eXn8GUl6CG5N",
"outputId": "688c37f3-e830-4ba8-b1e6-b4e014cb11a9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0216676\n"
]
}
],
"source": [
"import jax\n",
"import numpy as np\n",
"\n",
"# matrix multiplication on GPU\n",
"key = jax.random.PRNGKey(0)\n",
"x = jax.random.normal(key, (3000, 3000))\n",
"result = jax.numpy.dot(x, x.T).mean()\n",
"print(result)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0zTA2Q19DW4G"
},
"source": [
"## Linear Algebra"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"colab_type": "code",
"id": "uW9j84_UDYof",
"outputId": "80069760-12ab-4df2-9f5c-be2536de59b7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[6.9178247 5.9580336 5.5811076 4.5069666 4.1115823 3.9735446 3.3307252\n",
" 2.866489 1.8229384 1.5478926]\n"
]
}
],
"source": [
"import jax.numpy as jnp\n",
"import jax.random as rand\n",
"\n",
"N = 10\n",
"M = 20\n",
"key = rand.PRNGKey(1701)\n",
"\n",
"X = rand.normal(key, (N, M))\n",
"u, s, vt = jnp.linalg.svd(X)\n",
"assert u.shape == (N, N)\n",
"assert vt.shape == (M, M)\n",
"print(s)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "jCyKUn4-DCXn"
},
"source": [
"## XLA Compilation"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"colab_type": "code",
"id": "2GOn_HhDPuEn",
"outputId": "a51d7d07-8513-4503-bceb-d5b0e2b4e4a8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0.34676838 -0.7532232 1.7060698 ... 2.1208055 -0.42621925\n",
" 0.13093245]\n"
]
}
],
"source": [
"@jax.jit\n",
"def selu(x, alpha=1.67, lmbda=1.05):\n",
" return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)\n",
"x = jax.random.normal(key, (5000,))\n",
"result = selu(x).block_until_ready()\n",
"print(result)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "JAX Colab GPU Test",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}