Hyeontaek Lim c88ea23035 [JAX] Add caching to colocated_python.colocated_cpu_devices()
For a deployment with many devices, `colocated_python.colocated_cpu_devices()`
can take some time to find colocated devices as it needs to find matching
devices one by one in Python.

This change adds caching as an optimization to reduce the overall cost of API
calls.

PiperOrigin-RevId: 740930124
2025-03-26 15:48:48 -07:00
..