Fix circular import in pallas core file

This commit is contained in:
Ruturaj4 2025-04-09 19:36:15 -05:00
parent ac21549df0
commit ce7347a52b

View File

@ -28,7 +28,6 @@ from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import util
from jax._src.pallas import core as pallas_core
from jax._src.pallas import pallas_call
import jax.numpy as jnp
import numpy as np