diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 82fe9c2ba..bf6899fbe 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -28,7 +28,6 @@ from jax._src import core as jax_core from jax._src import dtypes from jax._src import util from jax._src.pallas import core as pallas_core -from jax._src.pallas import pallas_call import jax.numpy as jnp import numpy as np