mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 02:36:06 +00:00

For a deployment with many devices, `colocated_python.colocated_cpu_devices()` can take some time to find colocated devices as it needs to find matching devices one by one in Python. This change adds caching as an optimization to reduce the overall cost of API calls. PiperOrigin-RevId: 740930124