mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #10229 from hyeontaek:transfer-guard-remove-compat-code
PiperOrigin-RevId: 441490830
This commit is contained in:
commit
86c8446c00
@ -99,7 +99,6 @@ _XLA_EXTENSION_STUBS = [
|
||||
"transfer_guard_lib.pyi",
|
||||
]
|
||||
_OPTIONAL_XLA_EXTENSION_STUBS = [
|
||||
"transfer_guard_lib.pyi", # Will be required on xla_extension_version >= 58.
|
||||
]
|
||||
|
||||
|
||||
|
@ -26,8 +26,7 @@ import warnings
|
||||
|
||||
from jax._src import lib
|
||||
from jax._src.lib import jax_jit
|
||||
if lib.xla_extension_version >= 58:
|
||||
from jax._src.lib import transfer_guard_lib
|
||||
from jax._src.lib import transfer_guard_lib
|
||||
|
||||
def bool_env(varname: str, default: bool) -> bool:
|
||||
"""Read an environment variable and interpret it as a boolean.
|
||||
@ -720,139 +719,116 @@ config.define_bool_state(
|
||||
default=True,
|
||||
help=('Enables using optimization-barrier op for lowering remat.'))
|
||||
|
||||
if lib.xla_extension_version < 58:
|
||||
@contextlib.contextmanager
|
||||
def explicit_device_put_scope() -> Iterator[None]:
|
||||
"""Indicates that the current context is an explicit device_put*() call."""
|
||||
@contextlib.contextmanager
|
||||
def explicit_device_put_scope() -> Iterator[None]:
|
||||
"""Indicates that the current context is an explicit device_put*() call."""
|
||||
state = transfer_guard_lib.thread_local_state()
|
||||
prev = state.explicit_device_put
|
||||
state.explicit_device_put = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
state.explicit_device_put = prev
|
||||
|
||||
@contextlib.contextmanager
|
||||
def explicit_device_get_scope() -> Iterator[None]:
|
||||
"""Indicates that the current context is an explicit device_get() call."""
|
||||
@contextlib.contextmanager
|
||||
def explicit_device_get_scope() -> Iterator[None]:
|
||||
"""Indicates that the current context is an explicit device_get() call."""
|
||||
state = transfer_guard_lib.thread_local_state()
|
||||
prev = state.explicit_device_get
|
||||
state.explicit_device_get = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
state.explicit_device_get = prev
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _transfer_guard(new_val: str) -> Iterator[None]:
|
||||
raise NotImplementedError("jaxlib version is too low for transfer guards")
|
||||
def _update_transfer_guard(state, key, val):
|
||||
"""Applies the transfer guard level within transfer_guard_lib."""
|
||||
if val is None:
|
||||
setattr(state, key, None)
|
||||
elif val == 'allow':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.ALLOW)
|
||||
elif val == 'log':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG)
|
||||
elif val == 'disallow':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW)
|
||||
elif val == 'log_explicit':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG_EXPLICIT)
|
||||
elif val == 'disallow_explicit':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW_EXPLICIT)
|
||||
else:
|
||||
assert False, f'Invalid transfer guard level {val}'
|
||||
|
||||
transfer_guard_host_to_device = _transfer_guard
|
||||
transfer_guard_device_to_device = _transfer_guard
|
||||
transfer_guard_device_to_host = _transfer_guard
|
||||
transfer_guard = _transfer_guard
|
||||
transfer_guard_host_to_device = config.define_enum_state(
|
||||
name='jax_transfer_guard_host_to_device',
|
||||
enum_values=[
|
||||
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
||||
],
|
||||
# The default is applied by transfer_guard_lib. Use None here to avoid
|
||||
# accidentally overriding --jax_transfer_guard.
|
||||
default=None,
|
||||
help=('Select the transfer guard level for host-to-device transfers. '
|
||||
'Default is "allow".'),
|
||||
update_global_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.global_state(), 'host_to_device', val),
|
||||
update_thread_local_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.thread_local_state(), 'host_to_device', val))
|
||||
|
||||
else:
|
||||
@contextlib.contextmanager
|
||||
def explicit_device_put_scope() -> Iterator[None]:
|
||||
"""Indicates that the current context is an explicit device_put*() call."""
|
||||
state = transfer_guard_lib.thread_local_state()
|
||||
prev = state.explicit_device_put
|
||||
state.explicit_device_put = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
state.explicit_device_put = prev
|
||||
transfer_guard_device_to_device = config.define_enum_state(
|
||||
name='jax_transfer_guard_device_to_device',
|
||||
enum_values=[
|
||||
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
||||
],
|
||||
# The default is applied by transfer_guard_lib. Use None here to avoid
|
||||
# accidentally overriding --jax_transfer_guard.
|
||||
default=None,
|
||||
help=('Select the transfer guard level for device-to-device transfers. '
|
||||
'Default is "allow".'),
|
||||
update_global_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.global_state(), 'device_to_device', val),
|
||||
update_thread_local_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.thread_local_state(), 'device_to_device', val))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def explicit_device_get_scope() -> Iterator[None]:
|
||||
"""Indicates that the current context is an explicit device_get() call."""
|
||||
state = transfer_guard_lib.thread_local_state()
|
||||
prev = state.explicit_device_get
|
||||
state.explicit_device_get = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
state.explicit_device_get = prev
|
||||
transfer_guard_device_to_host = config.define_enum_state(
|
||||
name='jax_transfer_guard_device_to_host',
|
||||
enum_values=[
|
||||
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
||||
],
|
||||
# The default is applied by transfer_guard_lib. Use None here to avoid
|
||||
# accidentally overriding --jax_transfer_guard.
|
||||
default=None,
|
||||
help=('Select the transfer guard level for device-to-host transfers. '
|
||||
'Default is "allow".'),
|
||||
update_global_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.global_state(), 'device_to_host', val),
|
||||
update_thread_local_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.thread_local_state(), 'device_to_host', val))
|
||||
|
||||
def _update_transfer_guard(state, key, val):
|
||||
"""Applies the transfer guard level within transfer_guard_lib."""
|
||||
if val is None:
|
||||
setattr(state, key, None)
|
||||
elif val == 'allow':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.ALLOW)
|
||||
elif val == 'log':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG)
|
||||
elif val == 'disallow':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW)
|
||||
elif val == 'log_explicit':
|
||||
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG_EXPLICIT)
|
||||
elif val == 'disallow_explicit':
|
||||
setattr(state, key,
|
||||
transfer_guard_lib.TransferGuardLevel.DISALLOW_EXPLICIT)
|
||||
else:
|
||||
assert False, f"Invalid transfer guard level {val}"
|
||||
def _update_all_transfer_guard_global(val):
|
||||
for name in ('jax_transfer_guard_host_to_device',
|
||||
'jax_transfer_guard_device_to_device',
|
||||
'jax_transfer_guard_device_to_host'):
|
||||
config.update(name, val)
|
||||
|
||||
transfer_guard_host_to_device = config.define_enum_state(
|
||||
name='jax_transfer_guard_host_to_device',
|
||||
enum_values=[
|
||||
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
||||
],
|
||||
# The default is applied by transfer_guard_lib. Use None here to avoid
|
||||
# accidentally overriding --jax_transfer_guard.
|
||||
default=None,
|
||||
help=('Select the transfer guard level for host-to-device transfers. '
|
||||
'Default is "allow".'),
|
||||
update_global_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.global_state(), 'host_to_device', val),
|
||||
update_thread_local_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.thread_local_state(), 'host_to_device', val))
|
||||
|
||||
transfer_guard_device_to_device = config.define_enum_state(
|
||||
name='jax_transfer_guard_device_to_device',
|
||||
enum_values=[
|
||||
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
||||
],
|
||||
# The default is applied by transfer_guard_lib. Use None here to avoid
|
||||
# accidentally overriding --jax_transfer_guard.
|
||||
default=None,
|
||||
help=('Select the transfer guard level for device-to-device transfers. '
|
||||
'Default is "allow".'),
|
||||
update_global_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.global_state(), 'device_to_device', val),
|
||||
update_thread_local_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.thread_local_state(), 'device_to_device', val))
|
||||
|
||||
transfer_guard_device_to_host = config.define_enum_state(
|
||||
name='jax_transfer_guard_device_to_host',
|
||||
enum_values=[
|
||||
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
||||
],
|
||||
# The default is applied by transfer_guard_lib. Use None here to avoid
|
||||
# accidentally overriding --jax_transfer_guard.
|
||||
default=None,
|
||||
help=('Select the transfer guard level for device-to-host transfers. '
|
||||
'Default is "allow".'),
|
||||
update_global_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.global_state(), 'device_to_host', val),
|
||||
update_thread_local_hook=lambda val: _update_transfer_guard(
|
||||
transfer_guard_lib.thread_local_state(), 'device_to_host', val))
|
||||
|
||||
def _update_all_transfer_guard_global(val):
|
||||
for name in ('jax_transfer_guard_host_to_device',
|
||||
'jax_transfer_guard_device_to_device',
|
||||
'jax_transfer_guard_device_to_host'):
|
||||
config.update(name, val)
|
||||
|
||||
_transfer_guard = config.define_enum_state(
|
||||
name='jax_transfer_guard',
|
||||
enum_values=[
|
||||
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
||||
],
|
||||
# The default is applied by transfer_guard_lib. Use None here to avoid
|
||||
# accidentally overriding --jax_transfer_guard_*.
|
||||
default=None,
|
||||
help=(
|
||||
'Select the transfer guard level for all transfers. This option is '
|
||||
_transfer_guard = config.define_enum_state(
|
||||
name='jax_transfer_guard',
|
||||
enum_values=[
|
||||
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
||||
],
|
||||
# The default is applied by transfer_guard_lib. Use None here to avoid
|
||||
# accidentally overriding --jax_transfer_guard_*.
|
||||
default=None,
|
||||
help=('Select the transfer guard level for all transfers. This option is '
|
||||
'set-only; the transfer guard level for a specific direction should '
|
||||
'be read using the per-transfer direction option. '
|
||||
'Default is "allow".'),
|
||||
update_global_hook=_update_all_transfer_guard_global)
|
||||
update_global_hook=_update_all_transfer_guard_global)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def transfer_guard(new_val: str) -> Iterator[None]:
|
||||
"""Set up thread-local state and return a contextmanager for managing it."""
|
||||
with contextlib.ExitStack() as stack:
|
||||
stack.enter_context(transfer_guard_host_to_device(new_val))
|
||||
stack.enter_context(transfer_guard_device_to_device(new_val))
|
||||
stack.enter_context(transfer_guard_device_to_host(new_val))
|
||||
stack.enter_context(_transfer_guard(new_val))
|
||||
yield
|
||||
@contextlib.contextmanager
|
||||
def transfer_guard(new_val: str) -> Iterator[None]:
|
||||
"""Set up thread-local state and return a contextmanager for managing it."""
|
||||
with contextlib.ExitStack() as stack:
|
||||
stack.enter_context(transfer_guard_host_to_device(new_val))
|
||||
stack.enter_context(transfer_guard_device_to_device(new_val))
|
||||
stack.enter_context(transfer_guard_device_to_host(new_val))
|
||||
stack.enter_context(_transfer_guard(new_val))
|
||||
yield
|
||||
|
@ -173,5 +173,4 @@ cuda_path = os.path.join(os.path.dirname(jaxlib.__file__), "cuda")
|
||||
if not os.path.isdir(cuda_path):
|
||||
cuda_path = None
|
||||
|
||||
if xla_extension_version >= 58:
|
||||
transfer_guard_lib = xla_client._xla.transfer_guard_lib
|
||||
transfer_guard_lib = xla_client._xla.transfer_guard_lib
|
||||
|
@ -18,5 +18,5 @@ def _version_as_tuple(version_str):
|
||||
__version__ = "0.3.7"
|
||||
__version_info__ = _version_as_tuple(__version__)
|
||||
|
||||
_minimum_jaxlib_version = "0.3.0"
|
||||
_minimum_jaxlib_version = "0.3.2"
|
||||
_minimum_jaxlib_version_info = _version_as_tuple(_minimum_jaxlib_version)
|
||||
|
@ -100,150 +100,144 @@ _COMMON_TEST_PARAMETERS = [
|
||||
("all", _all_funcs, jax.transfer_guard),
|
||||
]
|
||||
|
||||
if jax._src.lib.xla_extension_version < 58:
|
||||
|
||||
class TransferGuardTest(jtu.JaxTestCase):
|
||||
pass
|
||||
class TransferGuardTest(jtu.JaxTestCase):
|
||||
# `_default_config` is used by `jtu.JaxTestCase` to update the JAX config for
|
||||
# every test case. TransferGuardTest disables `--jax_enable_checks` because it
|
||||
# can prematurely fetch the value of device arrays and make device-to-host
|
||||
# tests to incur no transfers unexpectedly.
|
||||
_default_config = {"jax_enable_checks": False}
|
||||
|
||||
else:
|
||||
@contextlib.contextmanager
|
||||
def assertAllows(self, func_name):
|
||||
"""Asserts that a transfer in the context is allowed."""
|
||||
try:
|
||||
yield
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
raise RuntimeError(
|
||||
f"Expected a transfer to be allowed while running: {func_name}"
|
||||
) from e
|
||||
|
||||
class TransferGuardTest(jtu.JaxTestCase):
|
||||
# `_default_config` is used by `jtu.JaxTestCase` to update the JAX config
|
||||
# for every test case. TransferGuardTest disables `--jax_enable_checks`
|
||||
# because it can prematurely fetch the value of device arrays and make
|
||||
# device-to-host tests to incur no transfers unexpectedly.
|
||||
_default_config = {"jax_enable_checks": False}
|
||||
@contextlib.contextmanager
|
||||
def assertLogs(self, func_name):
|
||||
"""Asserts that a transfer in the context is logged and allowed."""
|
||||
# Only check if the transfer is allowed until Abseil provides an interface
|
||||
# to capture logs.
|
||||
with self.assertAllows(func_name):
|
||||
yield
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assertAllows(self, func_name):
|
||||
"""Asserts that a transfer in the context is allowed."""
|
||||
try:
|
||||
@contextlib.contextmanager
|
||||
def assertDisallows(self, func_name):
|
||||
"""Asserts that a transfer in the context is disallowed."""
|
||||
try:
|
||||
with self.assertRaises(Exception):
|
||||
yield
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
raise RuntimeError(
|
||||
f"Expected a transfer to be allowed while running: {func_name}"
|
||||
) from e
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
raise RuntimeError(
|
||||
f"Expected a transfer to be disallowed while running: {func_name}"
|
||||
) from e
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assertLogs(self, func_name):
|
||||
"""Asserts that a transfer in the context is logged and allowed."""
|
||||
# Only check if the transfer is allowed until Abseil provides an
|
||||
# interface to capture logs.
|
||||
with self.assertAllows(func_name):
|
||||
yield
|
||||
def test_simple(self):
|
||||
"""Simple transfer guard tests."""
|
||||
with jax.transfer_guard("allow"):
|
||||
with self.assertAllows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
with jax.transfer_guard("log"):
|
||||
with self.assertLogs("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
with jax.transfer_guard("disallow"):
|
||||
with self.assertDisallows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assertDisallows(self, func_name):
|
||||
"""Asserts that a transfer in the context is disallowed."""
|
||||
try:
|
||||
with self.assertRaises(Exception):
|
||||
yield
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
raise RuntimeError(
|
||||
f"Expected a transfer to be disallowed while running: {func_name}"
|
||||
) from e
|
||||
|
||||
def test_simple(self):
|
||||
"""Simple transfer guard tests."""
|
||||
def test_nesting(self):
|
||||
with jax.transfer_guard("disallow"):
|
||||
with jax.transfer_guard("allow"):
|
||||
with self.assertAllows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
with jax.transfer_guard("log"):
|
||||
with self.assertLogs("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
with jax.transfer_guard("disallow"):
|
||||
with self.assertDisallows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
with self.assertDisallows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
|
||||
def test_nesting(self):
|
||||
with jax.transfer_guard("disallow"):
|
||||
with jax.transfer_guard("allow"):
|
||||
with self.assertAllows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
with self.assertDisallows("host_to_device_jnp_ones"):
|
||||
def test_mixed_nesting(self):
|
||||
with jax.transfer_guard_host_to_device("disallow"):
|
||||
with jax.transfer_guard("allow"):
|
||||
with self.assertAllows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
with self.assertDisallows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
|
||||
def test_mixed_nesting(self):
|
||||
with jax.transfer_guard_host_to_device("disallow"):
|
||||
with jax.transfer_guard("allow"):
|
||||
with self.assertAllows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
with self.assertDisallows("host_to_device_jnp_ones"):
|
||||
with jax.transfer_guard("disallow"):
|
||||
with jax.transfer_guard_host_to_device("allow"):
|
||||
with self.assertAllows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
with self.assertDisallows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
|
||||
with jax.transfer_guard("disallow"):
|
||||
with jax.transfer_guard_host_to_device("allow"):
|
||||
with self.assertAllows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
with self.assertDisallows("host_to_device_jnp_ones"):
|
||||
jnp.ones(1)
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_allow_by_default(self, func_generator, _):
|
||||
for func_name, _, func in func_generator():
|
||||
with self.assertAllows(func_name):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_allow_by_default(self, func_generator, _):
|
||||
for func_name, _, func in func_generator():
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_allow(self, func_generator, jax_transfer_guard):
|
||||
for func_name, _, func in func_generator():
|
||||
with jax_transfer_guard("allow"):
|
||||
with self.assertAllows(func_name):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_allow(self, func_generator, jax_transfer_guard):
|
||||
for func_name, _, func in func_generator():
|
||||
with jax_transfer_guard("allow"):
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_log(self, func_generator, jax_transfer_guard):
|
||||
for func_name, explicit, func in func_generator():
|
||||
with jax_transfer_guard("log"):
|
||||
if explicit:
|
||||
with self.assertAllows(func_name):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_log(self, func_generator, jax_transfer_guard):
|
||||
for func_name, explicit, func in func_generator():
|
||||
with jax_transfer_guard("log"):
|
||||
if explicit:
|
||||
with self.assertAllows(func_name):
|
||||
func()
|
||||
else:
|
||||
with self.assertLogs(func_name):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_disallow(self, func_generator, jax_transfer_guard):
|
||||
for func_name, explicit, func in func_generator():
|
||||
with jax_transfer_guard("disallow"):
|
||||
if explicit:
|
||||
with self.assertAllows(func_name):
|
||||
func()
|
||||
else:
|
||||
with self.assertDisallows(func_name):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("device_to_host", _device_to_host_funcs,
|
||||
jax.transfer_guard_device_to_host),
|
||||
("all", _device_to_host_funcs, jax.transfer_guard),
|
||||
)
|
||||
def test_disallow_ignores_arrays_on_cpu(self, func_generator,
|
||||
jax_transfer_guard):
|
||||
for func_name, _, func in func_generator():
|
||||
with jax_transfer_guard("allow"):
|
||||
# Transfer the device array to host.
|
||||
func()
|
||||
with jax_transfer_guard("disallow"):
|
||||
with self.assertAllows(func_name):
|
||||
# No error because the array has a value on host and no new transfer
|
||||
# will occur.
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_log_explicit(self, func_generator, jax_transfer_guard):
|
||||
for func_name, _, func in func_generator():
|
||||
with jax_transfer_guard("log_explicit"):
|
||||
else:
|
||||
with self.assertLogs(func_name):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_disallow_explicit(self, func_generator, jax_transfer_guard):
|
||||
for func_name, _, func in func_generator():
|
||||
with jax_transfer_guard("disallow_explicit"):
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_disallow(self, func_generator, jax_transfer_guard):
|
||||
for func_name, explicit, func in func_generator():
|
||||
with jax_transfer_guard("disallow"):
|
||||
if explicit:
|
||||
with self.assertAllows(func_name):
|
||||
func()
|
||||
else:
|
||||
with self.assertDisallows(func_name):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("device_to_host", _device_to_host_funcs,
|
||||
jax.transfer_guard_device_to_host),
|
||||
("all", _device_to_host_funcs, jax.transfer_guard),
|
||||
)
|
||||
def test_disallow_ignores_arrays_on_cpu(self, func_generator,
|
||||
jax_transfer_guard):
|
||||
for func_name, _, func in func_generator():
|
||||
with jax_transfer_guard("allow"):
|
||||
# Transfer the device array to host.
|
||||
func()
|
||||
with jax_transfer_guard("disallow"):
|
||||
with self.assertAllows(func_name):
|
||||
# No error because the array has a value on host and no new transfer
|
||||
# will occur.
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_log_explicit(self, func_generator, jax_transfer_guard):
|
||||
for func_name, _, func in func_generator():
|
||||
with jax_transfer_guard("log_explicit"):
|
||||
with self.assertLogs(func_name):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
|
||||
def test_disallow_explicit(self, func_generator, jax_transfer_guard):
|
||||
for func_name, _, func in func_generator():
|
||||
with jax_transfer_guard("disallow_explicit"):
|
||||
with self.assertDisallows(func_name):
|
||||
func()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user