Merge pull request #22642 from ROCm:ci_jax_exp

PiperOrigin-RevId: 655894235
This commit is contained in:
jax authors 2024-07-25 03:36:36 -07:00
commit e14752c0ab

View File

@ -207,7 +207,7 @@ _plugin_callback_lock = threading.Lock()
# It is fine for a plugin not to implement every feature that JAX uses, provided
# that a reasonable feature set is implemented and the plugin fails gracefully
# for unimplemented features. Wrong outputs are not acceptable.
_nonexperimental_plugins: set[str] = {'cuda'}
_nonexperimental_plugins: set[str] = {'cuda', 'rocm'}
def register_backend_factory(name: str, factory: BackendFactory, *,
priority: int = 0,