From afa518aa0ef296d8b32c71bf2b2022514f520e70 Mon Sep 17 00:00:00 2001 From: Stella-S-Yan Date: Thu, 7 Nov 2024 00:24:32 +0000 Subject: [PATCH] Allow setting default_device with platform names. --- jax/_src/config.py | 11 ++++++----- jax/_src/interpreters/pxla.py | 7 ++++++- tests/api_test.py | 11 ++++++----- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index f3edde699..30a9ba0be 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1561,7 +1561,9 @@ def _update_default_device_thread_local(val): def _validate_default_device(val): - if val is not None and not isinstance(val, xla_client.Device): + if (val is not None and + not isinstance(val, xla_client.Device) and + val not in ['cpu', 'gpu', 'tpu']): # TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when # all JAX backends use a single C++ device interface. if 'Device' in str(type(val)): @@ -1569,12 +1571,11 @@ def _validate_default_device(val): 'Allowing non-`xla_client.Device` default device: %s, type: %s', repr(val), type(val)) return - raise ValueError('jax.default_device must be passed a Device object (e.g. ' - f"`jax.devices('cpu')[0]`), got: {val!r}") + raise ValueError('jax.default_device must be passed either a Device object (e.g. ' + f"`jax.devices('cpu')[0]`) or a platform name string like 'cpu' or 'gpu'" + f", got: {val!r}") -# TODO(skye): default_device only accepts devices for now. Make it work with -# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]). default_device = string_or_object_state( name='jax_default_device', default=None, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index c83d3e3a4..2ee2ec75e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1710,7 +1710,10 @@ ShardingInfo = tuple[ def _get_default_device() -> xc.Device: - return config.default_device.value or xb.local_devices()[0] + if isinstance(config.default_device.value, str): + return xb.get_backend(config.default_device.value).local_devices()[0] + else: + return config.default_device.value or xb.local_devices()[0] def _get_and_check_device_assignment( @@ -1742,6 +1745,7 @@ def _get_and_check_device_assignment( raise DeviceAssignmentMismatchError([ DeviceAssignmentMismatch(devices, MismatchType.CONTEXT_DEVICES, None), DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)]) + if first_sharding_info is None and devices: final_device_assignment = devices elif first_sharding_info is None: @@ -2190,6 +2194,7 @@ def lower_sharding_computation( assert len(out_shardings) == len(out_layouts) == len(global_out_avals), ( len(out_shardings), len(out_layouts), len(global_out_avals)) + devices_from_context = (None if context_mesh is None or context_mesh.empty else context_mesh._flat_devices_tuple) # Device assignment across all inputs, outputs and shardings inside jaxpr diff --git a/tests/api_test.py b/tests/api_test.py index 8ab5d90f6..e61236a2d 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -287,13 +287,14 @@ class JitTest(jtu.BufferDonationTestCase): self.assertEqual(f(sticky).devices(), system_default_devices) self.assertEqual(f(1).devices(), system_default_devices) - # TODO(skye): make this work! def test_jit_default_platform(self): - with self.assertRaisesWithLiteralMatch( - ValueError, "jax.default_device must be passed a Device object " - "(e.g. `jax.devices('cpu')[0]`), got: 'cpu'"): with jax.default_device("cpu"): - jax.jit(lambda x: x + 1)(1) + result = jax.jit(lambda x: x + 1)(1) + self.assertEqual(result.device.platform, "cpu") + + result = jax.jit(lambda x: x + 1)(1) + self.assertEqual(result.device.platform, jax.default_backend()) + def test_complex_support(self): self.assertEqual(jit(lambda x: x + 1)(1 + 1j), 2 + 1j)