mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[JAX] Replace uses of jax.devices("cpu") with jax.local_devices(backend="cpu").
An upcoming change to JAX will include non-local (addressable) CPU devices in jax.devices() when JAX is used multicontroller-style, where there are multiple Python processes. This change preserves the current behavior by replacing uses of jax.devices("cpu"), which previously only returned local devices, with jax.local_devices("cpu"), which will return local devices both now and in the future. This change is always be safe (i.e., it should always preserve the previous behavior) but it may sometimes be unnecessary if code is never used in a multicontroller setting. PiperOrigin-RevId: 582786346
This commit is contained in:
parent
840b5c5d6d
commit
0560cc478e
@ -697,8 +697,11 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
def _maybe_put(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return dispatch._put_x(
|
||||
x, jax.sharding.SingleDeviceSharding(jax.devices('cpu')[0]),
|
||||
shaped_abstractify(x), False)
|
||||
x,
|
||||
jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0]),
|
||||
shaped_abstractify(x),
|
||||
False,
|
||||
)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user