Fix imports in Mosaic GPU examples

PiperOrigin-RevId: 629003217
This commit is contained in:
Adam Paszke 2024-04-29 02:27:40 -07:00 committed by jax authors
parent 47fdc7b08f
commit 8741ab2f25

View File

@ -23,14 +23,14 @@ from jax.experimental.mosaic import gpu as mosaic_gpu
from jax.experimental.mosaic.gpu import profiler
from jax.experimental.mosaic.gpu.dsl import * # noqa: F403
import jax.numpy as jnp
from mlir import ir
from mlir.dialects import arith
from mlir.dialects import gpu
from mlir.dialects import memref
from mlir.dialects import nvgpu
from mlir.dialects import nvvm
from mlir.dialects import scf
from mlir.dialects import vector
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import gpu
from jaxlib.mlir.dialects import memref
from jaxlib.mlir.dialects import nvgpu
from jaxlib.mlir.dialects import nvvm
from jaxlib.mlir.dialects import scf
from jaxlib.mlir.dialects import vector
import numpy as np
# mypy: ignore-errors