[Mosaic GPU] Allow __init__.py to run without _src.lib.mosaic_gpu being available.

PiperOrigin-RevId: 646124431
This commit is contained in:
jax authors 2024-06-24 09:35:51 -07:00 committed by jax authors
parent e0b2144000
commit e119fe933b

View File

@ -31,7 +31,6 @@ from jax._src import config
from jax._src import core as jax_core
from jax._src.interpreters import mlir
from jax._src.lib import xla_client
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import builtin
@ -68,8 +67,18 @@ TMA_DESCRIPTOR_ALIGNMENT = 64
c = mgpu.c # This is too common to fully qualify.
RUNTIME_PATH = pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent / "libmosaic_gpu_runtime.so"
if RUNTIME_PATH.exists():
RUNTIME_PATH = None
try:
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
RUNTIME_PATH = (
pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent
/ "libmosaic_gpu_runtime.so"
)
except ImportError:
pass
if RUNTIME_PATH and RUNTIME_PATH.exists():
# Set this so that the custom call can find it
os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH)