mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Cache _get_tpu_generation
to avoid repeated calls to jax.devices().
We use `util.cache` such that if the default backend changes this function will (correctly) be re-evaluated. PiperOrigin-RevId: 661293560
This commit is contained in:
parent
40e67c73ee
commit
5ced6db692
@ -72,6 +72,7 @@ def _broadcast_pytree_to(from_pytree, to_pytree):
|
||||
return tree_util.tree_unflatten(treedef, broadcast_leaves)
|
||||
|
||||
|
||||
@jax_util.cache(trace_context_in_key=False)
|
||||
def _get_tpu_generation() -> int:
|
||||
kind = jax.devices()[0].device_kind
|
||||
if kind.endswith(' lite'):
|
||||
|
Loading…
x
Reference in New Issue
Block a user