migrate more internal dependencies from jax.core to jax._src.core

PiperOrigin-RevId: 509736368
This commit is contained in:
Roy Frostig 2023-02-14 23:00:40 -08:00 committed by jax authors
parent b476661b4a
commit cb8dcce2fe
71 changed files with 230 additions and 201 deletions

View File

@ -35,6 +35,10 @@ del _cloud_tpu_init
from jax import config as _config_module
del _config_module
# Force early import, allowing use of `jax.core` after importing `jax`.
import jax.core as _core
del _core
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570

View File

@ -16,16 +16,16 @@ from __future__ import annotations
import dataclasses
import inspect
import threading
from typing import Any, Dict, Hashable, List, Optional, Protocol, Tuple
import numpy as np
import jax.numpy as jnp
from jax import core
from jax import tree_util
from jax._src import core
from jax._src import debugging
from jax._src import traceback_util
from jax._src import util
import numpy as np
@tree_util.register_pytree_node_class

View File

@ -16,13 +16,14 @@ from functools import partial
import enum
from typing import Callable, Sequence, Union
from jax import core
import numpy as np
from jax import jit
from jax import lax
from jax import numpy as jnp
from jax._src import core
from jax._src.util import canonicalize_axis
from jax._src.numpy.util import _promote_dtypes_inexact
import numpy as np
def _fill_lanczos_kernel(radius, x):

View File

@ -1539,7 +1539,7 @@ class PmapComputation(stages.XlaLowering):
class UnloadedPmapExecutable:
compiled: Any
backend: xb.XlaBackend
local_input_avals: Sequence[jax.core.AbstractValue]
local_input_avals: Sequence[core.AbstractValue]
input_shardings: Sequence[sharding_internal.XLACompatibleSharding]
local_output_avals: Sequence[ShapedArray]
output_shardings: Sequence[sharding_internal.XLACompatibleSharding]

View File

@ -21,10 +21,7 @@ import operator
from typing import Callable, Sequence, List, Tuple
from jax import core
from jax._src import linear_util as lu
from jax.config import config
from jax.core import ConcreteArray, raise_to_shaped
from jax.interpreters import ad
from jax.interpreters import batching
from jax._src.interpreters import mlir
@ -32,11 +29,13 @@ from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
from jax._src.core import replace_jaxpr_effects
from jax._src import core
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import util
from jax._src import state
from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects
from jax._src.lax import lax
from jax._src.traceback_util import api_boundary
from jax._src.util import (safe_map, extend_name_stack, split_list,

View File

@ -17,9 +17,8 @@ import operator
from typing import Any, Callable, Generic, List, Optional, Sequence, Set, Tuple, TypeVar, Union
from jax import core
import jax.numpy as jnp
from jax import lax
from jax._src import linear_util as lu
from jax.api_util import flatten_fun_nokwargs
from jax.interpreters import ad
from jax.interpreters import batching
@ -28,14 +27,15 @@ from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import (tree_flatten, tree_structure, tree_unflatten,
treedef_tuple, tree_map, tree_leaves, PyTreeDef)
from jax._src import ad_util
from jax._src import core
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import state
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
split_list, split_dict)
import jax.numpy as jnp
from jax._src.lax.control_flow import loops
from jax._src.lax.control_flow.common import _abstractify, _initial_style_jaxpr

View File

@ -23,7 +23,7 @@ import weakref
from jax._src import core
from jax._src import linear_util as lu
from jax.config import config
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax.interpreters import ad
from jax.interpreters import batching
from jax._src.interpreters import mlir

View File

@ -17,10 +17,7 @@ from functools import partial
import operator
import jax
from jax import core
from jax import lax
from jax._src import linear_util as lu
from jax.core import raise_to_shaped
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
@ -28,6 +25,9 @@ from jax.interpreters import xla
from jax.tree_util import (tree_flatten, treedef_children, tree_leaves,
tree_unflatten, treedef_tuple)
from jax._src import ad_util
from jax._src import core
from jax._src import linear_util as lu
from jax._src.core import raise_to_shaped
from jax._src.traceback_util import api_boundary
from jax._src.util import split_list, safe_map
import numpy as np

View File

