From e119fe933b66c09625acde6e6db85e859cf21d5a Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 24 Jun 2024 09:35:51 -0700 Subject: [PATCH] [Mosaic GPU] Allow __init__.py to run without _src.lib.mosaic_gpu being available. PiperOrigin-RevId: 646124431 --- jax/experimental/mosaic/gpu/__init__.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 3c3445c07..7df477ab6 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -31,7 +31,6 @@ from jax._src import config from jax._src import core as jax_core from jax._src.interpreters import mlir from jax._src.lib import xla_client -from jax._src.lib import mosaic_gpu as mosaic_gpu_lib from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import builtin @@ -68,8 +67,18 @@ TMA_DESCRIPTOR_ALIGNMENT = 64 c = mgpu.c # This is too common to fully qualify. -RUNTIME_PATH = pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent / "libmosaic_gpu_runtime.so" -if RUNTIME_PATH.exists(): +RUNTIME_PATH = None +try: + from jax._src.lib import mosaic_gpu as mosaic_gpu_lib + + RUNTIME_PATH = ( + pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent + / "libmosaic_gpu_runtime.so" + ) +except ImportError: + pass + +if RUNTIME_PATH and RUNTIME_PATH.exists(): # Set this so that the custom call can find it os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH)