mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Use _internal_device_list
in _get_device
so that all places accessing _get_device
get a speedup.
PiperOrigin-RevId: 624320655
This commit is contained in:
parent
198ce52bf7
commit
001732086b
@ -56,9 +56,9 @@ Index = tuple[slice, ...]
|
||||
PRNGKeyArray = Any # TODO(jakevdp): fix cycles and import this.
|
||||
|
||||
def _get_device(a: ArrayImpl) -> Device:
|
||||
devices = a.devices()
|
||||
devices = a.sharding._internal_device_list # type: ignore
|
||||
assert len(devices) == 1
|
||||
return next(iter(devices))
|
||||
return devices[0]
|
||||
|
||||
|
||||
class Shard:
|
||||
|
Loading…
x
Reference in New Issue
Block a user