mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic GPU] Allow __init__.py to run without _src.lib.mosaic_gpu being available.
PiperOrigin-RevId: 646124431
This commit is contained in:
parent
e0b2144000
commit
e119fe933b
@ -31,7 +31,6 @@ from jax._src import config
|
||||
from jax._src import core as jax_core
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith
|
||||
from jaxlib.mlir.dialects import builtin
|
||||
@ -68,8 +67,18 @@ TMA_DESCRIPTOR_ALIGNMENT = 64
|
||||
c = mgpu.c # This is too common to fully qualify.
|
||||
|
||||
|
||||
RUNTIME_PATH = pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent / "libmosaic_gpu_runtime.so"
|
||||
if RUNTIME_PATH.exists():
|
||||
RUNTIME_PATH = None
|
||||
try:
|
||||
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
|
||||
|
||||
RUNTIME_PATH = (
|
||||
pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent
|
||||
/ "libmosaic_gpu_runtime.so"
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if RUNTIME_PATH and RUNTIME_PATH.exists():
|
||||
# Set this so that the custom call can find it
|
||||
os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user