Use JAX's default device instead of jax.devices()[0], if set.

PiperOrigin-RevId: 702515221
This commit is contained in:
jax authors 2024-12-03 16:53:20 -08:00
parent fcf0b6d3da
commit fd4b160880
5 changed files with 9 additions and 4 deletions

View File

@ -440,7 +440,7 @@ def _device_put_sharding_impl(x, aval, device, copy):
return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x],
[device])
sh = SingleDeviceSharding(pxla._get_default_device()
sh = SingleDeviceSharding(pxla.get_default_device()
if device is None else device)
return _DeferredShardArg(x, sh, aval, device is not None, copy)

View File

@ -1710,7 +1710,7 @@ ShardingInfo = tuple[
]
def _get_default_device() -> xc.Device:
def get_default_device() -> xc.Device:
if isinstance(config.default_device.value, str):
return xb.get_backend(config.default_device.value).local_devices()[0]
else:
@ -1749,7 +1749,7 @@ def _get_and_check_device_assignment(
if first_sharding_info is None and devices:
final_device_assignment = devices
elif first_sharding_info is None:
final_device_assignment = (_get_default_device(),)
final_device_assignment = (get_default_device(),)
else:
final_device_assignment = first_sharding_info[0] # type: ignore
return xb.get_device_backend(final_device_assignment[0]), final_device_assignment

View File

@ -124,6 +124,7 @@ py_library(
"//jax:pallas",
"//jax:util",
"//jax/_src/pallas",
"//jax/extend:backend",
] + py_deps("numpy"),
)

View File

@ -33,6 +33,7 @@ from jax._src.pallas import primitives as primitives
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import primitives as tpu_primitives
from jax.experimental import pallas as pl
from jax.extend.backend import get_default_device
import jax.numpy as jnp
import numpy as np
@ -75,7 +76,7 @@ def _broadcast_pytree_to(from_pytree, to_pytree):
@jax_util.cache(trace_context_in_key=False)
def _get_tpu_generation() -> int:
kind = jax.devices()[0].device_kind
kind = get_default_device().device_kind
if kind.endswith(' lite'):
kind = kind[:-len(' lite')]
assert kind[:5] == "TPU v", kind

View File

@ -24,3 +24,6 @@ from jax._src.xla_bridge import (
get_backend as get_backend,
register_backend_factory as register_backend_factory,
)
from jax._src.interpreters.pxla import (
get_default_device as get_default_device
)