From 3f1900e2e3a324ec5e9d3ce4c8bb030b951ad857 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Thu, 9 Nov 2023 13:28:52 -0800 Subject: [PATCH] [PJRT C API] Add a util method to get the PJRT C API version of the backend. Disable some memories tests which are not supported on plugin older than 0.32. PiperOrigin-RevId: 581008059 --- jax/_src/test_util.py | 7 +++++++ jax/_src/xla_bridge.py | 24 +++++++++++++++++++----- tests/memories_test.py | 25 +++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 5 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 1eaa4466e..f86aa9ee2 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -331,6 +331,13 @@ def is_cloud_tpu(): return 'libtpu' in xla_bridge.get_backend().platform_version +def pjrt_c_api_version_at_least(major_version: int, minor_version: int): + pjrt_c_api_versions = xla_bridge.backend_pjrt_c_api_version() + if pjrt_c_api_versions is None: + return True + return pjrt_c_api_versions >= (major_version, minor_version) + + def is_device_tpu_v4(): return jax.devices()[0].device_kind == "TPU v4" diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 08e499507..0253cd067 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -21,29 +21,29 @@ XLA. There are also a handful of related casting utilities. from collections.abc import Mapping import dataclasses -from functools import partial, lru_cache +from functools import lru_cache, partial import glob import importlib import json import logging import os import pathlib -import platform as py_platform import pkgutil +import platform as py_platform import sys import threading -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Tuple, Union import warnings from jax._src import config from jax._src import distributed +from jax._src import traceback_util +from jax._src import util from jax._src.cloud_tpu_init import maybe_import_libtpu from jax._src.lib import cuda_versions from jax._src.lib import xla_client from jax._src.lib import xla_extension from jax._src.lib import xla_extension_version -from jax._src import traceback_util -from jax._src import util logger = logging.getLogger(__name__) @@ -856,6 +856,20 @@ def default_backend() -> str: """Returns the platform name of the default XLA backend.""" return get_backend(None).platform + +def backend_pjrt_c_api_version(platform=None) -> Optional[Tuple[int, int]]: + """Returns the PJRT C API version of the backend. + + Returns None if the backend does not use PJRT C API. + """ + backend = get_backend(platform) + if hasattr(backend, "pjrt_c_api_major_version") and hasattr( + backend, "pjrt_c_api_minor_version" + ): + return (backend.pjrt_c_api_major_version, backend.pjrt_c_api_minor_version) + return None + + @lru_cache def local_devices(process_index: Optional[int] = None, backend: Optional[Union[str, xla_client.Client]] = None, diff --git a/tests/memories_test.py b/tests/memories_test.py index 6a1f9b458..5cac75ffb 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -666,6 +666,10 @@ class MemoriesTest(jtu.BufferDonationTestCase): self.assertEqual(cache_info2.misses, cache_info1.misses) def test_device_put_host_to_hbm(self): + # TODO(jieying): remove after 12/26/2023. + if not jtu.pjrt_c_api_version_at_least(0, 32): + raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.") + mesh = jtu.create_global_mesh((4, 2), ("x", "y")) s_host = NamedSharding(mesh, P("y"), memory_kind="unpinned_host") np_inp = jnp.arange(16).reshape(8, 2) @@ -683,6 +687,10 @@ class MemoriesTest(jtu.BufferDonationTestCase): out_on_hbm, np_inp, s_hbm, "tpu_hbm") def test_device_put_hbm_to_host(self): + # TODO(jieying): remove after 12/26/2023. + if not jtu.pjrt_c_api_version_at_least(0, 32): + raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.") + mesh = jtu.create_global_mesh((4, 2), ("x", "y")) s_host = NamedSharding(mesh, P("y"), memory_kind="unpinned_host") inp = jnp.arange(16).reshape(8, 2) @@ -699,6 +707,9 @@ class MemoriesTest(jtu.BufferDonationTestCase): def test_device_put_different_device_and_memory_host_to_hbm(self): if jax.device_count() < 3: raise unittest.SkipTest("Test requires >=3 devices") + # TODO(jieying): remove after 12/26/2023. + if not jtu.pjrt_c_api_version_at_least(0, 32): + raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.") out_host0 = jax.device_put( jnp.arange(8), @@ -716,6 +727,9 @@ class MemoriesTest(jtu.BufferDonationTestCase): def test_device_put_different_device_and_memory_hbm_to_host(self): if jax.device_count() < 3: raise unittest.SkipTest("Test requires >=3 devices") + # TODO(jieying): remove after 12/26/2023. + if not jtu.pjrt_c_api_version_at_least(0, 32): + raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.") out_hbm0 = jnp.arange(8) @@ -735,6 +749,9 @@ class MemoriesTest(jtu.BufferDonationTestCase): raise unittest.SkipTest("Test requires xla_extension_version >= 199") if len(jax.devices()) < 2: raise unittest.SkipTest("Test requires >=2 devices.") + # TODO(jieying): remove after 12/26/2023. + if not jtu.pjrt_c_api_version_at_least(0, 32): + raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.") np_inp = np.arange(16).reshape(8, 2) @@ -753,6 +770,10 @@ class MemoriesTest(jtu.BufferDonationTestCase): out_host_dev_1, np_inp, s_host_dev_1, "unpinned_host") def test_device_put_resharding(self): + # TODO(jieying): remove after 12/26/2023. + if not jtu.pjrt_c_api_version_at_least(0, 32): + raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.") + mesh = jtu.create_global_mesh((2, 2), ("x", "y")) s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host") s_hbm = s_host.with_memory_kind("tpu_hbm") @@ -777,6 +798,10 @@ class MemoriesTest(jtu.BufferDonationTestCase): out_sharded_hbm, np_inp, s_hbm, "tpu_hbm") def test_jit_host_inputs_via_device_put_outside(self): + # TODO(jieying): remove after 12/26/2023. + if not jtu.pjrt_c_api_version_at_least(0, 32): + raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.") + mesh = jtu.create_global_mesh((4, 2), ("x", "y")) s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host") s_hbm = s_host.with_memory_kind("tpu_hbm")