mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[PJRT C API] Add a util method to get the PJRT C API version of the backend.
Disable some memories tests which are not supported on plugin older than 0.32. PiperOrigin-RevId: 581008059
This commit is contained in:
parent
a23aac5566
commit
3f1900e2e3
@ -331,6 +331,13 @@ def is_cloud_tpu():
|
||||
return 'libtpu' in xla_bridge.get_backend().platform_version
|
||||
|
||||
|
||||
def pjrt_c_api_version_at_least(major_version: int, minor_version: int):
|
||||
pjrt_c_api_versions = xla_bridge.backend_pjrt_c_api_version()
|
||||
if pjrt_c_api_versions is None:
|
||||
return True
|
||||
return pjrt_c_api_versions >= (major_version, minor_version)
|
||||
|
||||
|
||||
def is_device_tpu_v4():
|
||||
return jax.devices()[0].device_kind == "TPU v4"
|
||||
|
||||
|
@ -21,29 +21,29 @@ XLA. There are also a handful of related casting utilities.
|
||||
|
||||
from collections.abc import Mapping
|
||||
import dataclasses
|
||||
from functools import partial, lru_cache
|
||||
from functools import lru_cache, partial
|
||||
import glob
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import platform as py_platform
|
||||
import pkgutil
|
||||
import platform as py_platform
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
import warnings
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import distributed
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.cloud_tpu_init import maybe_import_libtpu
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -856,6 +856,20 @@ def default_backend() -> str:
|
||||
"""Returns the platform name of the default XLA backend."""
|
||||
return get_backend(None).platform
|
||||
|
||||
|
||||
def backend_pjrt_c_api_version(platform=None) -> Optional[Tuple[int, int]]:
|
||||
"""Returns the PJRT C API version of the backend.
|
||||
|
||||
Returns None if the backend does not use PJRT C API.
|
||||
"""
|
||||
backend = get_backend(platform)
|
||||
if hasattr(backend, "pjrt_c_api_major_version") and hasattr(
|
||||
backend, "pjrt_c_api_minor_version"
|
||||
):
|
||||
return (backend.pjrt_c_api_major_version, backend.pjrt_c_api_minor_version)
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache
|
||||
def local_devices(process_index: Optional[int] = None,
|
||||
backend: Optional[Union[str, xla_client.Client]] = None,
|
||||
|
@ -666,6 +666,10 @@ class MemoriesTest(jtu.BufferDonationTestCase):
|
||||
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
||||
|
||||
def test_device_put_host_to_hbm(self):
|
||||
# TODO(jieying): remove after 12/26/2023.
|
||||
if not jtu.pjrt_c_api_version_at_least(0, 32):
|
||||
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")
|
||||
|
||||
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
|
||||
s_host = NamedSharding(mesh, P("y"), memory_kind="unpinned_host")
|
||||
np_inp = jnp.arange(16).reshape(8, 2)
|
||||
@ -683,6 +687,10 @@ class MemoriesTest(jtu.BufferDonationTestCase):
|
||||
out_on_hbm, np_inp, s_hbm, "tpu_hbm")
|
||||
|
||||
def test_device_put_hbm_to_host(self):
|
||||
# TODO(jieying): remove after 12/26/2023.
|
||||
if not jtu.pjrt_c_api_version_at_least(0, 32):
|
||||
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")
|
||||
|
||||
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
|
||||
s_host = NamedSharding(mesh, P("y"), memory_kind="unpinned_host")
|
||||
inp = jnp.arange(16).reshape(8, 2)
|
||||
@ -699,6 +707,9 @@ class MemoriesTest(jtu.BufferDonationTestCase):
|
||||
def test_device_put_different_device_and_memory_host_to_hbm(self):
|
||||
if jax.device_count() < 3:
|
||||
raise unittest.SkipTest("Test requires >=3 devices")
|
||||
# TODO(jieying): remove after 12/26/2023.
|
||||
if not jtu.pjrt_c_api_version_at_least(0, 32):
|
||||
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")
|
||||
|
||||
out_host0 = jax.device_put(
|
||||
jnp.arange(8),
|
||||
@ -716,6 +727,9 @@ class MemoriesTest(jtu.BufferDonationTestCase):
|
||||
def test_device_put_different_device_and_memory_hbm_to_host(self):
|
||||
if jax.device_count() < 3:
|
||||
raise unittest.SkipTest("Test requires >=3 devices")
|
||||
# TODO(jieying): remove after 12/26/2023.
|
||||
if not jtu.pjrt_c_api_version_at_least(0, 32):
|
||||
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")
|
||||
|
||||
out_hbm0 = jnp.arange(8)
|
||||
|
||||
@ -735,6 +749,9 @@ class MemoriesTest(jtu.BufferDonationTestCase):
|
||||
raise unittest.SkipTest("Test requires xla_extension_version >= 199")
|
||||
if len(jax.devices()) < 2:
|
||||
raise unittest.SkipTest("Test requires >=2 devices.")
|
||||
# TODO(jieying): remove after 12/26/2023.
|
||||
if not jtu.pjrt_c_api_version_at_least(0, 32):
|
||||
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")
|
||||
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
|
||||
@ -753,6 +770,10 @@ class MemoriesTest(jtu.BufferDonationTestCase):
|
||||
out_host_dev_1, np_inp, s_host_dev_1, "unpinned_host")
|
||||
|
||||
def test_device_put_resharding(self):
|
||||
# TODO(jieying): remove after 12/26/2023.
|
||||
if not jtu.pjrt_c_api_version_at_least(0, 32):
|
||||
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")
|
||||
|
||||
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
|
||||
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host")
|
||||
s_hbm = s_host.with_memory_kind("tpu_hbm")
|
||||
@ -777,6 +798,10 @@ class MemoriesTest(jtu.BufferDonationTestCase):
|
||||
out_sharded_hbm, np_inp, s_hbm, "tpu_hbm")
|
||||
|
||||
def test_jit_host_inputs_via_device_put_outside(self):
|
||||
# TODO(jieying): remove after 12/26/2023.
|
||||
if not jtu.pjrt_c_api_version_at_least(0, 32):
|
||||
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")
|
||||
|
||||
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
|
||||
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host")
|
||||
s_hbm = s_host.with_memory_kind("tpu_hbm")
|
||||
|
Loading…
x
Reference in New Issue
Block a user