Delete remote TPU support.

TPU VMs are the only supported way to use TPUs as of JAX 0.4.0.

PiperOrigin-RevId: 519211267
This commit is contained in:
Peter Hawkins 2023-03-24 12:32:53 -07:00 committed by jax authors
parent fad4e6f95a
commit 6ed66ada0f
13 changed files with 26 additions and 418 deletions

View File

@ -23,18 +23,6 @@ licenses(["notice"]) # Apache 2
package(default_visibility = ["//visibility:public"])
bool_flag(
name = "enable_remote_tpu",
build_setting_default = False,
)
config_setting(
name = "remote_tpu_enabled",
flag_values = {
":enable_remote_tpu": "True",
},
)
py_binary(
name = "build_wheel",
srcs = ["build_wheel.py"],
@ -47,10 +35,7 @@ py_binary(
"@xla//xla/python:xla_client",
] + if_windows([
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
]) + select({
":remote_tpu_enabled": ["@xla//xla/python/tpu_driver/client:py_tpu_client"],
"//conditions:default": [],
}) + if_cuda([
]) + if_cuda([
"//jaxlib/cuda:cuda_gpu_support",
"@local_config_cuda//cuda:cuda-nvvm",
]) + if_rocm([

View File

@ -219,8 +219,7 @@ def write_bazelrc(*, python_bin_path, remote_build,
cpu, cuda_compute_capabilities,
rocm_amdgpu_targets, bazel_options, target_cpu_features,
wheel_cpu, enable_mkl_dnn, enable_cuda, enable_nccl,
enable_tpu, enable_remote_tpu, enable_rocm,
enable_plugin_device):
enable_tpu, enable_rocm, enable_plugin_device):
tf_cuda_paths = []
with open("../.jax_configure.bazelrc", "w") as f:
@ -286,8 +285,6 @@ def write_bazelrc(*, python_bin_path, remote_build,
f.write("build --config=nonccl\n")
if enable_tpu:
f.write("build --config=tpu\n")
if enable_remote_tpu:
f.write("build --//build:enable_remote_tpu=true\n")
if enable_rocm:
f.write("build --config=rocm\n")
if not enable_nccl:
@ -375,10 +372,6 @@ def main():
parser,
"enable_tpu",
help_str="Should we build with Cloud TPU VM support enabled?")
add_boolean_argument(
parser,
"enable_remote_tpu",
help_str="Should we build with remote Cloud TPU support enabled?")
add_boolean_argument(
parser,
"enable_rocm",
@ -514,7 +507,6 @@ def main():
print("NCCL enabled: {}".format("yes" if args.enable_nccl else "no"))
print("TPU enabled: {}".format("yes" if args.enable_tpu else "no"))
print("Remote TPU enabled: {}".format("yes" if args.enable_remote_tpu else "no"))
print("ROCm enabled: {}".format("yes" if args.enable_rocm else "no"))
if args.enable_rocm:
@ -542,7 +534,6 @@ def main():
enable_cuda=args.enable_cuda,
enable_nccl=args.enable_nccl,
enable_tpu=args.enable_tpu,
enable_remote_tpu=args.enable_remote_tpu,
enable_rocm=args.enable_rocm,
enable_plugin_device=args.enable_plugin_device,
)

View File

@ -117,19 +117,6 @@ def patch_copy_xla_extension_stubs(dst_dir):
f.write(src)
def patch_copy_tpu_client_py(dst_dir):
with open(r.Rlocation("xla/xla/python/tpu_driver/client/tpu_client.py")) as f:
src = f.read()
src = src.replace("from xla.python import xla_extension as _xla",
"from . import xla_extension as _xla")
src = src.replace("from xla.python import xla_client",
"from . import xla_client")
src = src.replace(
"from xla.python.tpu_driver.client import tpu_client_extension as _tpu_client",
"from . import tpu_client_extension as _tpu_client")
with open(os.path.join(dst_dir, "tpu_client.py"), "w") as f:
f.write(src)
def verify_mac_libraries_dont_reference_chkstack():
"""Verifies that xla_extension.so doesn't depend on ____chkstk_darwin.
@ -250,10 +237,6 @@ def prepare_wheel(sources_path):
copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir)
patch_copy_xla_extension_stubs(jaxlib_dir)
if exists("xla/xla/python/tpu_driver/client/tpu_client_extension.so"):
copy_to_jaxlib("xla/xla/python/tpu_driver/client/tpu_client_extension.so")
patch_copy_tpu_client_py(jaxlib_dir)
def edit_jaxlib_version(sources_path):
version_regex = re.compile(r'__version__ = \"(.*)\"')

View File

@ -1,29 +1,5 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "hLEyhfMqmnrt"
},
"source": [
"## Colab JAX TPU Setup"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "5CTEVmyKmkfp"
},
"outputs": [],
"source": [
"import jax.tools.colab_tpu\n",
"jax.tools.colab_tpu.setup_tpu()"
]
},
{
"cell_type": "markdown",
"metadata": {

View File

@ -25,30 +25,6 @@
"Alex Alemi"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j-n2r719AKee",
"colab_type": "text"
},
"source": [
"# Cloud TPU Setup"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ReFcuyaKAxh4",
"colab_type": "code",
"colab": {}
},
"source": [
"from jax.tools import colab_tpu\n",
"colab_tpu.setup_tpu()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
@ -76,10 +52,6 @@
"from jax import vmap, jit, grad, ops, lax, config\n",
"from jax import random as jr\n",
"\n",
"# The following is required to use TPU Driver as JAX's backend.\n",
"config.FLAGS.jax_xla_backend = \"tpu_driver\"\n",
"config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n",
"\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.cm as cm\n",

View File

@ -14,30 +14,6 @@
"accelerator": "TPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "LpPtl0n4rg6L",
"colab_type": "text"
},
"source": [
"# Colab JAX TPU Setup"
]
},
{
"cell_type": "code",
"metadata": {
"id": "4DYY4Yyhq8vG",
"colab_type": "code",
"colab": {}
},
"source": [
"import jax.tools.colab_tpu\n",
"jax.tools.colab_tpu.setup_tpu()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {

View File

@ -42,10 +42,7 @@
"outputs": [],
"source": [
"# Grab other packages for this demo.\n",
"!pip install -U -q Pillow moviepy proglog scikit-image\n",
"\n",
"import jax.tools.colab_tpu\n",
"jax.tools.colab_tpu.setup_tpu()"
"!pip install -U -q Pillow moviepy proglog scikit-image"
]
},
{

View File

@ -25,23 +25,9 @@
"id": "7mCgBzix2fd3"
},
"source": [
"## Colab TPU Setup\n",
"## TPU Setup\n",
"\n",
"If you're running this code in Google Colab, be sure to choose *Runtime*→*Change Runtime Type* and choose **TPU** from the Hardware Accelerator menu.\n",
"\n",
"Once this is done, you can run the following to set up the Colab TPU for use with JAX:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "hn7HtC2QS92b"
},
"outputs": [],
"source": [
"import jax.tools.colab_tpu\n",
"jax.tools.colab_tpu.setup_tpu()"
"This notebook requires multiple accelerators and we recommend running it using Kaggle TPU VMs."
]
},
{

View File

@ -27,18 +27,9 @@ Conceptually, this is not very different from vectorisation, where the same oper
+++ {"id": "7mCgBzix2fd3"}
## Colab TPU Setup
## TPU Setup
If you're running this code in Google Colab, be sure to choose *Runtime*→*Change Runtime Type* and choose **TPU** from the Hardware Accelerator menu.
Once this is done, you can run the following to set up the Colab TPU for use with JAX:
```{code-cell} ipython3
:id: hn7HtC2QS92b
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
```
This notebook requires multiple accelerators and we recommend running it using Kaggle TPU VMs.
+++ {"id": "gN6VbcdRTcdE"}

View File

@ -114,12 +114,6 @@ import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error
# Version number for MLIR:Python APIs, provided by jaxlib.
mlir_api_version = xla_client.mlir_api_version
try:
from jaxlib import tpu_client as tpu_driver_client # pytype: disable=import-error
except:
tpu_driver_client = None # type: ignore
# TODO(rocm): check if we need the same for rocm.
cuda_path: Optional[str]
cuda_path = os.path.join(os.path.dirname(jaxlib.__file__), "cuda")

View File

@ -34,7 +34,6 @@ import numpy as np
from jax._src import lib
from jax._src import distributed
from jax._src.config import flags, bool_env, config, int_env
from jax._src.lib import tpu_driver_client
from jax._src.lib import xla_client
from jax._src import traceback_util
from jax._src import util
@ -164,16 +163,6 @@ def get_compile_options(
# Backends
def _make_tpu_driver_client() -> Optional[xla_client.Client]:
if tpu_driver_client is None:
logger.info("Remote TPU is not linked into jax; skipping remote TPU.")
return None
if FLAGS.jax_backend_target is None:
logger.info("No --jax_backend_target was provided; skipping remote TPU.")
return None
return tpu_driver_client.TpuBackend.create(worker=FLAGS.jax_backend_target)
def tpu_client_timer_callback(timer_secs: float) -> Optional[xla_client.Client]:
def _log_warning():
warnings.warn(
@ -218,8 +207,6 @@ register_backend_factory('interpreter', xla_client.make_interpreter_client,
register_backend_factory('cpu',
partial(xla_client.make_cpu_client, use_tfrt=True),
priority=0)
register_backend_factory('tpu_driver', _make_tpu_driver_client,
priority=100)
def make_gpu_client(

View File

@ -14,34 +14,26 @@
"""Utilities for running JAX on Cloud TPUs via Colab."""
import requests
import os
import textwrap
from jax.config import config
message = """
As of JAX 0.4.0, JAX only supports TPU VMs, not the older Colab TPUs.
TPU_DRIVER_MODE = 0
We recommend trying Kaggle Notebooks
(https://www.kaggle.com/code, click on "New Notebook" near the top) which offer
TPU VMs. You have to create an account, log in, and verify your account to get
accelerator support.
Once you do that, there's a new "TPU 1VM v3-8" accelerator option. This gives
you a TPU notebook environment similar to Colab, but using the newer TPU VM
architecture. This should be a less buggy, more performant, and overall better
experience than the older TPU node architecture.
It is also possible to use Colab together with a self-hosted Jupyter kernel
running on a Cloud TPU VM. See
https://research.google.com/colaboratory/local-runtimes.html
for details.
"""
def setup_tpu(tpu_driver_version='tpu_driver_20230216'):
"""Sets up Colab to run on TPU.
Note: make sure the Colab Runtime is set to Accelerator: TPU.
Args
----
tpu_driver_version : (str) specify the version identifier for the tpu driver.
Set to "tpu_driver_nightly" to use the nightly tpu driver build.
"""
global TPU_DRIVER_MODE
if not TPU_DRIVER_MODE:
colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/{tpu_driver_version}'
requests.post(url)
TPU_DRIVER_MODE = 1
# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
# TODO(skyewm): Remove this after SPMD is supported for colab tpu.
config.update('jax_array', False)
def setup_tpu(tpu_driver_version=None):
"""Returns an error. Do not use."""
raise RuntimeError(textwrap.dedent(message))

View File

@ -1,222 +0,0 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "JAX Colab TPU Test",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "TPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/google/jax/blob/main/tests/notebooks/colab_tpu.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WkadOyTDCAWD",
"colab_type": "text"
},
"source": [
"# JAX Colab TPU Test\n",
"\n",
"This notebook is meant to be run in a [Colab](http://colab.research.google.com) TPU runtime as a basic check for JAX updates."
]
},
{
"cell_type": "code",
"metadata": {
"id": "_tKNrbqqBHwu",
"colab_type": "code",
"outputId": "bf0043b0-6f2b-44e4-9822-4f426b3d158e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"source": [
"import jax\n",
"import jaxlib\n",
"\n",
"!cat /var/colab/hostname\n",
"print(jax.__version__)\n",
"print(jaxlib.__version__)"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"tpu-s-2dna7uebo6z96\n",
"0.1.64\n",
"0.1.45\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DzVStuLobcoG",
"colab_type": "text"
},
"source": [
"## TPU Setup"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "IXF0_gNCRH08",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"import jax.tools.colab_tpu\n",
"jax.tools.colab_tpu.setup_tpu()"
],
"execution_count": 2
},
{
"cell_type": "markdown",
"metadata": {
"id": "oqEG21rADO1F",
"colab_type": "text"
},
"source": [
"## Confirm Device"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "8BwzMYhKGQj6",
"outputId": "d51b7f21-d300-4420-8c5c-483bace8617d",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"from jaxlib import tpu_client_extension\n",
"import jax\n",
"key = jax.random.PRNGKey(1701)\n",
"arr = jax.random.normal(key, (1000,))\n",
"device = arr.device_buffer.device()\n",
"print(f\"JAX device type: {device}\")\n",
"assert isinstance(device, tpu_client_extension.TpuDevice), \"unexpected JAX device type\""
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"JAX device type: TPU_0(host=0,(0,0,0,0))\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z0FUY9yUC4k1",
"colab_type": "text"
},
"source": [
"## Matrix Multiplication"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "eXn8GUl6CG5N",
"outputId": "9954a064-ef8b-4db3-aad7-85d07b50f678",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"import jax\n",
"import numpy as np\n",
"\n",
"# matrix multiplication on GPU\n",
"key = jax.random.PRNGKey(0)\n",
"x = jax.random.normal(key, (3000, 3000))\n",
"result = jax.numpy.dot(x, x.T).mean()\n",
"print(result)"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"1.021576\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jCyKUn4-DCXn",
"colab_type": "text"
},
"source": [
"## XLA Compilation"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "2GOn_HhDPuEn",
"outputId": "a4384c55-41fb-44be-845d-17b86b152068",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
}
},
"source": [
"@jax.jit\n",
"def selu(x, alpha=1.67, lmbda=1.05):\n",
" return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)\n",
"x = jax.random.normal(key, (5000,))\n",
"result = selu(x).block_until_ready()\n",
"print(result)"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"[ 0.34676817 -0.7532211 1.7060809 ... 2.120809 -0.42622015\n",
" 0.13093244]\n"
],
"name": "stdout"
}
]
}
]
}