mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Allow setting default_device with platform names.
This commit is contained in:
parent
56150286d5
commit
afa518aa0e
@ -1561,7 +1561,9 @@ def _update_default_device_thread_local(val):
|
||||
|
||||
|
||||
def _validate_default_device(val):
|
||||
if val is not None and not isinstance(val, xla_client.Device):
|
||||
if (val is not None and
|
||||
not isinstance(val, xla_client.Device) and
|
||||
val not in ['cpu', 'gpu', 'tpu']):
|
||||
# TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when
|
||||
# all JAX backends use a single C++ device interface.
|
||||
if 'Device' in str(type(val)):
|
||||
@ -1569,12 +1571,11 @@ def _validate_default_device(val):
|
||||
'Allowing non-`xla_client.Device` default device: %s, type: %s',
|
||||
repr(val), type(val))
|
||||
return
|
||||
raise ValueError('jax.default_device must be passed a Device object (e.g. '
|
||||
f"`jax.devices('cpu')[0]`), got: {val!r}")
|
||||
raise ValueError('jax.default_device must be passed either a Device object (e.g. '
|
||||
f"`jax.devices('cpu')[0]`) or a platform name string like 'cpu' or 'gpu'"
|
||||
f", got: {val!r}")
|
||||
|
||||
|
||||
# TODO(skye): default_device only accepts devices for now. Make it work with
|
||||
# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]).
|
||||
default_device = string_or_object_state(
|
||||
name='jax_default_device',
|
||||
default=None,
|
||||
|
@ -1710,7 +1710,10 @@ ShardingInfo = tuple[
|
||||
|
||||
|
||||
def _get_default_device() -> xc.Device:
|
||||
return config.default_device.value or xb.local_devices()[0]
|
||||
if isinstance(config.default_device.value, str):
|
||||
return xb.get_backend(config.default_device.value).local_devices()[0]
|
||||
else:
|
||||
return config.default_device.value or xb.local_devices()[0]
|
||||
|
||||
|
||||
def _get_and_check_device_assignment(
|
||||
@ -1742,6 +1745,7 @@ def _get_and_check_device_assignment(
|
||||
raise DeviceAssignmentMismatchError([
|
||||
DeviceAssignmentMismatch(devices, MismatchType.CONTEXT_DEVICES, None),
|
||||
DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])
|
||||
|
||||
if first_sharding_info is None and devices:
|
||||
final_device_assignment = devices
|
||||
elif first_sharding_info is None:
|
||||
@ -2190,6 +2194,7 @@ def lower_sharding_computation(
|
||||
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
|
||||
len(out_shardings), len(out_layouts), len(global_out_avals))
|
||||
|
||||
|
||||
devices_from_context = (None if context_mesh is None or context_mesh.empty
|
||||
else context_mesh._flat_devices_tuple)
|
||||
# Device assignment across all inputs, outputs and shardings inside jaxpr
|
||||
|
@ -287,13 +287,14 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
self.assertEqual(f(sticky).devices(), system_default_devices)
|
||||
self.assertEqual(f(1).devices(), system_default_devices)
|
||||
|
||||
# TODO(skye): make this work!
|
||||
def test_jit_default_platform(self):
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, "jax.default_device must be passed a Device object "
|
||||
"(e.g. `jax.devices('cpu')[0]`), got: 'cpu'"):
|
||||
with jax.default_device("cpu"):
|
||||
jax.jit(lambda x: x + 1)(1)
|
||||
result = jax.jit(lambda x: x + 1)(1)
|
||||
self.assertEqual(result.device.platform, "cpu")
|
||||
|
||||
result = jax.jit(lambda x: x + 1)(1)
|
||||
self.assertEqual(result.device.platform, jax.default_backend())
|
||||
|
||||
|
||||
def test_complex_support(self):
|
||||
self.assertEqual(jit(lambda x: x + 1)(1 + 1j), 2 + 1j)
|
||||
|
Loading…
x
Reference in New Issue
Block a user