@ -4426,7 +4426,7 @@ def rng_bit_generator(key, shape, dtype=np.uint32,
Most users should use `jax.random` instead for a stable and more user
friendly API.
"""
shape = jax.core.canonicalize_shape(shape)
shape = core.canonicalize_shape(shape)
dtype = dtypes.canonicalize_dtype(dtype)
if np.dtype(dtype) not in {np.dtype('uint8'), np.dtype('uint16'),
np.dtype('uint32'), np.dtype('uint64')}:

View File

@ -21,7 +21,7 @@ from typing import (Callable, Iterable, Tuple, Optional, Dict, Any, Set,
from functools import wraps, partial, partialmethod, lru_cache
from jax import numpy as jnp
from jax import core
from jax._src import core
from jax._src import linear_util as lu
from jax import stages
from jax._src import dispatch

View File

@ -21,14 +21,14 @@ import numpy as np
from typing import Any, Optional, Tuple, Union
import jax
from jax import custom_jvp
from jax._src import dtypes
from jax import lax
from jax import core
from jax.core import AxisName
from jax._src import util
from jax._src.ops.special import logsumexp as _logsumexp
import jax.numpy as jnp
from jax import custom_jvp
from jax import lax
from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src.core import AxisName
from jax._src.ops.special import logsumexp as _logsumexp
Array = Any

View File

@ -25,9 +25,9 @@ import numpy as np
import jax.numpy as jnp
from jax import lax
from jax import random
from jax import core
from jax._src.util import prod
from jax._src import core
from jax._src import dtypes
from jax._src.util import prod
KeyArray = random.KeyArray
Array = Any

View File

@ -16,8 +16,8 @@ import abc
from typing import Any, Iterable, List, Tuple, Union
import jax
from jax import core
import jax._src.numpy.lax_numpy as jnp
from jax._src import core
from jax._src.numpy.util import _promote_dtypes
from jax._src.typing import Array, ArrayLike

View File

@ -41,16 +41,16 @@ import opt_einsum
import jax
from jax import jit
from jax import core
from jax import errors
from jax import lax
from jax.core import ShapedArray, DShapedArray, ConcreteArray
from jax.interpreters import pxla
from jax.tree_util import tree_leaves, tree_flatten, tree_map
from jax._src import core
from jax._src import device_array
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
from jax._src.core import ShapedArray, DShapedArray, ConcreteArray
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
_sort_le_comparator, PrecisionLike)
from jax._src.lax import lax as lax_internal

View File

@ -17,18 +17,20 @@ from functools import partial
import operator
from typing import Optional, Tuple, Union
from jax import core
import numpy as np
from jax import jit
from jax import lax
from jax._src import dtypes
from jax._src import core
from jax._src.numpy.lax_numpy import (
all, arange, argmin, array, asarray, atleast_1d, concatenate, convolve, diag, dot,
finfo, full, maximum, ones, outer, roll, sqrt, trim_zeros, trim_zeros_tol, true_divide,
vander, zeros)
all, arange, argmin, array, asarray, atleast_1d, concatenate, convolve,
diag, dot, finfo, full, maximum, ones, outer, roll, sqrt, trim_zeros,
trim_zeros_tol, true_divide, vander, zeros)
from jax._src.numpy import linalg
from jax._src.numpy.util import _check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _where, _wraps
from jax._src.numpy.util import (
_check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _where, _wraps)
from jax._src.typing import Array, ArrayLike
import numpy as np
@jit

View File

@ -20,15 +20,18 @@ import warnings
import numpy as np
from jax import core
from jax import lax
from jax._src import api
from jax._src import core
from jax._src import dtypes
from jax._src.numpy.ndarray import ndarray
from jax._src.numpy.util import _broadcast_to, _check_arraylike, _complex_elem_type, _promote_dtypes_inexact, _promote_dtypes_numeric, _where, _wraps
from jax._src.numpy.util import (
_broadcast_to, _check_arraylike, _complex_elem_type,
_promote_dtypes_inexact, _promote_dtypes_numeric, _where, _wraps)
from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.util import canonicalize_axis as _canonicalize_axis, maybe_named_axis, prod as _prod
from jax._src.util import (
canonicalize_axis as _canonicalize_axis, maybe_named_axis, prod as _prod)
_all = builtins.all

View File

@ -17,6 +17,12 @@ import operator
from textwrap import dedent as _dedent
from typing import Optional, Tuple, Union
import numpy as np
from jax import jit
from jax import lax
from jax._src import core
from jax._src import dtypes
from jax._src.lax import lax as lax_internal
from jax._src.numpy.lax_numpy import (
@ -26,10 +32,6 @@ from jax._src.numpy.lax_numpy import (
from jax._src.numpy.util import _check_arraylike, _wraps
from jax._src.typing import Array, ArrayLike
from jax._src.util import prod as _prod
from jax import core
from jax import jit
from jax import lax
import numpy as np
_lax_const = lax_internal._const

View File

@ -24,16 +24,16 @@ from typing import Any, Callable, Tuple, Union, overload
import numpy as np
from jax._src.api import jit, custom_jvp
from jax import lax
from jax._src import core
from jax._src import dtypes
from jax._src.api import jit, custom_jvp
from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import (
_asarray, _check_arraylike, _promote_args, _promote_args_inexact,
_promote_args_numeric, _promote_dtypes_inexact, _promote_dtypes_numeric,
_promote_shapes, _where, _wraps)
from jax import core
from jax import lax
_lax_const = lax_internal._const

View File

@ -20,14 +20,13 @@ from typing import (
)
import warnings
from jax._src.config import config
from jax._src import dtypes
from jax._src.lax import lax as lax_internal
from jax._src import api
from jax._src import core
from jax._src.config import config
from jax._src.lax import lax
from jax._src.numpy.ndarray import ndarray
from jax._src.util import safe_zip, safe_map
from jax._src import api
from jax import core
from jax._src.lax import lax
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape
import numpy as np
@ -232,7 +231,7 @@ def _asarray(arr: ArrayLike) -> Array:
"""
_check_arraylike("_asarray", arr)
dtype, weak_type = dtypes._lattice_result_type(arr)
return lax_internal._convert_element_type(arr, dtype, weak_type)
return lax._convert_element_type(arr, dtype, weak_type)
def _promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
@ -283,7 +282,7 @@ def _promote_dtypes(*args: ArrayLike) -> List[Array]:
else:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
return [lax_internal._convert_element_type(x, to_dtype, weak_type) for x in args]
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
def _promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
@ -293,7 +292,7 @@ def _promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype)
return [lax_internal._convert_element_type(x, to_dtype_inexact, weak_type)
return [lax._convert_element_type(x, to_dtype_inexact, weak_type)
for x in args]
@ -304,7 +303,7 @@ def _promote_dtypes_numeric(*args: ArrayLike) -> List[Array]:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype_numeric = dtypes.to_numeric_dtype(to_dtype)
return [lax_internal._convert_element_type(x, to_dtype_numeric, weak_type)
return [lax._convert_element_type(x, to_dtype_numeric, weak_type)
for x in args]
@ -315,7 +314,7 @@ def _promote_dtypes_complex(*args: ArrayLike) -> List[Array]:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype_complex = dtypes.to_complex_dtype(to_dtype)
return [lax_internal._convert_element_type(x, to_dtype_complex, weak_type)
return [lax._convert_element_type(x, to_dtype_complex, weak_type)
for x in args]
@ -426,7 +425,7 @@ def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array:
"be provided to jax.numpy.where, got {} and {}."
.format(x, y))
if not np.issubdtype(_dtype(condition), np.bool_):
condition = lax.ne(condition, lax_internal._zero(condition))
condition = lax.ne(condition, lax._zero(condition))
x, y = _promote_dtypes(x, y)
condition_arr, x_arr, y_arr = _broadcast_arrays(condition, x, y)
try:

View File

@ -20,9 +20,9 @@ import warnings
import numpy as np
from jax import core
from jax import lax
from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src.lax import lax as lax_internal

View File

@ -19,14 +19,18 @@ from typing import cast, Any, List, Optional, Tuple
import numpy as np
import scipy.special as osp_special
from jax._src import api
from jax._src import dtypes
from jax import jit, vmap
from jax import lax, core
from jax.interpreters import ad
import jax.numpy as jnp
from jax import jit
from jax import vmap
from jax import lax
from jax.interpreters import ad
from jax._src import api
from jax._src import core
from jax._src import dtypes
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.lax_numpy import moveaxis, _promote_args_inexact, _promote_dtypes_inexact
from jax._src.numpy.lax_numpy import (
moveaxis, _promote_args_inexact, _promote_dtypes_inexact)
from jax._src.numpy.util import _wraps
from jax._src.ops import special as ops_special
from jax._src.third_party.scipy.betaln import betaln as _betaln_impl

View File

@ -21,6 +21,7 @@ from typing import (Sequence, List, Tuple, Optional, Mapping, Dict, Set,
FrozenSet, Union, cast)
import jax
from jax._src import core
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax.interpreters import mlir
@ -483,7 +484,7 @@ class PmapSharding(XLACompatibleSharding):
"""
# The dtype doesn't matter here. Its only used for creating the
# sharding_spec.
aval = jax.core.ShapedArray(shape, np.int32)
aval = core.ShapedArray(shape, np.int32)
sharding_spec = pxla._create_pmap_sharding_spec(aval, sharded_dim)
num_ways_sharded = None

View File

@ -20,14 +20,14 @@ from typing import Any, Dict, List, Optional, Protocol, Sequence, Tuple, Union
import numpy as np
from jax import core
from jax import lax
from jax._src import linear_util as lu
from jax.interpreters import partial_eval as pe
from jax._src.util import safe_map, safe_zip, split_list
from jax._src import core
from jax._src import linear_util as lu
from jax._src.state.types import ShapedArrayRef
from jax._src.state.primitives import get_p, swap_p, addupdate_p
from jax._src.util import safe_map, safe_zip, split_list
## JAX utilities

View File

@ -16,19 +16,21 @@ from functools import partial
from typing import Any, List, Protocol, Tuple, TypeVar, Union
from jax import core
import numpy as np
from jax import lax
from jax._src import ad_util
from jax._src import pretty_printer as pp
from jax._src.typing import Array
from jax._src.util import safe_map, safe_zip, partition_list, tuple_insert
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
import numpy as np
from jax._src import ad_util
from jax._src import core
from jax._src import pretty_printer as pp
from jax._src.typing import Array
from jax._src.state.types import (ShapedArrayRef, ReadEffect, WriteEffect,
AccumEffect)
from jax._src.util import safe_map, safe_zip, partition_list, tuple_insert
## General utilities

View File

@ -16,7 +16,7 @@ from __future__ import annotations
from typing import Any, List, Optional, Sequence, Set, Union
from jax import core
from jax._src import core
from jax._src.lib import xla_bridge, xla_client
from jax._src.util import safe_map, safe_zip, tuple_insert, tuple_delete, prod
from jax._src.lax.control_flow import common

View File

@ -14,7 +14,7 @@
import jax
import inspect
from jax import core
from jax._src import core
from jax import tree_util
from jax._src import linear_util as lu
from jax.experimental import pjit
@ -60,7 +60,7 @@ _CUSTOM_PARTITIONING_CALL_NAME = "CustomSPMDPartitioning"
def _to_jax_shape(s):
return jax.core.ShapedArray(s.dimensions(), s.numpy_dtype())
return core.ShapedArray(s.dimensions(), s.numpy_dtype())
def _custom_partitioning_propagate_user_sharding(sharding, shape, backend_string):

View File

@ -19,7 +19,7 @@ import numpy as np
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple
import jax
from jax import core
from jax._src import core
from jax._src import dispatch
from jax._src import api_util
from jax._src.lib import xla_bridge as xb

View File

@ -505,7 +505,7 @@ from typing import (Any, Callable, Dict, List, Optional, Sequence,
import warnings
from jax._src import api
from jax import core
from jax._src import core
from jax.config import config
from jax import custom_derivatives
from jax._src import dtypes

View File

@ -29,7 +29,6 @@ from typing import Any, Callable, Optional, Sequence, Tuple
from absl import logging
import jax
from jax import core
from jax import dlpack
from jax import dtypes
from jax import tree_util
@ -40,6 +39,7 @@ from jax._src import ad_checkpoint
from jax._src import custom_derivatives
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src import core
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib.mlir.dialects import hlo

View File

@ -18,7 +18,7 @@ from functools import partial, wraps
import string
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
from jax import core
from jax._src import core
from jax import lax
from jax._src.lax import slicing as lax_slicing
from jax._src import dtypes

View File

@ -28,7 +28,6 @@ import numpy as np
import jax
from jax import lax
from jax import config
from jax import core
from jax import custom_derivatives
from jax import random
from jax import numpy as jnp
@ -45,6 +44,7 @@ from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import linear_util as lu

View File

@ -59,20 +59,20 @@ from functools import partial
import numpy as np
import jax
from jax import core
from jax import lax
from jax.interpreters import xla
import jax.numpy as jnp
from jax.experimental import pjit
from jax.interpreters import partial_eval as pe, pxla
from jax._src.api_util import shaped_abstractify
from jax.tree_util import (register_pytree_node, tree_structure,
treedef_is_leaf, tree_flatten, tree_unflatten,)
from jax._src import ad_util
from jax._src import core
from jax._src import dispatch
from jax._src.lax import lax as lax_internal
from jax._src import linear_util as lu
from jax._src.api_util import shaped_abstractify
from jax._src.lax import lax as lax_internal
from jax._src.util import unzip2, weakref_lru_cache

View File

@ -21,6 +21,7 @@ import zlib
from typing import Any
import jax
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
from jax._src import core
from jax._src import dispatch
from jax._src import array
from jax._src import sharding
@ -113,7 +114,7 @@ def _handle_array_process_allgather(inp, tiled):
if host_np_arr.ndim == 0 or not tiled:
host_np_arr = np.expand_dims(host_np_arr, axis=0)
aval = jax.core.ShapedArray(host_np_arr.shape, host_np_arr.dtype)
aval = core.ShapedArray(host_np_arr.shape, host_np_arr.dtype)
global_aval = global_mesh._local_to_global(
pxla.get_array_mapping(pspec), aval)
@ -325,7 +326,7 @@ def host_local_array_to_global_array(local_inputs: Any,
))
global_aval = _local_to_global_aval(
jax.core.ShapedArray(arr.shape, arrays[0].dtype), global_mesh, pspec)
core.ShapedArray(arr.shape, arrays[0].dtype), global_mesh, pspec)
return array.ArrayImpl(
global_aval, jax.sharding.NamedSharding(global_mesh, pspec),

View File

@ -31,7 +31,7 @@ import operator as op
import jax
import jax.numpy as jnp
from jax import core
from jax._src import core
from jax import custom_derivatives
from jax import lax
from jax._src.numpy.util import _promote_dtypes_inexact

View File

@ -87,7 +87,7 @@ from typing import Any, Dict, List, Tuple
import jax
import numpy as np
from jax import core
from jax._src import core
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.custom_derivatives import custom_vjp

View File

@ -24,9 +24,8 @@ from typing import (Any, Callable, Dict, Hashable, List, Optional, Sequence,
import numpy as np
import jax
from jax import core
from jax.core import Tracer
from jax.sharding import NamedSharding, PartitionSpec, Mesh
from jax._src import core
from jax._src import ad_util
from jax._src import custom_derivatives
from jax._src import linear_util as lu
@ -35,6 +34,7 @@ from jax._src import pjit
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.core import Tracer
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
windowed_reductions, fft, linalg)
from jax._src.util import (prod, HashableFunction, unzip2, as_hashable_function,
@ -502,7 +502,7 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
outs_, out_rep = unzip2((t.val, t.rep) for t in out_tracers)
del main, t, in_tracers, ans, out_tracers
out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs_]
_check_names(out_names_thunk(), out_avals)
_check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types
if check_rep: _check_reps(mesh, out_names_thunk(), out_rep)
return map(partial(_match_spec, mesh), out_rep, out_names_thunk(), outs_)
core.EvalTrace.process_shard_map = _shard_map_impl
@ -579,7 +579,7 @@ class ShardMapTrace(core.Trace):
fun, jaxpr = _grab_jaxpr_shadily(fun) # TODO remove with initial-style jit
bind = partial(call_primitive.bind, fun) # TODO caching (compat w/ jaxpr())
fake_primitive = pxla.FakePrimitive(multiple_results=True, bind=bind)
_rep_rules[fake_primitive] = lambda *_, **__: set()
_rep_rules[fake_primitive] = lambda *_, **__: set() # pytype: disable=container-type-mismatch
out_tracers_ = self.process_primitive(fake_primitive, tracers, params)
out_vals = [t.val for t in out_tracers_]
if self.check:

View File

@ -16,7 +16,7 @@
import abc
from typing import Sequence, Tuple
from jax import core
from jax._src import core
import jax.numpy as jnp
from jax._src import util
from jax._src.typing import Array

View File

@ -16,7 +16,7 @@ import itertools
from typing import Any, Callable, Sequence, Tuple, Union
import jax
from jax import core
from jax._src import core
from jax import tree_util
from jax._src.api_util import _ensure_index, _ensure_index_tuple
from jax.util import safe_zip

View File

@ -34,7 +34,6 @@ import operator
from typing import Optional, Union
import jax
from jax import core
from jax import tree_util
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.bcoo import BCOO
@ -44,6 +43,7 @@ from jax.experimental.sparse.csr import CSR, CSC
from jax.experimental.sparse.util import _coo_extract
from jax.interpreters import mlir
from jax._src import core
from jax._src import dtypes
from jax._src.interpreters import ad
from jax._src.interpreters import batching

View File

@ -24,7 +24,6 @@ import warnings
import numpy as np
import jax
from jax import core
from jax import lax
from jax import tree_util
from jax import vmap
@ -40,6 +39,7 @@ from jax._src.interpreters import mlir
import jax.numpy as jnp
from jax.util import safe_zip, unzip2, split_list
from jax._src import api_util
from jax._src import core
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.lax.lax import (

View File

@ -25,7 +25,6 @@ import numpy as np
import jax.numpy as jnp
from jax import config
from jax import core
from jax import lax
from jax import tree_util
from jax.experimental.sparse._base import JAXSparse
@ -37,6 +36,7 @@ from jax.experimental.sparse.util import (
from jax.util import split_list, safe_zip
from jax._src import api_util
from jax._src import core
from jax._src.lax.lax import DotDimensionNumbers, _dot_general_batch_dim_nums
from jax._src.lib import gpu_sparse
from jax._src.lib.mlir.dialects import hlo

View File

@ -22,12 +22,12 @@ import warnings
import numpy as np
from jax import core
from jax import lax
from jax.interpreters import mlir
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import _coo_extract, CuSparseEfficiencyWarning
from jax import tree_util
from jax._src import core
from jax._src.interpreters import ad
from jax._src.lax.lax import _const
from jax._src.lib.mlir.dialects import hlo

View File

@ -21,13 +21,13 @@ import warnings
import numpy as np
from jax import core
from jax.interpreters import mlir
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo
from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, CuSparseEfficiencyWarning
from jax import lax
from jax import tree_util
from jax._src import core
from jax._src.interpreters import ad
from jax._src.lax.lax import _const
from jax._src.lib import gpu_sparse

View File

@ -20,10 +20,10 @@ import functools
import jax
import jax.numpy as jnp
from jax import core
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src import core
from jax._src.lib import gpu_solver
import numpy as np

View File

@ -54,8 +54,8 @@ from typing import (
import numpy as np
import jax
from jax import core
from jax import lax
from jax._src import core
from jax._src import linear_util as lu
from jax._src import pjit
from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_sparse

View File

@ -19,10 +19,10 @@ from typing import Any, NamedTuple, Tuple, Union
import numpy as np
import jax
from jax import core
from jax import lax
from jax import tree_util
from jax import vmap
from jax._src import core
from jax._src import dtypes
from jax._src import stages
from jax._src.api_util import flatten_axes

View File

@ -46,10 +46,10 @@ import concurrent.futures
import jax
import jax.numpy as jnp
from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian
from jax import core, lax
from jax._src import core
from jax import lax
from jax import custom_batching
from jax._src import api, dtypes, dispatch, lib, api_util
from jax.core import Primitive
from jax.errors import UnexpectedTracerError
from jax.interpreters import ad
from jax._src.interpreters import mlir
@ -965,7 +965,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
for obj in [lowered, compiled]:
self.assertEqual(
obj.in_avals,
((jax.core.ShapedArray([], expected_dtype, weak_type=True),), {}))
((core.ShapedArray([], expected_dtype, weak_type=True),), {}))
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0,), {}))[1])
def test_jit_lower_duck_typing(self):
@ -1449,7 +1449,7 @@ class APITest(jtu.JaxTestCase):
r"The __index__\(\) method was called on the JAX Tracer object.*", lambda: jit(f)(0))
def test_unimplemented_interpreter_rules(self):
foo_p = Primitive('foo')
foo_p = core.Primitive('foo')
def foo(x):
return foo_p.bind(x)
@ -3543,12 +3543,12 @@ class APITest(jtu.JaxTestCase):
def test_jit_returning_token(self):
x = jax.jit(jax.lax.create_token)(1.0)
self.assertIsInstance(x, jax.core.Token)
self.assertIsInstance(x, core.Token)
def test_jit_capturing_token(self):
tok = jax.core.token
tok = core.token
_, y = jax.jit(lambda x: (x + 2, tok))(7)
self.assertIsInstance(y, jax.core.Token)
self.assertIsInstance(y, core.Token)
def test_leak_checker_catches_a_jit_leak(self):
with jax.checking_leaks():
@ -4119,7 +4119,7 @@ class APITest(jtu.JaxTestCase):
return g(x)
jaxpr = jax.make_jaxpr(h)(7)
jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 7)
core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 7)
b(8) # don't crash
@ -4142,7 +4142,7 @@ class APITest(jtu.JaxTestCase):
return g(x)
jaxpr = jax.make_jaxpr(h)(7)
jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 7)
core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 7)
b(8) # don't crash
@ -4734,7 +4734,7 @@ class RematTest(jtu.JaxTestCase):
])
def test_remat_eval_counter(self, remat):
# https://github.com/google/jax/issues/2737
add_one_p = Primitive('add_one')
add_one_p = core.Primitive('add_one')
add_one = add_one_p.bind
num_evals = 0
@ -4772,7 +4772,7 @@ class RematTest(jtu.JaxTestCase):
@jax_util.curry
def call(f, *args):
return jax.core.call(
return core.call(
lu.wrap_init(lambda *args: [f(*args)]),
*args, name='foo')[0]

View File

@ -22,6 +22,7 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax._src import core
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src.lib import xla_client as xc
@ -145,7 +146,7 @@ class JaxArrayTest(jtu.JaxTestCase):
self.assertEqual(arr.is_fully_replicated, expected_is_fully_replicated)
for i, s in enumerate(arr.addressable_shards):
self.assertEqual(s.data.aval,
jax.core.ShapedArray(expected_shard_shape, s.data.dtype))
core.ShapedArray(expected_shard_shape, s.data.dtype))
self.assertArraysEqual(s.data, global_input_data[s.index])
self.assertArraysEqual(s.data, arr.addressable_data(i))
@ -318,13 +319,13 @@ class JaxArrayTest(jtu.JaxTestCase):
ValueError,
r'Expected 8 per-device arrays \(this is how many devices are addressable '
r'by the sharding\), but got 4'):
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs[:4], committed=True)
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs[:4], committed=True)
with self.assertRaisesRegex(
ValueError,
r'Expected 8 per-device arrays \(this is how many devices are addressable '
r'by the sharding\), but got 16'):
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs + bufs, committed=True)
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs + bufs, committed=True)
def test_arrays_not_in_device_assignment(self):
if jax.device_count() < 4:
@ -342,7 +343,7 @@ class JaxArrayTest(jtu.JaxTestCase):
"Sharding contains devices {0, 1} that are not present in per-device "
"arrays. Per-device arrays contain devices {2, 3} that are not present "
"in the sharding."):
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
def test_more_devices_in_sharding_than_arrays(self):
shape = (8, 2)
@ -357,7 +358,7 @@ class JaxArrayTest(jtu.JaxTestCase):
"Addressable devices and per-device arrays devices do not match. "
r"Sharding contains devices \{1\} that are not present in per-device "
"arrays."):
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
def test_different_devices_in_arrays_than_sharding(self):
if jax.device_count() < 3:
@ -375,7 +376,7 @@ class JaxArrayTest(jtu.JaxTestCase):
r"Sharding contains devices \{2\} that are not present in per-device "
r"arrays. Per-device arrays contain devices \{0\} that are not present "
"in the sharding."):
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"), (2, 2)),
@ -410,7 +411,7 @@ class JaxArrayTest(jtu.JaxTestCase):
ValueError,
"Input buffers to `Array` must have matching dtypes. "
"Got int32, expected float32"):
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
def test_array_iter_pmap_sharding(self):
if jax.device_count() < 2:
@ -975,7 +976,7 @@ class RngShardingTest(jtu.JaxTestCase):
def test_pickle_pjit_lower(self):
example_exe = jax.jit(lambda x: x * x).lower(
jax.core.ShapedArray(
core.ShapedArray(
(2, 2), dtype=np.float32)).compile()._executable.xla_executable
# Skip if CompileOptions is not available. This is true on
@ -995,7 +996,7 @@ class RngShardingTest(jtu.JaxTestCase):
fun,
in_axis_resources=P('data'),
out_axis_resources=P(None, 'data'),
).lower(jax.core.ShapedArray(shape=(8, 8), dtype=np.float32))
).lower(core.ShapedArray(shape=(8, 8), dtype=np.float32))
def verify_serialization(lowered):
serialized, in_tree, out_tree = compile_and_serialize(lowered)

View File

@ -24,6 +24,7 @@ from absl.testing import parameterized
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax._src import core
from jax._src import dtypes
from jax._src import test_util as jtu
from jax import lax
@ -1178,7 +1179,7 @@ class BatchingTest(jtu.JaxTestCase):
f = vmap(jax.grad(lambda x: -lax.psum(x, 'i')), out_axes=None, axis_name='i')
self.assertEqual(
f(a),
jax.core.jaxpr_as_fun(jax.make_jaxpr(f)(a))(a)[0])
core.jaxpr_as_fun(jax.make_jaxpr(f)(a))(a)[0])
def testAllGatherToUnmapped(self):
def f(x):
@ -1301,7 +1302,7 @@ class BatchingTest(jtu.JaxTestCase):
Array = Any
ArrayElt = Any
Int = Union[int, jax.core.Tracer]
Int = Union[int, core.Tracer]
# Can't used NamedTuple here b/c those are pytrees
class NamedArray:

View File

@ -28,6 +28,7 @@ from jax.experimental import checkify
from jax.experimental import pjit
from jax._src.sharding import NamedSharding
from jax._src import array
from jax._src import core
from jax._src.checkify import JaxRuntimeError, FailedCheckError, ErrorEffect, OOBError
import jax.numpy as jnp
@ -1173,7 +1174,7 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
return x
x = jnp.ones(())
jaxpr = jax.make_jaxpr(f)(x)
roundtrip_f = partial(jax.core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)
roundtrip_f = partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)
checked_f = checkify.checkify(jax.jit(roundtrip_f))
err, _ = checked_f(jnp.ones(()))
self.assertIsNotNone(err.get())

View File

@ -24,10 +24,8 @@ from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import core
from jax import lax
from jax import numpy as jnp
from jax._src import linear_util as lu
from jax import jvp, linearize, vjp, jit, make_jaxpr
from jax.core import UnshapedArray, ShapedArray, DBIdx
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce,
@ -35,6 +33,8 @@ from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce,
from jax.api_util import flatten_fun_nokwargs
from jax.interpreters import partial_eval as pe
from jax._src import core
from jax._src import linear_util as lu
from jax._src import util
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal

View File

@ -19,9 +19,10 @@ from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import core, grad, jit, vmap, lax
import jax.numpy as jnp
from jax import grad, jit, vmap, lax
from jax._src import config as jax_config
from jax._src import core
from jax._src import test_util as jtu
from jax._src import source_info_util
from jax._src import traceback_util

View File

@ -19,7 +19,7 @@ import unittest
import numpy as np
import jax
from jax import core
from jax._src import core
from jax._src import test_util as jtu
from jax._src.util import prod, safe_zip

View File

@ -27,7 +27,7 @@ from absl.testing import absltest
import jax
from jax import ad_checkpoint
from jax import core
from jax._src import core
from jax.config import config
from jax import dtypes
from jax.experimental import host_callback as hcb

View File

@ -20,6 +20,7 @@ import jax
from jax import lax, numpy as jnp
from jax import config
from jax.experimental import host_callback as hcb
from jax._src import core
from jax._src.lib import xla_client
import jax._src.test_util as jtu
import numpy as np
@ -37,9 +38,9 @@ class InfeedTest(jtu.JaxTestCase):
def f(x):
token = lax.create_token(x)
(y,), token = lax.infeed(
token, shape=(jax.core.ShapedArray((3, 4), jnp.float32),))
token, shape=(core.ShapedArray((3, 4), jnp.float32),))
(z,), _ = lax.infeed(
token, shape=(jax.core.ShapedArray((3, 1, 1), jnp.float32),))
token, shape=(core.ShapedArray((3, 1, 1), jnp.float32),))
return x + y + z
x = np.float32(1.5)
@ -55,8 +56,8 @@ class InfeedTest(jtu.JaxTestCase):
x = np.float32(1.5)
y = np.reshape(np.arange(12, dtype=np.int16), (3, 4))
to_infeed = dict(a=x, b=y)
to_infeed_shape = dict(a=jax.core.ShapedArray((), dtype=np.float32),
b=jax.core.ShapedArray((3, 4), dtype=np.int16))
to_infeed_shape = dict(a=core.ShapedArray((), dtype=np.float32),
b=core.ShapedArray((3, 4), dtype=np.int16))
@jax.jit
def f(x):
token = lax.create_token(x)
@ -77,7 +78,7 @@ class InfeedTest(jtu.JaxTestCase):
def f(x):
token = lax.create_token(x)
y, token = lax.infeed(
token, shape=jax.core.ShapedArray((3, 4), jnp.float32))
token, shape=core.ShapedArray((3, 4), jnp.float32))
token = lax.outfeed(token, y + np.float32(1))
return x - 1
@ -97,7 +98,7 @@ class InfeedTest(jtu.JaxTestCase):
def doubler(_, token):
y, token = lax.infeed(
token, shape=jax.core.ShapedArray((3, 4), jnp.float32))
token, shape=core.ShapedArray((3, 4), jnp.float32))
return lax.outfeed(token, y * np.float32(2))
@jax.jit

View File

@ -19,6 +19,7 @@ from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import api
from jax._src import core
from jax import dtypes
from jax._src import lib as jaxlib
from jax import numpy as jnp
@ -60,7 +61,7 @@ class JaxJitTest(jtu.JaxTestCase):
self.assertFalse(output_buffer.aval.weak_type)
dtype = dtypes.canonicalize_dtype(dtype)
self.assertEqual(output_buffer.aval, jax.core.ShapedArray((), dtype))
self.assertEqual(output_buffer.aval, core.ShapedArray((), dtype))
self.assertEqual(output_buffer.dtype, dtype)
@parameterized.parameters([jax.device_put, _cpp_device_put])
@ -73,7 +74,7 @@ class JaxJitTest(jtu.JaxTestCase):
self.assertFalse(output_buffer.aval.weak_type)
dtype = dtypes.canonicalize_dtype(dtype)
self.assertEqual(output_buffer.aval, jax.core.ShapedArray((3, 4), dtype))
self.assertEqual(output_buffer.aval, core.ShapedArray((3, 4), dtype))
self.assertEqual(output_buffer.dtype, dtype)
np.testing.assert_array_equal(output_buffer, np.zeros((3, 4),
dtype=dtype))

View File

@ -19,7 +19,7 @@ import warnings
from absl.testing import absltest
import jax
import jax.numpy as jnp
from jax import core
from jax._src import core
from jax import lax
from jax._src import linear_util as lu
from jax.config import config

View File

@ -26,7 +26,7 @@ from absl.testing import parameterized
import numpy as np
import jax
from jax import core
from jax._src import core
from jax import dtypes
from jax.errors import UnexpectedTracerError
from jax import lax

View File

@ -40,6 +40,7 @@ from jax import numpy as jnp
from jax import tree_util
from jax.test_util import check_grads
from jax._src import core
from jax._src import device_array
from jax._src import dtypes
from jax._src import test_util as jtu
@ -2296,9 +2297,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testSearchsortedDtype(self):
# Test that for large arrays, int64 indices are used. We test this
# via abstract evaluation to avoid allocating a large array in tests.
a_int32 = jax.core.ShapedArray((np.iinfo(np.int32).max,), np.float32)
a_int64 = jax.core.ShapedArray((np.iinfo(np.int32).max + 1,), np.float32)
v = jax.core.ShapedArray((), np.float32)
a_int32 = core.ShapedArray((np.iinfo(np.int32).max,), np.float32)
a_int64 = core.ShapedArray((np.iinfo(np.int32).max + 1,), np.float32)
v = core.ShapedArray((), np.float32)
out_int32 = jax.eval_shape(jnp.searchsorted, a_int32, v)
self.assertEqual(out_int32.dtype, np.int32)
@ -3322,7 +3323,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
if mode == 'raise':
msg = ("The error occurred because ravel_multi_index was jit-compiled "
"with mode='raise'. Use mode='wrap' or mode='clip' instead.")
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg):
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
jax.jit(jnp_fun)(*args_maker())
else:
self._CompileAndCheck(jnp_fun, args_maker)
@ -3360,7 +3361,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
if mode == 'raise':
msg = ("The error occurred because jnp.choose was jit-compiled"
" with mode='raise'. Use mode='wrap' or mode='clip' instead.")
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg):
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
jax.jit(jnp_fun)(*args_maker())
else:
self._CompileAndCheck(jnp_fun, args_maker)
@ -4438,7 +4439,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
atol=atol,
rtol=rtol)
# abstract tracer value for jnp.mgrid slice
with self.assertRaisesRegex(jax.core.ConcretizationTypeError,
with self.assertRaisesRegex(core.ConcretizationTypeError,
"slice start of jnp.mgrid"):
jax.jit(lambda a, b: jnp.mgrid[a:b])(0, 2)
@ -4479,7 +4480,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
atol=atol,
rtol=rtol)
# abstract tracer value for ogrid slice
with self.assertRaisesRegex(jax.core.ConcretizationTypeError,
with self.assertRaisesRegex(core.ConcretizationTypeError,
"slice start of jnp.ogrid"):
jax.jit(lambda a, b: jnp.ogrid[a:b])(0, 2)
@ -4506,7 +4507,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, "could not understand directive.*"):
jnp.r_["asdfgh",[1,2,3]]
# abstract tracer value for r_ slice
with self.assertRaisesRegex(jax.core.ConcretizationTypeError,
with self.assertRaisesRegex(core.ConcretizationTypeError,
"slice start of jnp.r_"):
jax.jit(lambda a, b: jnp.r_[a:b])(0, 2)
@ -4555,7 +4556,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, "could not understand directive.*"):
jnp.c_["asdfgh",[1,2,3]]
# abstract tracer value for c_ slice
with self.assertRaisesRegex(jax.core.ConcretizationTypeError,
with self.assertRaisesRegex(core.ConcretizationTypeError,
"slice start of jnp.c_"):
jax.jit(lambda a, b: jnp.c_[a:b])(0, 2)
@ -4948,13 +4949,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testArangeConcretizationError(self):
msg = r"It arose in jax.numpy.arange argument `{}`".format
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('stop')):
with self.assertRaisesRegex(core.ConcretizationTypeError, msg('stop')):
jax.jit(jnp.arange)(3)
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('start')):
with self.assertRaisesRegex(core.ConcretizationTypeError, msg('start')):
jax.jit(lambda start: jnp.arange(start, 3))(0)
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('stop')):
with self.assertRaisesRegex(core.ConcretizationTypeError, msg('stop')):
jax.jit(lambda stop: jnp.arange(0, stop))(3)
@jtu.sample_product(dtype=[None] + float_dtypes)

View File

@ -28,7 +28,7 @@ from absl.testing import parameterized
import numpy as np
import jax
from jax import core
from jax._src import core
from jax import lax
import jax.numpy as jnp
from jax.test_util import check_grads

View File

@ -28,6 +28,7 @@ import numpy as np
import jax
from jax import experimental
from jax.config import config
from jax._src import core
from jax._src import distributed
import jax.numpy as jnp
from jax._src import test_util as jtu
@ -537,7 +538,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
f = pjit.pjit(lambda x, y: (x, y),
in_axis_resources=experimental.PartitionSpec("x", "y"),
out_axis_resources=experimental.PartitionSpec("x", "y"))
inp_aval = jax.core.ShapedArray((8, 2), jnp.int32)
inp_aval = core.ShapedArray((8, 2), jnp.int32)
# `ShapedArray` is considered global when lowered and compiled.
# Hence it can bypass the contiguous mesh restriction.
compiled = f.lower(inp_aval, gda1).compile()

View File

@ -16,7 +16,7 @@ import functools
from absl.testing import absltest
import jax
import jax.numpy as jnp
from jax import core
from jax._src import core
from jax import lax
from jax._src.pjit import pjit
from jax._src import linear_util as lu

View File

@ -23,7 +23,7 @@ from absl.testing import parameterized
import scipy.stats
from jax import core
from jax._src import core
from jax._src import test_util as jtu
from jax.test_util import check_grads
from jax import nn

View File

@ -28,6 +28,7 @@ import concurrent.futures
import jax
import jax.numpy as jnp
from jax._src import core
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src.config import parallel_functions_output_gda, jax_array
@ -588,7 +589,7 @@ class PJitTest(jtu.BufferDonationTestCase):
in_axis_resources=(P('x'), P('y')),
out_axis_resources=P('y'))
f_jaxpr = jax.make_jaxpr(f)(x, y)
f_eval = jax.core.jaxpr_as_fun(f_jaxpr)
f_eval = core.jaxpr_as_fun(f_jaxpr)
r, = f_eval(x, y)
self.assertAllClose(r, x.sum() + jnp.sin(y))
@ -727,11 +728,11 @@ class PJitTest(jtu.BufferDonationTestCase):
def f_for_jit(x):
token = lax.create_token(x)
(y,), token = lax.infeed(
token, shape=(jax.core.ShapedArray(x.shape, np.float32),))
token, shape=(core.ShapedArray(x.shape, np.float32),))
(z,), token = lax.infeed(
token, shape=(jax.core.ShapedArray(x.shape, np.float32),))
token, shape=(core.ShapedArray(x.shape, np.float32),))
(w,), token = lax.infeed(
token, shape=(jax.core.ShapedArray(x.shape, np.float32),))
token, shape=(core.ShapedArray(x.shape, np.float32),))
return x + y + z + w
@ -761,17 +762,17 @@ class PJitTest(jtu.BufferDonationTestCase):
# A replicated infeed
(y,), token = lax.infeed(
token,
shape=(jax.core.ShapedArray(x.shape, np.float32),),
shape=(core.ShapedArray(x.shape, np.float32),),
partitions=(None,))
# An infeed sharded on first axis
(z,), token = lax.infeed(
token,
shape=(jax.core.ShapedArray(x.shape, np.float32),),
shape=(core.ShapedArray(x.shape, np.float32),),
partitions=(P(nr_devices, 1),))
# An infeed sharded on second axis
(w,), token = lax.infeed(
token,
shape=(jax.core.ShapedArray(x.shape, np.float32),),
shape=(core.ShapedArray(x.shape, np.float32),),
partitions=(P(1, nr_devices),))
return x + y + z + w
@ -855,7 +856,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertEqual(lowered.in_avals, compiled.in_avals)
self.assertEqual(
lowered.in_avals,
((jax.core.ShapedArray(x.shape, x.dtype, weak_type=False),) * 2, {}))
((core.ShapedArray(x.shape, x.dtype, weak_type=False),) * 2, {}))
splits = np.split(expected, 4)
self.assertAllClose(np.asarray(actual.device_buffers[0]), splits[0],
@ -1058,7 +1059,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x @ y
shape = (8, 8)
aval = jax.core.ShapedArray(shape, dtypes.canonicalize_dtype(jnp.int64))
aval = core.ShapedArray(shape, dtypes.canonicalize_dtype(jnp.int64))
x = jnp.arange(np.prod(shape)).reshape(shape)
exe = f.lower(aval, x).compile()
self.assertIsInstance(exe, stages.Compiled)
@ -1509,7 +1510,7 @@ class GDAPjitTest(jtu.JaxTestCase):
with global_mesh:
f = pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=P('x'))
compiled = f.lower(jax.core.ShapedArray(global_input_shape, jnp.float32)).compile()
compiled = f.lower(core.ShapedArray(global_input_shape, jnp.float32)).compile()
with self.assertRaisesRegex(
ValueError, "GDA sharding does not match the input sharding."):
compiled(input_gda)
@ -1521,7 +1522,7 @@ class GDAPjitTest(jtu.JaxTestCase):
g1, _ = create_gda(global_input_shape, global_mesh, P(None,))
with global_mesh:
f = pjit(lambda x: x, in_axis_resources=P(None), out_axis_resources=P('x'))
compiled = f.lower(jax.core.ShapedArray(global_input_shape, jnp.float32)).compile()
compiled = f.lower(core.ShapedArray(global_input_shape, jnp.float32)).compile()
compiled(g1) # no error
@parallel_functions_output_gda(True)
@ -1577,7 +1578,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
f = pjit(lambda x: x, in_axis_resources=AUTO,
out_axis_resources=AUTO)
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
inp = core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp).compile()
inputs = [create_gda(global_input_shape, global_mesh, ip, input_data)[0]
for ip in compiled.input_shardings[0]]
@ -1604,7 +1605,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
f = pjit(lambda x: x, in_axis_resources=AUTO,
out_axis_resources=AUTO)
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
inp = core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp).compile()
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
for ip in compiled.input_shardings[0]]
@ -1628,7 +1629,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
with ctx(True):
with global_mesh:
f = pjit(lambda x: x, in_axis_resources=AUTO, out_axis_resources=AUTO)
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
inp = core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp).compile()
different_pspec = (P('y', 'x')
@ -1652,7 +1653,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
with global_mesh:
f = pjit(lambda x, y, z: (x, y, z), in_axis_resources=AUTO,
out_axis_resources=AUTO)
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
inp = core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp, inp, inp).compile()
self.assertLen(compiled.output_shardings, 3)
self.assertLen(compiled.input_shardings[0], 3)
@ -1680,7 +1681,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
f = pjit(lambda x, y: (x, y), in_axis_resources=(in_resource, AUTO),
out_axis_resources=AUTO)
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
inp = core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp, inp).compile()
inputs = [create_gda(global_input_shape, global_mesh, ip, input_data)[0]
for ip in compiled.input_shardings[0]]
@ -1710,7 +1711,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
f = pjit(lambda x, y: (x, y), in_axis_resources=(in_resource, AUTO),
out_axis_resources=AUTO)
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
inp = core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp, inp).compile()
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
for ip in compiled.input_shardings[0]]
@ -1735,7 +1736,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
f = pjit(lambda x: x, in_axis_resources=AUTO,
out_axis_resources=AUTO)
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
inp = core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp).compile()
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
for ip in compiled.input_shardings[0]]
@ -1987,7 +1988,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
a1, input_data = create_array(global_input_shape, global_mesh, P('x', 'y'))
a2, _ = create_array(global_input_shape, global_mesh, P('x'))
aval = jax.core.ShapedArray(global_input_shape, np.float32)
aval = core.ShapedArray(global_input_shape, np.float32)
with jax_array(True):
with global_mesh:
@ -2111,7 +2112,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with jax_array(True):
with global_mesh:
f = pjit(lambda x: x, in_axis_resources=NamedSharding(global_mesh, P(None,)))
compiled = f.lower(jax.core.ShapedArray(input_shape, jnp.float32)).compile()
compiled = f.lower(core.ShapedArray(input_shape, jnp.float32)).compile()
compiled(a1) # no error
@jax_array(True)
@ -2237,7 +2238,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
di_map = s.devices_indices_map(shape)
bufs = [jax.device_put(inp_data[di_map[d]], d)
for d in jax.local_devices()]
arr = array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
arr = array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
f = pjit(lambda x: x, out_axis_resources=s)
out = f(arr)
@ -2338,7 +2339,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
inp = np.arange(prod(shape), dtype=np.int32).reshape(shape)
arr = array.ArrayImpl(
jax.core.ShapedArray(shape, np.int32), NamedSharding(mesh, P(None)),
core.ShapedArray(shape, np.int32), NamedSharding(mesh, P(None)),
[jax.device_put(inp, d) for d in mesh.devices.flat], committed=False)
with self.assertRaisesRegex(
NotImplementedError,
@ -2861,8 +2862,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
def test_pjit_with_mismatched_static_argnames(self):
x_is_tracer, y_is_tracer = False, False
def f(x, y):
assert isinstance(x, jax.core.Tracer) == x_is_tracer
assert isinstance(y, jax.core.Tracer) == y_is_tracer
assert isinstance(x, core.Tracer) == x_is_tracer
assert isinstance(y, core.Tracer) == y_is_tracer
return 1
# If both static_argnums and static_argnames are provided, they are allowed
@ -3000,7 +3001,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(cache_info3.hits, cache_info2.hits)
# AOT test
compiled = f.lower(jax.core.ShapedArray(y.shape, y.dtype)).compile()
compiled = f.lower(core.ShapedArray(y.shape, y.dtype)).compile()
out3 = compiled(y)
_check(out3, jax.devices()[1], y)
@ -3030,7 +3031,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
g_out = g(x)
_check(g_out, jax.devices()[0], x)
compiled = g.lower(jax.core.ShapedArray(x.shape, x.dtype)).compile()
compiled = g.lower(core.ShapedArray(x.shape, x.dtype)).compile()
out4 = compiled(x)
_check(out4, jax.devices()[0], x)
@ -3703,7 +3704,7 @@ class UtilTest(jtu.JaxTestCase):
mesh = pxla.Mesh(np.array(devices).reshape(*mesh_shape), tuple(mesh_axes))
dims = 5
aval = jax.core.ShapedArray((len(devices),) * dims, jnp.float32)
aval = core.ShapedArray((len(devices),) * dims, jnp.float32)
def roundtrip(spec):
op_sharding = NamedSharding(mesh, spec)._to_xla_op_sharding(aval.ndim)
parsed_spec = pjit_lib.parse_flatten_op_sharding(op_sharding, mesh)[0].partitions
@ -3732,9 +3733,9 @@ class UtilTest(jtu.JaxTestCase):
def test_get_input_metadata_fully_replicated(self):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_in_aval1 = jax.core.ShapedArray((4, 4), jnp.int32)
global_in_aval2 = jax.core.ShapedArray((4, 4, 4), jnp.int32)
global_in_aval3 = jax.core.ShapedArray((), jnp.int32)
global_in_aval1 = core.ShapedArray((4, 4), jnp.int32)
global_in_aval2 = core.ShapedArray((4, 4, 4), jnp.int32)
global_in_aval3 = core.ShapedArray((), jnp.int32)
in_avals = [global_in_aval1, global_in_aval2, global_in_aval3]
mp = NamedSharding(global_mesh, P(None))
@ -3753,7 +3754,7 @@ class UtilTest(jtu.JaxTestCase):
def test_mesh_sharding_spec(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
array_mapping = pxla.get_array_mapping(P('x', 'y'))
aval = jax.core.ShapedArray((1, 1), jnp.int32)
aval = core.ShapedArray((1, 1), jnp.int32)
with self.assertRaisesRegex(
ValueError,
'The aval shape on dimension 0 is 1 and the size of axis x is 4. The '

View File

@ -37,7 +37,8 @@ from jax import lax
from jax._src.lax import parallel
from jax._src import api as src_api
from jax import random
from jax.core import ShapedArray
from jax._src import core
from jax._src.core import ShapedArray
from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
linearize, device_put)
from jax._src import config as jax_config
@ -204,7 +205,7 @@ class PythonPmapTest(jtu.JaxTestCase):
for obj in [lowered, compiled]:
self.assertFalse(obj._no_kwargs)
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0,), {}))[1])
self.assertEqual(obj.in_avals, ((jax.core.ShapedArray(x.shape, x.dtype),), {}))
self.assertEqual(obj.in_avals, ((core.ShapedArray(x.shape, x.dtype),), {}))
def testLowerCompileInTreeMismatch(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
@ -334,7 +335,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x_shape = jax.core.ShapedArray(x.shape, x.dtype)
x_shape = core.ShapedArray(x.shape, x.dtype)
self.assertAllClose(f.lower(x_shape).compile()(x), f(x))
def testMean(self):
@ -2012,7 +2013,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def test_axis_env_length(self):
f = lambda x: jax.pmap(g)(jnp.array([x]))[0]
def g(x):
assert len(jax.core.thread_local_state.trace_state.axis_env) == 1
assert len(core.thread_local_state.trace_state.axis_env) == 1
return x
jax.grad(f)(3.) # doesn't fail

View File

@ -19,9 +19,9 @@ from typing import Any, Callable, Sequence
from absl.testing import absltest
import jax
from jax import core
from jax import lax
from jax import tree_util
from jax._src import core
from jax._src import debugging
from jax._src import dispatch
from jax._src import sharding

View File

@ -27,12 +27,12 @@ import scipy.special
import scipy.stats
import jax
from jax import core
from jax import grad
from jax import lax
from jax import numpy as jnp
from jax import prng
from jax import random
from jax._src import core
from jax._src import dtypes
from jax._src import test_util as jtu
from jax import vmap

View File

@ -23,6 +23,7 @@ from jax import lax
from jax.config import config
from jax.experimental.maps import Mesh
from jax.experimental.pjit import PartitionSpec as P
from jax._src import core
from jax._src import test_util as jtu
from jax._src.lib import xla_bridge
import jax.numpy as jnp
@ -414,7 +415,7 @@ class ShardMapTest(jtu.JaxTestCase):
def test_check_rep_false_doesnt_hit_rep_rules(self):
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
prim = jax.core.Primitive('prim') # no rep rule here!
prim = core.Primitive('prim') # no rep rule here!
prim.multiple_results = True
prim.def_impl(lambda: [])
prim.def_abstract_eval(lambda: [])

View File

@ -20,7 +20,7 @@ from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
from jax import core
from jax._src import core
from jax import lax
from jax._src import linear_util as lu
from jax.config import config

View File

@ -20,7 +20,7 @@ so it should be checked with pytype/mypy as well as being run with pytest.
from typing import Any, Optional, Union
import jax
from jax import core
from jax._src import core
from jax._src import config as jax_config
from jax._src import test_util as jtu
from jax._src import typing

View File

@ -31,8 +31,8 @@ import jax.scipy as jscipy
from jax._src import test_util as jtu
from jax import vmap
from jax import lax
from jax import core
from jax.core import NamedShape
from jax._src import core
from jax._src.core import NamedShape
from jax.experimental import maps
from jax.experimental import global_device_array
from jax._src import array