Cache _get_tpu_generation to avoid repeated calls to jax.devices().

We use `util.cache` such that if the default backend changes this function
will (correctly) be re-evaluated.

PiperOrigin-RevId: 661293560
This commit is contained in:
Tom Hennigan 2024-08-09 09:33:11 -07:00 committed by jax authors
parent 40e67c73ee
commit 5ced6db692

View File

@ -72,6 +72,7 @@ def _broadcast_pytree_to(from_pytree, to_pytree):
return tree_util.tree_unflatten(treedef, broadcast_leaves)
@jax_util.cache(trace_context_in_key=False)
def _get_tpu_generation() -> int:
kind = jax.devices()[0].device_kind
if kind.endswith(' lite'):