Use _internal_device_list in _get_device so that all places accessing _get_device get a speedup.

PiperOrigin-RevId: 624320655
This commit is contained in:
Yash Katariya 2024-04-12 16:12:20 -07:00 committed by jax authors
parent 198ce52bf7
commit 001732086b

View File

@ -56,9 +56,9 @@ Index = tuple[slice, ...]
PRNGKeyArray = Any # TODO(jakevdp): fix cycles and import this.
def _get_device(a: ArrayImpl) -> Device:
devices = a.devices()
devices = a.sharding._internal_device_list # type: ignore
assert len(devices) == 1
return next(iter(devices))
return devices[0]
class Shard: