mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #22642 from ROCm:ci_jax_exp
PiperOrigin-RevId: 655894235
This commit is contained in:
commit
e14752c0ab
@ -207,7 +207,7 @@ _plugin_callback_lock = threading.Lock()
|
||||
# It is fine for a plugin not to implement every feature that JAX uses, provided
|
||||
# that a reasonable feature set is implemented and the plugin fails gracefully
|
||||
# for unimplemented features. Wrong outputs are not acceptable.
|
||||
_nonexperimental_plugins: set[str] = {'cuda'}
|
||||
_nonexperimental_plugins: set[str] = {'cuda', 'rocm'}
|
||||
|
||||
def register_backend_factory(name: str, factory: BackendFactory, *,
|
||||
priority: int = 0,
|
||||
|
Loading…
x
Reference in New Issue
Block a user