mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Export the Shard
type.
PiperOrigin-RevId: 511615655
This commit is contained in:
parent
661c9e14c0
commit
893d359933
@ -168,6 +168,9 @@ from jax import stages as stages
|
||||
from jax import tree_util as tree_util
|
||||
from jax import util as util
|
||||
|
||||
# Also circular dependency.
|
||||
from jax._src.array import Shard as Shard
|
||||
|
||||
import jax.lib # TODO(phawkins): remove this export.
|
||||
|
||||
if hasattr(jax, '_src'):
|
||||
|
Loading…
x
Reference in New Issue
Block a user