mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
remove several more symbols from jax.core
* `DBIdx` * `DConcreteArray` * `DimensionHandler` * `DuplicateAxisNameError` PiperOrigin-RevId: 510503517
This commit is contained in:
parent
1248383967
commit
6b4de4f91c
@ -26,13 +26,9 @@ from jax._src.core import (
|
||||
ClosedJaxpr as ClosedJaxpr,
|
||||
ConcreteArray as ConcreteArray,
|
||||
ConcretizationTypeError as ConcretizationTypeError,
|
||||
DBIdx as DBIdx,
|
||||
DConcreteArray as DConcreteArray,
|
||||
DShapedArray as DShapedArray,
|
||||
DimSize as DimSize,
|
||||
DimensionHandler as DimensionHandler,
|
||||
DropVar as DropVar,
|
||||
DuplicateAxisNameError as DuplicateAxisNameError,
|
||||
Effect as Effect,
|
||||
Effects as Effects,
|
||||
EvalTrace as EvalTrace,
|
||||
|
@ -27,20 +27,20 @@ import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import jvp, linearize, vjp, jit, make_jaxpr
|
||||
from jax.core import UnshapedArray, ShapedArray, DBIdx
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax.config import config
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce,
|
||||
tree_leaves)
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.core import UnshapedArray, ShapedArray, DBIdx
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
_ = pe.PartialVal.unknown(UnshapedArray(np.float32))
|
||||
|
Loading…
x
Reference in New Issue
Block a user