mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix imports in Mosaic GPU examples
PiperOrigin-RevId: 629003217
This commit is contained in:
parent
47fdc7b08f
commit
8741ab2f25
@ -23,14 +23,14 @@ from jax.experimental.mosaic import gpu as mosaic_gpu
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
from jax.experimental.mosaic.gpu.dsl import * # noqa: F403
|
||||
import jax.numpy as jnp
|
||||
from mlir import ir
|
||||
from mlir.dialects import arith
|
||||
from mlir.dialects import gpu
|
||||
from mlir.dialects import memref
|
||||
from mlir.dialects import nvgpu
|
||||
from mlir.dialects import nvvm
|
||||
from mlir.dialects import scf
|
||||
from mlir.dialects import vector
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith
|
||||
from jaxlib.mlir.dialects import gpu
|
||||
from jaxlib.mlir.dialects import memref
|
||||
from jaxlib.mlir.dialects import nvgpu
|
||||
from jaxlib.mlir.dialects import nvvm
|
||||
from jaxlib.mlir.dialects import scf
|
||||
from jaxlib.mlir.dialects import vector
|
||||
import numpy as np
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
Loading…
x
Reference in New Issue
Block a user