[PJRT C API] Add a util method to get the PJRT C API version of the backend.

Disable some memories tests which are not supported on plugin older than 0.32.

PiperOrigin-RevId: 581008059
This commit is contained in:
Jieying Luo 2023-11-09 13:28:52 -08:00 committed by jax authors
parent a23aac5566
commit 3f1900e2e3
3 changed files with 51 additions and 5 deletions

View File

@ -331,6 +331,13 @@ def is_cloud_tpu():
return 'libtpu' in xla_bridge.get_backend().platform_version
def pjrt_c_api_version_at_least(major_version: int, minor_version: int):
pjrt_c_api_versions = xla_bridge.backend_pjrt_c_api_version()
if pjrt_c_api_versions is None:
return True
return pjrt_c_api_versions >= (major_version, minor_version)
def is_device_tpu_v4():
return jax.devices()[0].device_kind == "TPU v4"

View File

@ -21,29 +21,29 @@ XLA. There are also a handful of related casting utilities.
from collections.abc import Mapping
import dataclasses
from functools import partial, lru_cache
from functools import lru_cache, partial
import glob
import importlib
import json
import logging
import os
import pathlib
import platform as py_platform
import pkgutil
import platform as py_platform
import sys
import threading
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Tuple, Union
import warnings
from jax._src import config
from jax._src import distributed
from jax._src import traceback_util
from jax._src import util
from jax._src.cloud_tpu_init import maybe_import_libtpu
from jax._src.lib import cuda_versions
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src import traceback_util
from jax._src import util
logger = logging.getLogger(__name__)
@ -856,6 +856,20 @@ def default_backend() -> str:
"""Returns the platform name of the default XLA backend."""
return get_backend(None).platform
def backend_pjrt_c_api_version(platform=None) -> Optional[Tuple[int, int]]:
"""Returns the PJRT C API version of the backend.
Returns None if the backend does not use PJRT C API.
"""
backend = get_backend(platform)
if hasattr(backend, "pjrt_c_api_major_version") and hasattr(
backend, "pjrt_c_api_minor_version"
):
return (backend.pjrt_c_api_major_version, backend.pjrt_c_api_minor_version)
return None
@lru_cache
def local_devices(process_index: Optional[int] = None,
backend: Optional[Union[str, xla_client.Client]] = None,

View File

@ -666,6 +666,10 @@ class MemoriesTest(jtu.BufferDonationTestCase):
self.assertEqual(cache_info2.misses, cache_info1.misses)
def test_device_put_host_to_hbm(self):
# TODO(jieying): remove after 12/26/2023.
if not jtu.pjrt_c_api_version_at_least(0, 32):
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("y"), memory_kind="unpinned_host")
np_inp = jnp.arange(16).reshape(8, 2)
@ -683,6 +687,10 @@ class MemoriesTest(jtu.BufferDonationTestCase):
out_on_hbm, np_inp, s_hbm, "tpu_hbm")
def test_device_put_hbm_to_host(self):
# TODO(jieying): remove after 12/26/2023.
if not jtu.pjrt_c_api_version_at_least(0, 32):
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("y"), memory_kind="unpinned_host")
inp = jnp.arange(16).reshape(8, 2)
@ -699,6 +707,9 @@ class MemoriesTest(jtu.BufferDonationTestCase):
def test_device_put_different_device_and_memory_host_to_hbm(self):
if jax.device_count() < 3:
raise unittest.SkipTest("Test requires >=3 devices")
# TODO(jieying): remove after 12/26/2023.
if not jtu.pjrt_c_api_version_at_least(0, 32):
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")
out_host0 = jax.device_put(
jnp.arange(8),
@ -716,6 +727,9 @@ class MemoriesTest(jtu.BufferDonationTestCase):
def test_device_put_different_device_and_memory_hbm_to_host(self):
if jax.device_count() < 3:
raise unittest.SkipTest("Test requires >=3 devices")
# TODO(jieying): remove after 12/26/2023.
if not jtu.pjrt_c_api_version_at_least(0, 32):
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")
out_hbm0 = jnp.arange(8)
@ -735,6 +749,9 @@ class MemoriesTest(jtu.BufferDonationTestCase):
raise unittest.SkipTest("Test requires xla_extension_version >= 199")
if len(jax.devices()) < 2:
raise unittest.SkipTest("Test requires >=2 devices.")
# TODO(jieying): remove after 12/26/2023.
if not jtu.pjrt_c_api_version_at_least(0, 32):
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")
np_inp = np.arange(16).reshape(8, 2)
@ -753,6 +770,10 @@ class MemoriesTest(jtu.BufferDonationTestCase):
out_host_dev_1, np_inp, s_host_dev_1, "unpinned_host")
def test_device_put_resharding(self):
# TODO(jieying): remove after 12/26/2023.
if not jtu.pjrt_c_api_version_at_least(0, 32):
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host")
s_hbm = s_host.with_memory_kind("tpu_hbm")
@ -777,6 +798,10 @@ class MemoriesTest(jtu.BufferDonationTestCase):
out_sharded_hbm, np_inp, s_hbm, "tpu_hbm")
def test_jit_host_inputs_via_device_put_outside(self):
# TODO(jieying): remove after 12/26/2023.
if not jtu.pjrt_c_api_version_at_least(0, 32):
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host")
s_hbm = s_host.with_memory_kind("tpu_hbm")