Export the Shard type.

PiperOrigin-RevId: 511615655
This commit is contained in:
Brennan Saeta 2023-02-22 15:37:25 -08:00 committed by jax authors
parent 661c9e14c0
commit 893d359933

View File

@ -168,6 +168,9 @@ from jax import stages as stages
from jax import tree_util as tree_util
from jax import util as util
# Also circular dependency.
from jax._src.array import Shard as Shard
import jax.lib # TODO(phawkins): remove this export.
if hasattr(jax, '_src'):