mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers. The API distinguishes between two types of transfers: * explicit transfers: `jax.device_put*()` and `jax.device_get()` calls. * implicit transfers: Other transfers (e.g., printing a `DeviceArray`). The transfer guard can take an action based on its guard level: * "allow": Silently allow all transfers (default; same as the previous behavior). * "log": Log and allow implicit transfers. Silently allow explicit transfers. * "disallow": Disallow implicit transfers. Silently allow explicit transfers. * "log_explicit": Log and allow all transfers. * "disallow_explicit": Disallow all transfers. The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction: * "host_to_device": Converting a Python value into a `DeviceBuffer`. * "device_to_device": Copying a `DeviceBuffer` to a different device. * "device_to_host": Fetching the value of a `DeviceBuffer`. Example: ``` x = jnp.array(1) y = jnp.array(2) z = jnp.array(3) print(x) # No error with jax.transfer_guard("disallow"): print(x) # No error; x is already fetched print(jax.device_get(y)) # No error print(z) # Error! ``` PiperOrigin-RevId: 428590081
This commit is contained in:
parent
f229a703e7
commit
beaa00c460
@ -96,6 +96,10 @@ _XLA_EXTENSION_STUBS = [
|
||||
"pmap_lib.pyi",
|
||||
"profiler.pyi",
|
||||
"pytree.pyi",
|
||||
"transfer_guard_lib.pyi",
|
||||
]
|
||||
_OPTIONAL_XLA_EXTENSION_STUBS = [
|
||||
"transfer_guard_lib.pyi", # Will be required on xla_extension_version >= 58.
|
||||
]
|
||||
|
||||
|
||||
@ -107,8 +111,12 @@ def patch_copy_xla_extension_stubs(dst_dir):
|
||||
xla_extension_dir = os.path.join(dst_dir, "xla_extension")
|
||||
os.makedirs(xla_extension_dir)
|
||||
for stub_name in _XLA_EXTENSION_STUBS:
|
||||
with open(r.Rlocation(
|
||||
"org_tensorflow/tensorflow/compiler/xla/python/xla_extension/" + stub_name)) as f:
|
||||
stub_path = r.Rlocation(
|
||||
"org_tensorflow/tensorflow/compiler/xla/python/xla_extension/" + stub_name)
|
||||
stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path).
|
||||
if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path):
|
||||
continue
|
||||
with open(stub_path) as f:
|
||||
src = f.read()
|
||||
src = src.replace(
|
||||
"from tensorflow.compiler.xla.python import xla_extension",
|
||||
|
@ -49,7 +49,11 @@ from jax._src.config import (
|
||||
default_matmul_precision as default_matmul_precision,
|
||||
default_prng_impl as default_prng_impl,
|
||||
numpy_rank_promotion as numpy_rank_promotion,
|
||||
jax2tf_associative_scan_reductions as jax2tf_associative_scan_reductions
|
||||
jax2tf_associative_scan_reductions as jax2tf_associative_scan_reductions,
|
||||
transfer_guard as transfer_guard,
|
||||
transfer_guard_host_to_device as transfer_guard_host_to_device,
|
||||
transfer_guard_device_to_device as transfer_guard_device_to_device,
|
||||
transfer_guard_device_to_host as transfer_guard_device_to_host,
|
||||
)
|
||||
from .core import eval_context as ensure_compile_time_eval
|
||||
from jax._src.api import (
|
||||
|
@ -84,11 +84,14 @@ from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp
|
||||
from jax.custom_transpose import custom_transpose
|
||||
from jax.ad_checkpoint import checkpoint_policies
|
||||
|
||||
from jax._src.config import (flags, config, bool_env,
|
||||
disable_jit as _disable_jit,
|
||||
debug_nans as config_debug_nans,
|
||||
debug_infs as config_debug_infs,
|
||||
_thread_local_state as config_thread_local_state)
|
||||
from jax._src.config import (
|
||||
flags, config, bool_env,
|
||||
disable_jit as _disable_jit,
|
||||
debug_nans as config_debug_nans,
|
||||
debug_infs as config_debug_infs,
|
||||
_thread_local_state as config_thread_local_state,
|
||||
explicit_device_put_scope as config_explicit_device_put_scope,
|
||||
explicit_device_get_scope as config_explicit_device_get_scope)
|
||||
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
@ -2750,7 +2753,8 @@ def device_put(x, device: Optional[xc.Device] = None):
|
||||
Returns:
|
||||
A copy of ``x`` that resides on ``device``.
|
||||
"""
|
||||
return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
|
||||
with config_explicit_device_put_scope():
|
||||
return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
|
||||
|
||||
|
||||
def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]):
|
||||
@ -2819,7 +2823,8 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]):
|
||||
for buf in dispatch.device_put(x, d)]
|
||||
return pxla.make_sharded_device_array(stacked_aval, None, buffers)
|
||||
|
||||
return tree_multimap(_device_put_sharded, *shards)
|
||||
with config_explicit_device_put_scope():
|
||||
return tree_multimap(_device_put_sharded, *shards)
|
||||
|
||||
|
||||
def device_put_replicated(x: Any, devices: Sequence[xc.Device]):
|
||||
@ -2862,7 +2867,9 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]):
|
||||
buf, = dispatch.device_put(x, devices[0])
|
||||
rest_bufs = [buf.copy_to_device(d) for d in devices[1:]]
|
||||
return pxla.make_sharded_device_array(aval, None, [buf, *rest_bufs])
|
||||
return tree_map(_device_put_replicated, x)
|
||||
|
||||
with config_explicit_device_put_scope():
|
||||
return tree_map(_device_put_replicated, x)
|
||||
|
||||
|
||||
# TODO(mattjj): consider revising
|
||||
@ -2907,12 +2914,13 @@ def device_get(x: Any):
|
||||
- device_put_sharded
|
||||
- device_put_replicated
|
||||
"""
|
||||
for y in tree_leaves(x):
|
||||
try:
|
||||
y.copy_to_host_async()
|
||||
except AttributeError:
|
||||
pass
|
||||
return tree_map(_device_get, x)
|
||||
with config_explicit_device_get_scope():
|
||||
for y in tree_leaves(x):
|
||||
try:
|
||||
y.copy_to_host_async()
|
||||
except AttributeError:
|
||||
pass
|
||||
return tree_map(_device_get, x)
|
||||
|
||||
def _check_arg(arg):
|
||||
if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
|
||||
|
@ -21,11 +21,13 @@ import itertools
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any, List, Callable, NamedTuple, Optional
|
||||
from typing import Any, List, Callable, NamedTuple, Iterator, Optional
|
||||
import warnings
|
||||
|
||||
from jax._src import lib
|
||||
from jax._src.lib import jax_jit
|
||||
if lib.xla_extension_version >= 58:
|
||||
from jax._src.lib import transfer_guard_lib
|
||||
|
||||
def bool_env(varname: str, default: bool) -> bool:
|
||||
"""Read an environment variable and interpret it as a boolean.
|
||||
@ -685,3 +687,140 @@ config.define_bool_state(
|
||||
default=False,
|
||||
help=('Enables experimental features for staging out computations with '
|
||||
'dynamic shapes.'))
|
||||
|
||||
if lib.xla_extension_version < 58:
|
||||
@contextlib.contextmanager
|
||||
def explicit_device_put_scope() -> Iterator[None]:
|
||||
"""Indicates that the current context is an explicit device_put*() call."""
|
||||
yield
|
||||
|
||||
@contextlib.contextmanager
|
||||
def explicit_device_get_scope() -> Iterator[None]:
|
||||
"""Indicates that the current context is an explicit device_get() call."""
|
||||
yield
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _transfer_guard(new_val: str) -> Iterator[None]:
|
||||
raise NotImplementedError("jaxlib version is too low for transfer guards")
|
||||
|
||||
transfer_guard_host_to_device = _transfer_guard
|
||||
transfer_guard_device_to_device = _transfer_guard
|
||||
transfer_guard_device_to_host = _transfer_guard
|
||||
transfer_guard = _transfer_guard
|
||||
|
||||
else:
|
||||
@contextlib.contextmanager
|
||||
def explicit_device_put_scope() -> Iterator[None]:
|
||||
"""Indicates that the current context is an explicit device_put*() call."""
|
||||
state = transfer_guard_lib.thread_local_state()
|
||||
prev = state.explicit_device_put
|
||||
state.explicit_device_put = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
state.explicit_device_put = prev
|
||||
|
||||
@contextlib.contextmanager
|
||||
def explicit_device_get_scope() -> Iterator[None]:
|
||||
"""Indicates that the current context is an explicit device_get() call."""
|
||||
state = transfer_guard_lib.thread_local_state()
|
||||
prev = state.explicit_device_get
|
||||
state.explicit_device_get = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
state.explicit_device_get = prev
|
||||
|
||||
def _update_transfer_guard(state, key, val):
|
||||
"""Applies the transfer guard level within transfer_guard_lib."""
|
||||
if val is None:
|
||||
setattr(state, key, None)
|
||||
elif val == 'allow':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.ALLOW)
|
||||
elif val == 'log':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG)
|
||||
elif val == 'disallow':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW)
|
||||
elif val == 'log_explicit':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG_EXPLICIT)
|
||||
elif val == 'disallow_explicit':
|
||||
setattr(state, key,
|
||||
transfer_guard_lib.TransferGuardLevel.DISALLOW_EXPLICIT)
|
||||
else:
|
||||
assert False, f"Invalid transfer guard level {val}"
|
||||
|
||||
transfer_guard_host_to_device = config.define_enum_state(
|
||||
name='jax_transfer_guard_host_to_device',
|
||||
enum_values=[
|
||||
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
||||
],
|
||||
# The default is applied by transfer_guard_lib. Use None here to avoid
|
||||
# accidentally overriding --jax_transfer_guard.
|
||||
default=None,
|
||||
help=('Select the transfer guard level for host-to-device transfers. '
|
||||
'Default is "allow".'),
|
||||
update_global_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.global_state(), 'host_to_device', val),
|
||||
update_thread_local_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.thread_local_state(), 'host_to_device', val))
|
||||
|
||||
transfer_guard_device_to_device = config.define_enum_state(
|
||||
name='jax_transfer_guard_device_to_device',
|
||||
enum_values=[
|
||||
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
||||
],
|
||||
# The default is applied by transfer_guard_lib. Use None here to avoid
|
||||
# accidentally overriding --jax_transfer_guard.
|
||||
default=None,
|
||||
help=('Select the transfer guard level for device-to-device transfers. '
|
||||
'Default is "allow".'),
|
||||
update_global_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.global_state(), 'device_to_device', val),
|
||||
update_thread_local_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.thread_local_state(), 'device_to_device', val))
|
||||
|
||||
transfer_guard_device_to_host = config.define_enum_state(
|
||||
name='jax_transfer_guard_device_to_host',
|
||||
enum_values=[
|
||||
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
||||
],
|
||||
# The default is applied by transfer_guard_lib. Use None here to avoid
|
||||
# accidentally overriding --jax_transfer_guard.
|
||||
default=None,
|
||||
help=('Select the transfer guard level for device-to-host transfers. '
|
||||
'Default is "allow".'),
|
||||
update_global_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.global_state(), 'device_to_host', val),
|
||||
update_thread_local_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.thread_local_state(), 'device_to_host', val))
|
||||
|
||||
def _update_all_transfer_guard_global(val):
|
||||
for name in ('jax_transfer_guard_host_to_device',
|
||||
'jax_transfer_guard_device_to_device',
|
||||
'jax_transfer_guard_device_to_host'):
|
||||
config.update(name, val)
|
||||
|
||||
_transfer_guard = config.define_enum_state(
|
||||
name='jax_transfer_guard',
|
||||
enum_values=[
|
||||
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
||||
],
|
||||
# The default is applied by transfer_guard_lib. Use None here to avoid
|
||||
# accidentally overriding --jax_transfer_guard_*.
|
||||
default=None,
|
||||
help=(
|
||||
'Select the transfer guard level for all transfers. This option is '
|
||||
'set-only; the transfer guard level for a specific direction should '
|
||||
'be read using the per-transfer direction option. '
|
||||
'Default is "allow".'),
|
||||
update_global_hook=_update_all_transfer_guard_global)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def transfer_guard(new_val: str) -> Iterator[None]:
|
||||
"""Set up thread-local state and return a contextmanager for managing it."""
|
||||
with contextlib.ExitStack() as stack:
|
||||
stack.enter_context(transfer_guard_host_to_device(new_val))
|
||||
stack.enter_context(transfer_guard_device_to_device(new_val))
|
||||
stack.enter_context(transfer_guard_device_to_host(new_val))
|
||||
stack.enter_context(_transfer_guard(new_val))
|
||||
yield
|
||||
|
@ -152,3 +152,6 @@ cuda_path: Optional[str]
|
||||
cuda_path = os.path.join(os.path.dirname(jaxlib.__file__), "cuda")
|
||||
if not os.path.isdir(cuda_path):
|
||||
cuda_path = None
|
||||
|
||||
if xla_extension_version >= 58:
|
||||
transfer_guard_lib = xla_client._xla.transfer_guard_lib
|
||||
|
249
tests/transfer_guard_test.py
Normal file
249
tests/transfer_guard_test.py
Normal file
@ -0,0 +1,249 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for transfer guards."""
|
||||
|
||||
import contextlib
|
||||
import pickle
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax._src.test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.config import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def _host_to_device_funcs():
|
||||
"""Generates host-to-device transfer functions."""
|
||||
return [
|
||||
# (function name, is an explicit transfer?, function)
|
||||
("host_to_device_jax_device_put", True,
|
||||
lambda: jax.device_put(np.ones(10))),
|
||||
("host_to_device_jax_jit", False, lambda: jax.jit(lambda x: x)
|
||||
(np.ones(1))),
|
||||
("host_to_device_jnp_one", False, lambda: jnp.ones(1)),
|
||||
]
|
||||
|
||||
|
||||
def _device_to_device_funcs():
|
||||
"""Generates device-to-device transfer functions."""
|
||||
if len(jax.local_devices()) < 2:
|
||||
# device-to-device tests require at least 2 devices.
|
||||
return []
|
||||
|
||||
with jax.transfer_guard_host_to_device("allow"):
|
||||
device_arrays = [jnp.ones(1) for _ in range(2)]
|
||||
return [
|
||||
# (function name, is an explicit transfer?, function)
|
||||
("device_to_device_jax_device_put", True,
|
||||
lambda: jax.device_put(device_arrays[0], device=jax.local_devices()[1])),
|
||||
("device_to_device_jax_jit", False,
|
||||
lambda: jax.jit(lambda x: x, device=jax.local_devices()[1])
|
||||
(device_arrays[1])),
|
||||
]
|
||||
|
||||
|
||||
def _device_to_host_funcs():
|
||||
"""Generates device-to-host transfer functions."""
|
||||
if jax.default_backend() == "cpu":
|
||||
# device-to-host does not incur transfer on the CPU backend.
|
||||
return []
|
||||
|
||||
with jax.transfer_guard_host_to_device("allow"):
|
||||
device_arrays = [jnp.ones(1) for _ in range(6)]
|
||||
return [
|
||||
# (function name, is an explicit transfer?, function)
|
||||
("device_to_host_jax_device_get", True,
|
||||
lambda: jax.device_get(device_arrays[0])),
|
||||
("device_to_host_np_asarray", False,
|
||||
lambda: np.asarray(device_arrays[1])),
|
||||
("device_to_host_copy_to_host_async", False,
|
||||
lambda: device_arrays[2].copy_to_host_async()),
|
||||
("device_to_host_np_add", False, lambda: np.add(device_arrays[3], 1)),
|
||||
("device_to_host_str", False, lambda: str(device_arrays[4])),
|
||||
("device_to_host_pickle_dumps", False,
|
||||
lambda: pickle.dumps(device_arrays[5])),
|
||||
]
|
||||
|
||||
|
||||
def _all_funcs():
|
||||
"""Generates all transfer functions."""
|
||||
return (_host_to_device_funcs() + _device_to_device_funcs() +
|
||||
_device_to_host_funcs())
|
||||
|
||||
|
||||
# List of test parameters shared by multiple tests.
|
||||
_COMMON_TEST_PARAMETERS = [
|
||||
("host_to_device", _host_to_device_funcs,
|
||||
jax.transfer_guard_host_to_device),
|
||||
("device_to_device", _device_to_device_funcs,
|
||||
jax.transfer_guard_device_to_device),
|
||||
("device_to_host", _device_to_host_funcs,
|
||||
jax.transfer_guard_device_to_host),
|
||||
("all", _all_funcs, jax.transfer_guard),
|
||||
]
|
||||
|
||||
if jax._src.lib.xla_extension_version < 58:
|
||||
|
||||
class TransferGuardTest(jtu.JaxTestCase):
|
||||
pass
|
||||
|
||||
else:
|
||||
|
||||
class TransferGuardTest(jtu.JaxTestCase):
|
||||
# `_default_config` is used by `jtu.JaxTestCase` to update the JAX config
|
||||
# for every test case. TransferGuardTest disables `--jax_enable_checks`
|
||||
# because it can prematurely fetch the value of device arrays and make
|
||||
# device-to-host tests to incur no transfers unexpectedly.
|
||||
_default_config = {"jax_enable_checks": False}
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assertAllows(self, func_name):
|
||||
"""Asserts that a transfer in the context is allowed."""
|
||||
try:
|
||||
yield
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
raise RuntimeError(
|
||||
f"Expected a transfer to be allowed while running: {func_name}"
|
||||
) from e
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assertLogs(self, func_name):
|
||||
"""Asserts that a transfer in the context is logged and allowed."""
|
||||
# Only check if the transfer is allowed until Abseil provides an
|
||||
# interface to capture logs.
|
||||
with self.assertAllows(func_name):
|
||||
yield
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assertDisallows(self, func_name):
|
||||
"""Asserts that a transfer in the context is disallowed."""
|
||||
try:
|
||||
with self.assertRaises(Exception):
|
||||
yield
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
raise RuntimeError(
|
||||
f"Expected a transfer to be disallowed while running: {func_name}"
|
||||
) from e
|
||||
|
||||
def test_simple(self):
|
||||
"""Simple transfer guard tests."""
|
||||
with jax.transfer_guard("allow"):
|
||||
with self.assertAllows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
with jax.transfer_guard("log"):
|
||||
with self.assertLogs("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
with jax.transfer_guard("disallow"):
|
||||
with self.assertDisallows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
|
||||
def test_nesting(self):
|
||||
with jax.transfer_guard("disallow"):
|
||||
with jax.transfer_guard("allow"):
|
||||
with self.assertAllows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
with self.assertDisallows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
|
||||
def test_mixed_nesting(self):
|
||||
with jax.transfer_guard_host_to_device("disallow"):
|
||||
with jax.transfer_guard("allow"):
|
||||
with self.assertAllows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
with self.assertDisallows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
|
||||
with jax.transfer_guard("disallow"):
|
||||
with jax.transfer_guard_host_to_device("allow"):
|
||||
with self.assertAllows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
with self.assertDisallows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_allow_by_default(self, func_generator, _):
|
||||
for func_name, _, func in func_generator():
|
||||
with self.assertAllows(func_name):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_allow(self, func_generator, jax_transfer_guard):
|
||||
for func_name, _, func in func_generator():
|
||||
with jax_transfer_guard("allow"):
|
||||
with self.assertAllows(func_name):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_log(self, func_generator, jax_transfer_guard):
|
||||
for func_name, explicit, func in func_generator():
|
||||
with jax_transfer_guard("log"):
|
||||
if explicit:
|
||||
with self.assertAllows(func_name):
|
||||
func()
|
||||
else:
|
||||
with self.assertLogs(func_name):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_disallow(self, func_generator, jax_transfer_guard):
|
||||
for func_name, explicit, func in func_generator():
|
||||
with jax_transfer_guard("disallow"):
|
||||
if explicit:
|
||||
with self.assertAllows(func_name):
|
||||
func()
|
||||
else:
|
||||
with self.assertDisallows(func_name):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("device_to_host", _device_to_host_funcs,
|
||||
jax.transfer_guard_device_to_host),
|
||||
("all", _device_to_host_funcs, jax.transfer_guard),
|
||||
)
|
||||
def test_disallow_ignores_arrays_on_cpu(self, func_generator,
|
||||
jax_transfer_guard):
|
||||
for func_name, _, func in func_generator():
|
||||
with jax_transfer_guard("allow"):
|
||||
# Transfer the device array to host.
|
||||
func()
|
||||
with jax_transfer_guard("disallow"):
|
||||
with self.assertAllows(func_name):
|
||||
# No error because the array has a value on host and no new transfer
|
||||
# will occur.
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_log_explicit(self, func_generator, jax_transfer_guard):
|
||||
for func_name, _, func in func_generator():
|
||||
with jax_transfer_guard("log_explicit"):
|
||||
with self.assertLogs(func_name):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_disallow_explicit(self, func_generator, jax_transfer_guard):
|
||||
for func_name, _, func in func_generator():
|
||||
with jax_transfer_guard("disallow_explicit"):
|
||||
with self.assertDisallows(func_name):
|
||||
func()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user