1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-21 06:16:06 +00:00

1 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
c7ed1bd3a8 Add version check to jaxlib plugin imports.
For the CUDA and ROCM plugins, we only support exact matches between the plugin and jaxlib version, and bad things can happen if we try and load mismatched versions. This change issues a warning and skips importing a plugin when there is a version mismatch.

There are a handful of other places where plugins are imported throughout the JAX codebase (e.g. in lax_numpy, mosaic_gpu, and in the plugins themselves). In a follow up it would be good to add version checking there too, but let's start with just these ones.

PiperOrigin-RevId: 731808733
2025-02-27 11:52:17 -08:00