remove several more symbols from jax.core

* `DBIdx`
* `DConcreteArray`
* `DimensionHandler`
* `DuplicateAxisNameError`

PiperOrigin-RevId: 510503517
This commit is contained in:
Roy Frostig 2023-02-17 13:04:00 -08:00 committed by jax authors
parent 1248383967
commit 6b4de4f91c
2 changed files with 4 additions and 8 deletions

View File

@ -26,13 +26,9 @@ from jax._src.core import (
ClosedJaxpr as ClosedJaxpr,
ConcreteArray as ConcreteArray,
ConcretizationTypeError as ConcretizationTypeError,
DBIdx as DBIdx,
DConcreteArray as DConcreteArray,
DShapedArray as DShapedArray,
DimSize as DimSize,
DimensionHandler as DimensionHandler,
DropVar as DropVar,
DuplicateAxisNameError as DuplicateAxisNameError,
Effect as Effect,
Effects as Effects,
EvalTrace as EvalTrace,

View File

@ -27,20 +27,20 @@ import jax
from jax import lax
from jax import numpy as jnp
from jax import jvp, linearize, vjp, jit, make_jaxpr
from jax.core import UnshapedArray, ShapedArray, DBIdx
from jax.api_util import flatten_fun_nokwargs
from jax.config import config
from jax.interpreters import partial_eval as pe
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce,
tree_leaves)
from jax.api_util import flatten_fun_nokwargs
from jax.interpreters import partial_eval as pe
from jax._src import core
from jax._src import linear_util as lu
from jax._src import util
from jax._src import test_util as jtu
from jax._src.core import UnshapedArray, ShapedArray, DBIdx
from jax._src.lax import lax as lax_internal
from jax._src.lax import control_flow as lax_control_flow
from jax.config import config
config.parse_flags_with_absl()
_ = pe.PartialVal.unknown(UnshapedArray(np.float32))