Allow setting default_device with platform names.

This commit is contained in:
Stella-S-Yan 2024-11-07 00:24:32 +00:00
parent 56150286d5
commit afa518aa0e
3 changed files with 18 additions and 11 deletions

View File

@ -1561,7 +1561,9 @@ def _update_default_device_thread_local(val):
def _validate_default_device(val):
if val is not None and not isinstance(val, xla_client.Device):
if (val is not None and
not isinstance(val, xla_client.Device) and
val not in ['cpu', 'gpu', 'tpu']):
# TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when
# all JAX backends use a single C++ device interface.
if 'Device' in str(type(val)):
@ -1569,12 +1571,11 @@ def _validate_default_device(val):
'Allowing non-`xla_client.Device` default device: %s, type: %s',
repr(val), type(val))
return
raise ValueError('jax.default_device must be passed a Device object (e.g. '
f"`jax.devices('cpu')[0]`), got: {val!r}")
raise ValueError('jax.default_device must be passed either a Device object (e.g. '
f"`jax.devices('cpu')[0]`) or a platform name string like 'cpu' or 'gpu'"
f", got: {val!r}")
# TODO(skye): default_device only accepts devices for now. Make it work with
# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]).
default_device = string_or_object_state(
name='jax_default_device',
default=None,

View File

@ -1710,7 +1710,10 @@ ShardingInfo = tuple[
def _get_default_device() -> xc.Device:
return config.default_device.value or xb.local_devices()[0]
if isinstance(config.default_device.value, str):
return xb.get_backend(config.default_device.value).local_devices()[0]
else:
return config.default_device.value or xb.local_devices()[0]
def _get_and_check_device_assignment(
@ -1742,6 +1745,7 @@ def _get_and_check_device_assignment(
raise DeviceAssignmentMismatchError([
DeviceAssignmentMismatch(devices, MismatchType.CONTEXT_DEVICES, None),
DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])
if first_sharding_info is None and devices:
final_device_assignment = devices
elif first_sharding_info is None:
@ -2190,6 +2194,7 @@ def lower_sharding_computation(
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
len(out_shardings), len(out_layouts), len(global_out_avals))
devices_from_context = (None if context_mesh is None or context_mesh.empty
else context_mesh._flat_devices_tuple)
# Device assignment across all inputs, outputs and shardings inside jaxpr

View File

@ -287,13 +287,14 @@ class JitTest(jtu.BufferDonationTestCase):
self.assertEqual(f(sticky).devices(), system_default_devices)
self.assertEqual(f(1).devices(), system_default_devices)
# TODO(skye): make this work!
def test_jit_default_platform(self):
with self.assertRaisesWithLiteralMatch(
ValueError, "jax.default_device must be passed a Device object "
"(e.g. `jax.devices('cpu')[0]`), got: 'cpu'"):
with jax.default_device("cpu"):
jax.jit(lambda x: x + 1)(1)
result = jax.jit(lambda x: x + 1)(1)
self.assertEqual(result.device.platform, "cpu")
result = jax.jit(lambda x: x + 1)(1)
self.assertEqual(result.device.platform, jax.default_backend())
def test_complex_support(self):
self.assertEqual(jit(lambda x: x + 1)(1 + 1j), 2 + 1j)