mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Use JAX's default device instead of jax.devices()[0], if set.
PiperOrigin-RevId: 702515221
This commit is contained in:
parent
fcf0b6d3da
commit
fd4b160880
@ -440,7 +440,7 @@ def _device_put_sharding_impl(x, aval, device, copy):
|
||||
return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x],
|
||||
[device])
|
||||
|
||||
sh = SingleDeviceSharding(pxla._get_default_device()
|
||||
sh = SingleDeviceSharding(pxla.get_default_device()
|
||||
if device is None else device)
|
||||
return _DeferredShardArg(x, sh, aval, device is not None, copy)
|
||||
|
||||
|
@ -1710,7 +1710,7 @@ ShardingInfo = tuple[
|
||||
]
|
||||
|
||||
|
||||
def _get_default_device() -> xc.Device:
|
||||
def get_default_device() -> xc.Device:
|
||||
if isinstance(config.default_device.value, str):
|
||||
return xb.get_backend(config.default_device.value).local_devices()[0]
|
||||
else:
|
||||
@ -1749,7 +1749,7 @@ def _get_and_check_device_assignment(
|
||||
if first_sharding_info is None and devices:
|
||||
final_device_assignment = devices
|
||||
elif first_sharding_info is None:
|
||||
final_device_assignment = (_get_default_device(),)
|
||||
final_device_assignment = (get_default_device(),)
|
||||
else:
|
||||
final_device_assignment = first_sharding_info[0] # type: ignore
|
||||
return xb.get_device_backend(final_device_assignment[0]), final_device_assignment
|
||||
|
@ -124,6 +124,7 @@ py_library(
|
||||
"//jax:pallas",
|
||||
"//jax:util",
|
||||
"//jax/_src/pallas",
|
||||
"//jax/extend:backend",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
|
@ -33,6 +33,7 @@ from jax._src.pallas import primitives as primitives
|
||||
from jax._src.pallas.mosaic import core as tpu_core
|
||||
from jax._src.pallas.mosaic import primitives as tpu_primitives
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.extend.backend import get_default_device
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -75,7 +76,7 @@ def _broadcast_pytree_to(from_pytree, to_pytree):
|
||||
|
||||
@jax_util.cache(trace_context_in_key=False)
|
||||
def _get_tpu_generation() -> int:
|
||||
kind = jax.devices()[0].device_kind
|
||||
kind = get_default_device().device_kind
|
||||
if kind.endswith(' lite'):
|
||||
kind = kind[:-len(' lite')]
|
||||
assert kind[:5] == "TPU v", kind
|
||||
|
@ -24,3 +24,6 @@ from jax._src.xla_bridge import (
|
||||
get_backend as get_backend,
|
||||
register_backend_factory as register_backend_factory,
|
||||
)
|
||||
from jax._src.interpreters.pxla import (
|
||||
get_default_device as get_default_device
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user