mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Set JAX_PLATFORMS=tpu,cpu on TPUs
This commit is contained in:
parent
015a12e155
commit
78a7e161bb
@ -48,6 +48,7 @@ def cloud_tpu_init():
|
||||
|
||||
libtpu.configure_library_path()
|
||||
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
|
||||
os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu')
|
||||
os.environ['TPU_ML_PLATFORM'] = 'JAX'
|
||||
|
||||
# If the user has set any topology-related env vars, don't set any
|
||||
|
Loading…
x
Reference in New Issue
Block a user