mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
250 lines
6.2 KiB
Plaintext
250 lines
6.2 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"name": "JAX Colab CPU Test",
|
|
"provenance": [],
|
|
"collapsed_sections": []
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"display_name": "Python 3"
|
|
}
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "view-in-github",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"<a href=\"https://colab.research.google.com/github/jax-ml/jax/blob/main/tests/notebooks/colab_cpu.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "WkadOyTDCAWD",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"# JAX Colab CPU Test\n",
|
|
"\n",
|
|
"This notebook is meant to be run in a [Colab](http://colab.research.google.com) CPU runtime as a basic check for JAX updates."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "_tKNrbqqBHwu",
|
|
"colab_type": "code",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 68
|
|
},
|
|
"outputId": "071fb360-ddf5-41ae-d772-acc08ec71d9b"
|
|
},
|
|
"source": [
|
|
"import jax\n",
|
|
"import jaxlib\n",
|
|
"\n",
|
|
"!cat /var/colab/hostname\n",
|
|
"print(jax.__version__)\n",
|
|
"print(jaxlib.__version__)"
|
|
],
|
|
"execution_count": 6,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"m-s-1p12yf76kgzz\n",
|
|
"0.1.64\n",
|
|
"0.1.45\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "oqEG21rADO1F",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"## Confirm Device"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab_type": "code",
|
|
"id": "8BwzMYhKGQj6",
|
|
"outputId": "f79a44e3-4303-494c-9288-a4e582bb34cb",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 68
|
|
}
|
|
},
|
|
"execution_count": 2,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.\n",
|
|
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
|
|
],
|
|
"name": "stderr"
|
|
},
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"JAX device type: cpu:0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from jaxlib import xla_extension\n",
|
|
"import jax\n",
|
|
"key = jax.random.PRNGKey(1701)\n",
|
|
"arr = jax.random.normal(key, (1000,))\n",
|
|
"device = list(arr.devices())[0]\n",
|
|
"print(f\"JAX device type: {device}\")\n",
|
|
"assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "z0FUY9yUC4k1",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"## Matrix Multiplication"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab_type": "code",
|
|
"id": "eXn8GUl6CG5N",
|
|
"outputId": "307aa669-76f1-4117-b62a-7acb2aee2c16",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 34
|
|
}
|
|
},
|
|
"source": [
|
|
"import jax\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"# matrix multiplication on GPU\n",
|
|
"key = jax.random.PRNGKey(0)\n",
|
|
"x = jax.random.normal(key, (3000, 3000))\n",
|
|
"result = jax.numpy.dot(x, x.T).mean()\n",
|
|
"print(result)"
|
|
],
|
|
"execution_count": 3,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"1.0216691\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "0zTA2Q19DW4G",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"## Linear Algebra"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "uW9j84_UDYof",
|
|
"colab_type": "code",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 51
|
|
},
|
|
"outputId": "3dd5d7c0-9d47-4be1-c6f7-068b432b69f7"
|
|
},
|
|
"source": [
|
|
"import jax.numpy as jnp\n",
|
|
"import jax.random as rand\n",
|
|
"\n",
|
|
"N = 10\n",
|
|
"M = 20\n",
|
|
"key = rand.PRNGKey(1701)\n",
|
|
"\n",
|
|
"X = rand.normal(key, (N, M))\n",
|
|
"u, s, vt = jnp.linalg.svd(X)\n",
|
|
"assert u.shape == (N, N)\n",
|
|
"assert vt.shape == (M, M)\n",
|
|
"print(s)"
|
|
],
|
|
"execution_count": 4,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[6.9178133 5.9580317 5.581113 4.506963 4.111582 3.973543 3.3307292\n",
|
|
" 2.8664916 1.8229378 1.5478933]\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "jCyKUn4-DCXn",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"## XLA Compilation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab_type": "code",
|
|
"id": "2GOn_HhDPuEn",
|
|
"outputId": "41a40dd9-3680-458d-cedd-81ebcc2ab06f",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 51
|
|
}
|
|
},
|
|
"source": [
|
|
"@jax.jit\n",
|
|
"def selu(x, alpha=1.67, lmbda=1.05):\n",
|
|
" return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)\n",
|
|
"x = jax.random.normal(key, (5000,))\n",
|
|
"result = selu(x).block_until_ready()\n",
|
|
"print(result)"
|
|
],
|
|
"execution_count": 5,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[ 0.34676832 -0.7532232 1.7060695 ... 2.1208048 -0.42621925\n",
|
|
" 0.13093236]\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|