mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
migrate more internal dependencies from jax.core
to jax._src.core
PiperOrigin-RevId: 509736368
This commit is contained in:
parent
b476661b4a
commit
cb8dcce2fe
@ -35,6 +35,10 @@ del _cloud_tpu_init
|
||||
from jax import config as _config_module
|
||||
del _config_module
|
||||
|
||||
# Force early import, allowing use of `jax.core` after importing `jax`.
|
||||
import jax.core as _core
|
||||
del _core
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
|
@ -16,16 +16,16 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import inspect
|
||||
import threading
|
||||
|
||||
from typing import Any, Dict, Hashable, List, Optional, Protocol, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import core
|
||||
from jax import tree_util
|
||||
from jax._src import core
|
||||
from jax._src import debugging
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
import numpy as np
|
||||
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
|
@ -16,13 +16,14 @@ from functools import partial
|
||||
import enum
|
||||
from typing import Callable, Sequence, Union
|
||||
|
||||
from jax import core
|
||||
import numpy as np
|
||||
|
||||
from jax import jit
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import core
|
||||
from jax._src.util import canonicalize_axis
|
||||
from jax._src.numpy.util import _promote_dtypes_inexact
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _fill_lanczos_kernel(radius, x):
|
||||
|
@ -1539,7 +1539,7 @@ class PmapComputation(stages.XlaLowering):
|
||||
class UnloadedPmapExecutable:
|
||||
compiled: Any
|
||||
backend: xb.XlaBackend
|
||||
local_input_avals: Sequence[jax.core.AbstractValue]
|
||||
local_input_avals: Sequence[core.AbstractValue]
|
||||
input_shardings: Sequence[sharding_internal.XLACompatibleSharding]
|
||||
local_output_avals: Sequence[ShapedArray]
|
||||
output_shardings: Sequence[sharding_internal.XLACompatibleSharding]
|
||||
|
@ -21,10 +21,7 @@ import operator
|
||||
|
||||
from typing import Callable, Sequence, List, Tuple
|
||||
|
||||
from jax import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax.core import ConcreteArray, raise_to_shaped
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -32,11 +29,13 @@ from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src import ad_util
|
||||
from jax._src.core import replace_jaxpr_effects
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src import state
|
||||
from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects
|
||||
from jax._src.lax import lax
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (safe_map, extend_name_stack, split_list,
|
||||
|
@ -17,9 +17,8 @@ import operator
|
||||
|
||||
from typing import Any, Callable, Generic, List, Optional, Sequence, Set, Tuple, TypeVar, Union
|
||||
|
||||
from jax import core
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax._src import linear_util as lu
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
@ -28,14 +27,15 @@ from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import (tree_flatten, tree_structure, tree_unflatten,
|
||||
treedef_tuple, tree_map, tree_leaves, PyTreeDef)
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import state
|
||||
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
|
||||
split_list, split_dict)
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax._src.lax.control_flow import loops
|
||||
from jax._src.lax.control_flow.common import _abstractify, _initial_style_jaxpr
|
||||
|
||||
|
@ -23,7 +23,7 @@ import weakref
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
|
@ -17,10 +17,7 @@ from functools import partial
|
||||
import operator
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax._src import linear_util as lu
|
||||
from jax.core import raise_to_shaped
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
@ -28,6 +25,9 @@ from jax.interpreters import xla
|
||||
from jax.tree_util import (tree_flatten, treedef_children, tree_leaves,
|
||||
tree_unflatten, treedef_tuple)
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.core import raise_to_shaped
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import split_list, safe_map
|
||||
import numpy as np
|
||||
|
@ -4426,7 +4426,7 @@ def rng_bit_generator(key, shape, dtype=np.uint32,
|
||||
Most users should use `jax.random` instead for a stable and more user
|
||||
friendly API.
|
||||
"""
|
||||
shape = jax.core.canonicalize_shape(shape)
|
||||
shape = core.canonicalize_shape(shape)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if np.dtype(dtype) not in {np.dtype('uint8'), np.dtype('uint16'),
|
||||
np.dtype('uint32'), np.dtype('uint64')}:
|
||||
|
@ -21,7 +21,7 @@ from typing import (Callable, Iterable, Tuple, Optional, Dict, Any, Set,
|
||||
from functools import wraps, partial, partialmethod, lru_cache
|
||||
|
||||
from jax import numpy as jnp
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax import stages
|
||||
from jax._src import dispatch
|
||||
|
@ -21,14 +21,14 @@ import numpy as np
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import jax
|
||||
from jax import custom_jvp
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
from jax import core
|
||||
from jax.core import AxisName
|
||||
from jax._src import util
|
||||
from jax._src.ops.special import logsumexp as _logsumexp
|
||||
import jax.numpy as jnp
|
||||
from jax import custom_jvp
|
||||
from jax import lax
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src.core import AxisName
|
||||
from jax._src.ops.special import logsumexp as _logsumexp
|
||||
|
||||
Array = Any
|
||||
|
||||
|
@ -25,9 +25,9 @@ import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax import core
|
||||
from jax._src.util import prod
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.util import prod
|
||||
|
||||
KeyArray = random.KeyArray
|
||||
Array = Any
|
||||
|
@ -16,8 +16,8 @@ import abc
|
||||
from typing import Any, Iterable, List, Tuple, Union
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
import jax._src.numpy.lax_numpy as jnp
|
||||
from jax._src import core
|
||||
from jax._src.numpy.util import _promote_dtypes
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
@ -41,16 +41,16 @@ import opt_einsum
|
||||
|
||||
import jax
|
||||
from jax import jit
|
||||
from jax import core
|
||||
from jax import errors
|
||||
from jax import lax
|
||||
from jax.core import ShapedArray, DShapedArray, ConcreteArray
|
||||
from jax.interpreters import pxla
|
||||
from jax.tree_util import tree_leaves, tree_flatten, tree_map
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src.api_util import _ensure_index_tuple
|
||||
from jax._src.core import ShapedArray, DShapedArray, ConcreteArray
|
||||
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
|
||||
_sort_le_comparator, PrecisionLike)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
|
@ -17,18 +17,20 @@ from functools import partial
|
||||
import operator
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from jax import core
|
||||
import numpy as np
|
||||
|
||||
from jax import jit
|
||||
from jax import lax
|
||||
from jax._src import dtypes
|
||||
from jax._src import core
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
all, arange, argmin, array, asarray, atleast_1d, concatenate, convolve, diag, dot,
|
||||
finfo, full, maximum, ones, outer, roll, sqrt, trim_zeros, trim_zeros_tol, true_divide,
|
||||
vander, zeros)
|
||||
all, arange, argmin, array, asarray, atleast_1d, concatenate, convolve,
|
||||
diag, dot, finfo, full, maximum, ones, outer, roll, sqrt, trim_zeros,
|
||||
trim_zeros_tol, true_divide, vander, zeros)
|
||||
from jax._src.numpy import linalg
|
||||
from jax._src.numpy.util import _check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _where, _wraps
|
||||
from jax._src.numpy.util import (
|
||||
_check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _where, _wraps)
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
import numpy as np
|
||||
|
||||
|
||||
@jit
|
||||
|
@ -20,15 +20,18 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.numpy.ndarray import ndarray
|
||||
from jax._src.numpy.util import _broadcast_to, _check_arraylike, _complex_elem_type, _promote_dtypes_inexact, _promote_dtypes_numeric, _where, _wraps
|
||||
from jax._src.numpy.util import (
|
||||
_broadcast_to, _check_arraylike, _complex_elem_type,
|
||||
_promote_dtypes_inexact, _promote_dtypes_numeric, _where, _wraps)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
|
||||
from jax._src.util import canonicalize_axis as _canonicalize_axis, maybe_named_axis, prod as _prod
|
||||
from jax._src.util import (
|
||||
canonicalize_axis as _canonicalize_axis, maybe_named_axis, prod as _prod)
|
||||
|
||||
|
||||
_all = builtins.all
|
||||
|
@ -17,6 +17,12 @@ import operator
|
||||
from textwrap import dedent as _dedent
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import jit
|
||||
from jax import lax
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
@ -26,10 +32,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
from jax._src.numpy.util import _check_arraylike, _wraps
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.util import prod as _prod
|
||||
from jax import core
|
||||
from jax import jit
|
||||
from jax import lax
|
||||
import numpy as np
|
||||
|
||||
|
||||
_lax_const = lax_internal._const
|
||||
|
@ -24,16 +24,16 @@ from typing import Any, Callable, Tuple, Union, overload
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src.api import jit, custom_jvp
|
||||
from jax import lax
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.api import jit, custom_jvp
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.numpy.util import (
|
||||
_asarray, _check_arraylike, _promote_args, _promote_args_inexact,
|
||||
_promote_args_numeric, _promote_dtypes_inexact, _promote_dtypes_numeric,
|
||||
_promote_shapes, _where, _wraps)
|
||||
from jax import core
|
||||
from jax import lax
|
||||
|
||||
_lax_const = lax_internal._const
|
||||
|
||||
|
@ -20,14 +20,13 @@ from typing import (
|
||||
)
|
||||
import warnings
|
||||
|
||||
from jax._src.config import config
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
from jax._src.config import config
|
||||
from jax._src.lax import lax
|
||||
from jax._src.numpy.ndarray import ndarray
|
||||
from jax._src.util import safe_zip, safe_map
|
||||
from jax._src import api
|
||||
from jax import core
|
||||
from jax._src.lax import lax
|
||||
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape
|
||||
|
||||
import numpy as np
|
||||
@ -232,7 +231,7 @@ def _asarray(arr: ArrayLike) -> Array:
|
||||
"""
|
||||
_check_arraylike("_asarray", arr)
|
||||
dtype, weak_type = dtypes._lattice_result_type(arr)
|
||||
return lax_internal._convert_element_type(arr, dtype, weak_type)
|
||||
return lax._convert_element_type(arr, dtype, weak_type)
|
||||
|
||||
def _promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
|
||||
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
|
||||
@ -283,7 +282,7 @@ def _promote_dtypes(*args: ArrayLike) -> List[Array]:
|
||||
else:
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
return [lax_internal._convert_element_type(x, to_dtype, weak_type) for x in args]
|
||||
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
|
||||
|
||||
|
||||
def _promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
|
||||
@ -293,7 +292,7 @@ def _promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype)
|
||||
return [lax_internal._convert_element_type(x, to_dtype_inexact, weak_type)
|
||||
return [lax._convert_element_type(x, to_dtype_inexact, weak_type)
|
||||
for x in args]
|
||||
|
||||
|
||||
@ -304,7 +303,7 @@ def _promote_dtypes_numeric(*args: ArrayLike) -> List[Array]:
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
to_dtype_numeric = dtypes.to_numeric_dtype(to_dtype)
|
||||
return [lax_internal._convert_element_type(x, to_dtype_numeric, weak_type)
|
||||
return [lax._convert_element_type(x, to_dtype_numeric, weak_type)
|
||||
for x in args]
|
||||
|
||||
|
||||
@ -315,7 +314,7 @@ def _promote_dtypes_complex(*args: ArrayLike) -> List[Array]:
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
to_dtype_complex = dtypes.to_complex_dtype(to_dtype)
|
||||
return [lax_internal._convert_element_type(x, to_dtype_complex, weak_type)
|
||||
return [lax._convert_element_type(x, to_dtype_complex, weak_type)
|
||||
for x in args]
|
||||
|
||||
|
||||
@ -426,7 +425,7 @@ def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array:
|
||||
"be provided to jax.numpy.where, got {} and {}."
|
||||
.format(x, y))
|
||||
if not np.issubdtype(_dtype(condition), np.bool_):
|
||||
condition = lax.ne(condition, lax_internal._zero(condition))
|
||||
condition = lax.ne(condition, lax._zero(condition))
|
||||
x, y = _promote_dtypes(x, y)
|
||||
condition_arr, x_arr, y_arr = _broadcast_arrays(condition, x, y)
|
||||
try:
|
||||
|
@ -20,9 +20,9 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax import lax
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src.lax import lax as lax_internal
|
||||
|
@ -19,14 +19,18 @@ from typing import cast, Any, List, Optional, Tuple
|
||||
import numpy as np
|
||||
import scipy.special as osp_special
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import dtypes
|
||||
from jax import jit, vmap
|
||||
from jax import lax, core
|
||||
from jax.interpreters import ad
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
from jax import vmap
|
||||
from jax import lax
|
||||
from jax.interpreters import ad
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.lax_numpy import moveaxis, _promote_args_inexact, _promote_dtypes_inexact
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
moveaxis, _promote_args_inexact, _promote_dtypes_inexact)
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.ops import special as ops_special
|
||||
from jax._src.third_party.scipy.betaln import betaln as _betaln_impl
|
||||
|
@ -21,6 +21,7 @@ from typing import (Sequence, List, Tuple, Optional, Mapping, Dict, Set,
|
||||
FrozenSet, Union, cast)
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.interpreters import mlir
|
||||
@ -483,7 +484,7 @@ class PmapSharding(XLACompatibleSharding):
|
||||
"""
|
||||
# The dtype doesn't matter here. Its only used for creating the
|
||||
# sharding_spec.
|
||||
aval = jax.core.ShapedArray(shape, np.int32)
|
||||
aval = core.ShapedArray(shape, np.int32)
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(aval, sharded_dim)
|
||||
|
||||
num_ways_sharded = None
|
||||
|
@ -20,14 +20,14 @@ from typing import Any, Dict, List, Optional, Protocol, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax._src import linear_util as lu
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.util import safe_map, safe_zip, split_list
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.state.types import ShapedArrayRef
|
||||
from jax._src.state.primitives import get_p, swap_p, addupdate_p
|
||||
from jax._src.util import safe_map, safe_zip, split_list
|
||||
|
||||
## JAX utilities
|
||||
|
||||
|
@ -16,19 +16,21 @@ from functools import partial
|
||||
|
||||
from typing import Any, List, Protocol, Tuple, TypeVar, Union
|
||||
|
||||
from jax import core
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
from jax._src import ad_util
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import safe_map, safe_zip, partition_list, tuple_insert
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
import numpy as np
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src.typing import Array
|
||||
from jax._src.state.types import (ShapedArrayRef, ReadEffect, WriteEffect,
|
||||
AccumEffect)
|
||||
from jax._src.util import safe_map, safe_zip, partition_list, tuple_insert
|
||||
|
||||
|
||||
## General utilities
|
||||
|
||||
|
@ -16,7 +16,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Optional, Sequence, Set, Union
|
||||
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax._src.lib import xla_bridge, xla_client
|
||||
from jax._src.util import safe_map, safe_zip, tuple_insert, tuple_delete, prod
|
||||
from jax._src.lax.control_flow import common
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import jax
|
||||
import inspect
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax import tree_util
|
||||
from jax._src import linear_util as lu
|
||||
from jax.experimental import pjit
|
||||
@ -60,7 +60,7 @@ _CUSTOM_PARTITIONING_CALL_NAME = "CustomSPMDPartitioning"
|
||||
|
||||
|
||||
def _to_jax_shape(s):
|
||||
return jax.core.ShapedArray(s.dimensions(), s.numpy_dtype())
|
||||
return core.ShapedArray(s.dimensions(), s.numpy_dtype())
|
||||
|
||||
|
||||
def _custom_partitioning_propagate_user_sharding(sharding, shape, backend_string):
|
||||
|
@ -19,7 +19,7 @@ import numpy as np
|
||||
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import api_util
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
|
@ -505,7 +505,7 @@ from typing import (Any, Callable, Dict, List, Optional, Sequence,
|
||||
import warnings
|
||||
|
||||
from jax._src import api
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax.config import config
|
||||
from jax import custom_derivatives
|
||||
from jax._src import dtypes
|
||||
|
@ -29,7 +29,6 @@ from typing import Any, Callable, Optional, Sequence, Tuple
|
||||
from absl import logging
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import dlpack
|
||||
from jax import dtypes
|
||||
from jax import tree_util
|
||||
@ -40,6 +39,7 @@ from jax._src import ad_checkpoint
|
||||
from jax._src import custom_derivatives
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax._src import core
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
@ -18,7 +18,7 @@ from functools import partial, wraps
|
||||
import string
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
|
||||
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax import lax
|
||||
from jax._src.lax import slicing as lax_slicing
|
||||
from jax._src import dtypes
|
||||
|
@ -28,7 +28,6 @@ import numpy as np
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import config
|
||||
from jax import core
|
||||
from jax import custom_derivatives
|
||||
from jax import random
|
||||
from jax import numpy as jnp
|
||||
@ -45,6 +44,7 @@ from jax._src import ad_checkpoint
|
||||
from jax._src import ad_util
|
||||
from jax._src import api
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
|
@ -59,20 +59,20 @@ from functools import partial
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax.interpreters import xla
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import partial_eval as pe, pxla
|
||||
from jax._src.api_util import shaped_abstractify
|
||||
from jax.tree_util import (register_pytree_node, tree_structure,
|
||||
treedef_is_leaf, tree_flatten, tree_unflatten,)
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.api_util import shaped_abstractify
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.util import unzip2, weakref_lru_cache
|
||||
|
||||
|
||||
|
@ -21,6 +21,7 @@ import zlib
|
||||
from typing import Any
|
||||
import jax
|
||||
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import array
|
||||
from jax._src import sharding
|
||||
@ -113,7 +114,7 @@ def _handle_array_process_allgather(inp, tiled):
|
||||
if host_np_arr.ndim == 0 or not tiled:
|
||||
host_np_arr = np.expand_dims(host_np_arr, axis=0)
|
||||
|
||||
aval = jax.core.ShapedArray(host_np_arr.shape, host_np_arr.dtype)
|
||||
aval = core.ShapedArray(host_np_arr.shape, host_np_arr.dtype)
|
||||
global_aval = global_mesh._local_to_global(
|
||||
pxla.get_array_mapping(pspec), aval)
|
||||
|
||||
@ -325,7 +326,7 @@ def host_local_array_to_global_array(local_inputs: Any,
|
||||
))
|
||||
|
||||
global_aval = _local_to_global_aval(
|
||||
jax.core.ShapedArray(arr.shape, arrays[0].dtype), global_mesh, pspec)
|
||||
core.ShapedArray(arr.shape, arrays[0].dtype), global_mesh, pspec)
|
||||
|
||||
return array.ArrayImpl(
|
||||
global_aval, jax.sharding.NamedSharding(global_mesh, pspec),
|
||||
|
@ -31,7 +31,7 @@ import operator as op
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax import custom_derivatives
|
||||
from jax import lax
|
||||
from jax._src.numpy.util import _promote_dtypes_inexact
|
||||
|
@ -87,7 +87,7 @@ from typing import Any, Dict, List, Tuple
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax._src.custom_derivatives import custom_vjp
|
||||
|
@ -24,9 +24,8 @@ from typing import (Any, Callable, Dict, Hashable, List, Optional, Sequence,
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax.core import Tracer
|
||||
from jax.sharding import NamedSharding, PartitionSpec, Mesh
|
||||
from jax._src import core
|
||||
from jax._src import ad_util
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import linear_util as lu
|
||||
@ -35,6 +34,7 @@ from jax._src import pjit
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.core import Tracer
|
||||
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
|
||||
windowed_reductions, fft, linalg)
|
||||
from jax._src.util import (prod, HashableFunction, unzip2, as_hashable_function,
|
||||
@ -502,7 +502,7 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
|
||||
outs_, out_rep = unzip2((t.val, t.rep) for t in out_tracers)
|
||||
del main, t, in_tracers, ans, out_tracers
|
||||
out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs_]
|
||||
_check_names(out_names_thunk(), out_avals)
|
||||
_check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types
|
||||
if check_rep: _check_reps(mesh, out_names_thunk(), out_rep)
|
||||
return map(partial(_match_spec, mesh), out_rep, out_names_thunk(), outs_)
|
||||
core.EvalTrace.process_shard_map = _shard_map_impl
|
||||
@ -579,7 +579,7 @@ class ShardMapTrace(core.Trace):
|
||||
fun, jaxpr = _grab_jaxpr_shadily(fun) # TODO remove with initial-style jit
|
||||
bind = partial(call_primitive.bind, fun) # TODO caching (compat w/ jaxpr())
|
||||
fake_primitive = pxla.FakePrimitive(multiple_results=True, bind=bind)
|
||||
_rep_rules[fake_primitive] = lambda *_, **__: set()
|
||||
_rep_rules[fake_primitive] = lambda *_, **__: set() # pytype: disable=container-type-mismatch
|
||||
out_tracers_ = self.process_primitive(fake_primitive, tracers, params)
|
||||
out_vals = [t.val for t in out_tracers_]
|
||||
if self.check:
|
||||
|
@ -16,7 +16,7 @@
|
||||
import abc
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
import jax.numpy as jnp
|
||||
from jax._src import util
|
||||
from jax._src.typing import Array
|
||||
|
@ -16,7 +16,7 @@ import itertools
|
||||
from typing import Any, Callable, Sequence, Tuple, Union
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax import tree_util
|
||||
from jax._src.api_util import _ensure_index, _ensure_index_tuple
|
||||
from jax.util import safe_zip
|
||||
|
@ -34,7 +34,6 @@ import operator
|
||||
from typing import Optional, Union
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import tree_util
|
||||
from jax.experimental.sparse._base import JAXSparse
|
||||
from jax.experimental.sparse.bcoo import BCOO
|
||||
@ -44,6 +43,7 @@ from jax.experimental.sparse.csr import CSR, CSC
|
||||
from jax.experimental.sparse.util import _coo_extract
|
||||
from jax.interpreters import mlir
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
|
@ -24,7 +24,6 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax import vmap
|
||||
@ -40,6 +39,7 @@ from jax._src.interpreters import mlir
|
||||
import jax.numpy as jnp
|
||||
from jax.util import safe_zip, unzip2, split_list
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.lax.lax import (
|
||||
|
@ -25,7 +25,6 @@ import numpy as np
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import config
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax.experimental.sparse._base import JAXSparse
|
||||
@ -37,6 +36,7 @@ from jax.experimental.sparse.util import (
|
||||
from jax.util import split_list, safe_zip
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src.lax.lax import DotDimensionNumbers, _dot_general_batch_dim_nums
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
@ -22,12 +22,12 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax.interpreters import mlir
|
||||
from jax.experimental.sparse._base import JAXSparse
|
||||
from jax.experimental.sparse.util import _coo_extract, CuSparseEfficiencyWarning
|
||||
from jax import tree_util
|
||||
from jax._src import core
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax.lax import _const
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
@ -21,13 +21,13 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax.interpreters import mlir
|
||||
from jax.experimental.sparse._base import JAXSparse
|
||||
from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo
|
||||
from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, CuSparseEfficiencyWarning
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import core
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax.lax import _const
|
||||
from jax._src.lib import gpu_sparse
|
||||
|
@ -20,10 +20,10 @@ import functools
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax import core
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src import core
|
||||
from jax._src.lib import gpu_solver
|
||||
|
||||
import numpy as np
|
||||
|
@ -54,8 +54,8 @@ from typing import (
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import pjit
|
||||
from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_sparse
|
||||
|
@ -19,10 +19,10 @@ from typing import Any, NamedTuple, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax import vmap
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import stages
|
||||
from jax._src.api_util import flatten_axes
|
||||
|
@ -46,10 +46,10 @@ import concurrent.futures
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian
|
||||
from jax import core, lax
|
||||
from jax._src import core
|
||||
from jax import lax
|
||||
from jax import custom_batching
|
||||
from jax._src import api, dtypes, dispatch, lib, api_util
|
||||
from jax.core import Primitive
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax.interpreters import ad
|
||||
from jax._src.interpreters import mlir
|
||||
@ -965,7 +965,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
for obj in [lowered, compiled]:
|
||||
self.assertEqual(
|
||||
obj.in_avals,
|
||||
((jax.core.ShapedArray([], expected_dtype, weak_type=True),), {}))
|
||||
((core.ShapedArray([], expected_dtype, weak_type=True),), {}))
|
||||
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0,), {}))[1])
|
||||
|
||||
def test_jit_lower_duck_typing(self):
|
||||
@ -1449,7 +1449,7 @@ class APITest(jtu.JaxTestCase):
|
||||
r"The __index__\(\) method was called on the JAX Tracer object.*", lambda: jit(f)(0))
|
||||
|
||||
def test_unimplemented_interpreter_rules(self):
|
||||
foo_p = Primitive('foo')
|
||||
foo_p = core.Primitive('foo')
|
||||
def foo(x):
|
||||
return foo_p.bind(x)
|
||||
|
||||
@ -3543,12 +3543,12 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_jit_returning_token(self):
|
||||
x = jax.jit(jax.lax.create_token)(1.0)
|
||||
self.assertIsInstance(x, jax.core.Token)
|
||||
self.assertIsInstance(x, core.Token)
|
||||
|
||||
def test_jit_capturing_token(self):
|
||||
tok = jax.core.token
|
||||
tok = core.token
|
||||
_, y = jax.jit(lambda x: (x + 2, tok))(7)
|
||||
self.assertIsInstance(y, jax.core.Token)
|
||||
self.assertIsInstance(y, core.Token)
|
||||
|
||||
def test_leak_checker_catches_a_jit_leak(self):
|
||||
with jax.checking_leaks():
|
||||
@ -4119,7 +4119,7 @@ class APITest(jtu.JaxTestCase):
|
||||
return g(x)
|
||||
|
||||
jaxpr = jax.make_jaxpr(h)(7)
|
||||
jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 7)
|
||||
core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 7)
|
||||
|
||||
b(8) # don't crash
|
||||
|
||||
@ -4142,7 +4142,7 @@ class APITest(jtu.JaxTestCase):
|
||||
return g(x)
|
||||
|
||||
jaxpr = jax.make_jaxpr(h)(7)
|
||||
jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 7)
|
||||
core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 7)
|
||||
|
||||
b(8) # don't crash
|
||||
|
||||
@ -4734,7 +4734,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
])
|
||||
def test_remat_eval_counter(self, remat):
|
||||
# https://github.com/google/jax/issues/2737
|
||||
add_one_p = Primitive('add_one')
|
||||
add_one_p = core.Primitive('add_one')
|
||||
add_one = add_one_p.bind
|
||||
|
||||
num_evals = 0
|
||||
@ -4772,7 +4772,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
|
||||
@jax_util.curry
|
||||
def call(f, *args):
|
||||
return jax.core.call(
|
||||
return core.call(
|
||||
lu.wrap_init(lambda *args: [f(*args)]),
|
||||
*args, name='foo')[0]
|
||||
|
||||
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -145,7 +146,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(arr.is_fully_replicated, expected_is_fully_replicated)
|
||||
for i, s in enumerate(arr.addressable_shards):
|
||||
self.assertEqual(s.data.aval,
|
||||
jax.core.ShapedArray(expected_shard_shape, s.data.dtype))
|
||||
core.ShapedArray(expected_shard_shape, s.data.dtype))
|
||||
self.assertArraysEqual(s.data, global_input_data[s.index])
|
||||
self.assertArraysEqual(s.data, arr.addressable_data(i))
|
||||
|
||||
@ -318,13 +319,13 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
ValueError,
|
||||
r'Expected 8 per-device arrays \(this is how many devices are addressable '
|
||||
r'by the sharding\), but got 4'):
|
||||
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs[:4], committed=True)
|
||||
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs[:4], committed=True)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r'Expected 8 per-device arrays \(this is how many devices are addressable '
|
||||
r'by the sharding\), but got 16'):
|
||||
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs + bufs, committed=True)
|
||||
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs + bufs, committed=True)
|
||||
|
||||
def test_arrays_not_in_device_assignment(self):
|
||||
if jax.device_count() < 4:
|
||||
@ -342,7 +343,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
"Sharding contains devices {0, 1} that are not present in per-device "
|
||||
"arrays. Per-device arrays contain devices {2, 3} that are not present "
|
||||
"in the sharding."):
|
||||
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
|
||||
def test_more_devices_in_sharding_than_arrays(self):
|
||||
shape = (8, 2)
|
||||
@ -357,7 +358,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
"Addressable devices and per-device arrays devices do not match. "
|
||||
r"Sharding contains devices \{1\} that are not present in per-device "
|
||||
"arrays."):
|
||||
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
|
||||
def test_different_devices_in_arrays_than_sharding(self):
|
||||
if jax.device_count() < 3:
|
||||
@ -375,7 +376,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
r"Sharding contains devices \{2\} that are not present in per-device "
|
||||
r"arrays. Per-device arrays contain devices \{0\} that are not present "
|
||||
"in the sharding."):
|
||||
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y", P("x", "y"), (2, 2)),
|
||||
@ -410,7 +411,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
ValueError,
|
||||
"Input buffers to `Array` must have matching dtypes. "
|
||||
"Got int32, expected float32"):
|
||||
array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
|
||||
def test_array_iter_pmap_sharding(self):
|
||||
if jax.device_count() < 2:
|
||||
@ -975,7 +976,7 @@ class RngShardingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_pickle_pjit_lower(self):
|
||||
example_exe = jax.jit(lambda x: x * x).lower(
|
||||
jax.core.ShapedArray(
|
||||
core.ShapedArray(
|
||||
(2, 2), dtype=np.float32)).compile()._executable.xla_executable
|
||||
|
||||
# Skip if CompileOptions is not available. This is true on
|
||||
@ -995,7 +996,7 @@ class RngShardingTest(jtu.JaxTestCase):
|
||||
fun,
|
||||
in_axis_resources=P('data'),
|
||||
out_axis_resources=P(None, 'data'),
|
||||
).lower(jax.core.ShapedArray(shape=(8, 8), dtype=np.float32))
|
||||
).lower(core.ShapedArray(shape=(8, 8), dtype=np.float32))
|
||||
|
||||
def verify_serialization(lowered):
|
||||
serialized, in_tree, out_tree = compile_and_serialize(lowered)
|
||||
|
@ -24,6 +24,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.scipy as jsp
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax import lax
|
||||
@ -1178,7 +1179,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
f = vmap(jax.grad(lambda x: -lax.psum(x, 'i')), out_axes=None, axis_name='i')
|
||||
self.assertEqual(
|
||||
f(a),
|
||||
jax.core.jaxpr_as_fun(jax.make_jaxpr(f)(a))(a)[0])
|
||||
core.jaxpr_as_fun(jax.make_jaxpr(f)(a))(a)[0])
|
||||
|
||||
def testAllGatherToUnmapped(self):
|
||||
def f(x):
|
||||
@ -1301,7 +1302,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
|
||||
Array = Any
|
||||
ArrayElt = Any
|
||||
Int = Union[int, jax.core.Tracer]
|
||||
Int = Union[int, core.Tracer]
|
||||
|
||||
# Can't used NamedTuple here b/c those are pytrees
|
||||
class NamedArray:
|
||||
|
@ -28,6 +28,7 @@ from jax.experimental import checkify
|
||||
from jax.experimental import pjit
|
||||
from jax._src.sharding import NamedSharding
|
||||
from jax._src import array
|
||||
from jax._src import core
|
||||
from jax._src.checkify import JaxRuntimeError, FailedCheckError, ErrorEffect, OOBError
|
||||
import jax.numpy as jnp
|
||||
|
||||
@ -1173,7 +1174,7 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
return x
|
||||
x = jnp.ones(())
|
||||
jaxpr = jax.make_jaxpr(f)(x)
|
||||
roundtrip_f = partial(jax.core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)
|
||||
roundtrip_f = partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)
|
||||
checked_f = checkify.checkify(jax.jit(roundtrip_f))
|
||||
err, _ = checked_f(jnp.ones(()))
|
||||
self.assertIsNotNone(err.get())
|
||||
|
@ -24,10 +24,8 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import linear_util as lu
|
||||
from jax import jvp, linearize, vjp, jit, make_jaxpr
|
||||
from jax.core import UnshapedArray, ShapedArray, DBIdx
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce,
|
||||
@ -35,6 +33,8 @@ from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce,
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import lax as lax_internal
|
||||
|
@ -19,9 +19,10 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import core, grad, jit, vmap, lax
|
||||
import jax.numpy as jnp
|
||||
from jax import grad, jit, vmap, lax
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
|
@ -19,7 +19,7 @@ import unittest
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import prod, safe_zip
|
||||
|
||||
|
@ -27,7 +27,7 @@ from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax import ad_checkpoint
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax.config import config
|
||||
from jax import dtypes
|
||||
from jax.experimental import host_callback as hcb
|
||||
|
@ -20,6 +20,7 @@ import jax
|
||||
from jax import lax, numpy as jnp
|
||||
from jax import config
|
||||
from jax.experimental import host_callback as hcb
|
||||
from jax._src import core
|
||||
from jax._src.lib import xla_client
|
||||
import jax._src.test_util as jtu
|
||||
import numpy as np
|
||||
@ -37,9 +38,9 @@ class InfeedTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
token = lax.create_token(x)
|
||||
(y,), token = lax.infeed(
|
||||
token, shape=(jax.core.ShapedArray((3, 4), jnp.float32),))
|
||||
token, shape=(core.ShapedArray((3, 4), jnp.float32),))
|
||||
(z,), _ = lax.infeed(
|
||||
token, shape=(jax.core.ShapedArray((3, 1, 1), jnp.float32),))
|
||||
token, shape=(core.ShapedArray((3, 1, 1), jnp.float32),))
|
||||
return x + y + z
|
||||
|
||||
x = np.float32(1.5)
|
||||
@ -55,8 +56,8 @@ class InfeedTest(jtu.JaxTestCase):
|
||||
x = np.float32(1.5)
|
||||
y = np.reshape(np.arange(12, dtype=np.int16), (3, 4))
|
||||
to_infeed = dict(a=x, b=y)
|
||||
to_infeed_shape = dict(a=jax.core.ShapedArray((), dtype=np.float32),
|
||||
b=jax.core.ShapedArray((3, 4), dtype=np.int16))
|
||||
to_infeed_shape = dict(a=core.ShapedArray((), dtype=np.float32),
|
||||
b=core.ShapedArray((3, 4), dtype=np.int16))
|
||||
@jax.jit
|
||||
def f(x):
|
||||
token = lax.create_token(x)
|
||||
@ -77,7 +78,7 @@ class InfeedTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
token = lax.create_token(x)
|
||||
y, token = lax.infeed(
|
||||
token, shape=jax.core.ShapedArray((3, 4), jnp.float32))
|
||||
token, shape=core.ShapedArray((3, 4), jnp.float32))
|
||||
token = lax.outfeed(token, y + np.float32(1))
|
||||
return x - 1
|
||||
|
||||
@ -97,7 +98,7 @@ class InfeedTest(jtu.JaxTestCase):
|
||||
|
||||
def doubler(_, token):
|
||||
y, token = lax.infeed(
|
||||
token, shape=jax.core.ShapedArray((3, 4), jnp.float32))
|
||||
token, shape=core.ShapedArray((3, 4), jnp.float32))
|
||||
return lax.outfeed(token, y * np.float32(2))
|
||||
|
||||
@jax.jit
|
||||
|
@ -19,6 +19,7 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
from jax import dtypes
|
||||
from jax._src import lib as jaxlib
|
||||
from jax import numpy as jnp
|
||||
@ -60,7 +61,7 @@ class JaxJitTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertFalse(output_buffer.aval.weak_type)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
self.assertEqual(output_buffer.aval, jax.core.ShapedArray((), dtype))
|
||||
self.assertEqual(output_buffer.aval, core.ShapedArray((), dtype))
|
||||
self.assertEqual(output_buffer.dtype, dtype)
|
||||
|
||||
@parameterized.parameters([jax.device_put, _cpp_device_put])
|
||||
@ -73,7 +74,7 @@ class JaxJitTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertFalse(output_buffer.aval.weak_type)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
self.assertEqual(output_buffer.aval, jax.core.ShapedArray((3, 4), dtype))
|
||||
self.assertEqual(output_buffer.aval, core.ShapedArray((3, 4), dtype))
|
||||
self.assertEqual(output_buffer.dtype, dtype)
|
||||
np.testing.assert_array_equal(output_buffer, np.zeros((3, 4),
|
||||
dtype=dtype))
|
||||
|
@ -19,7 +19,7 @@ import warnings
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax import lax
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
|
@ -26,7 +26,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax import dtypes
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax import lax
|
||||
|
@ -40,6 +40,7 @@ from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax.test_util import check_grads
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
@ -2296,9 +2297,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testSearchsortedDtype(self):
|
||||
# Test that for large arrays, int64 indices are used. We test this
|
||||
# via abstract evaluation to avoid allocating a large array in tests.
|
||||
a_int32 = jax.core.ShapedArray((np.iinfo(np.int32).max,), np.float32)
|
||||
a_int64 = jax.core.ShapedArray((np.iinfo(np.int32).max + 1,), np.float32)
|
||||
v = jax.core.ShapedArray((), np.float32)
|
||||
a_int32 = core.ShapedArray((np.iinfo(np.int32).max,), np.float32)
|
||||
a_int64 = core.ShapedArray((np.iinfo(np.int32).max + 1,), np.float32)
|
||||
v = core.ShapedArray((), np.float32)
|
||||
|
||||
out_int32 = jax.eval_shape(jnp.searchsorted, a_int32, v)
|
||||
self.assertEqual(out_int32.dtype, np.int32)
|
||||
@ -3322,7 +3323,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
if mode == 'raise':
|
||||
msg = ("The error occurred because ravel_multi_index was jit-compiled "
|
||||
"with mode='raise'. Use mode='wrap' or mode='clip' instead.")
|
||||
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg):
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
jax.jit(jnp_fun)(*args_maker())
|
||||
else:
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
@ -3360,7 +3361,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
if mode == 'raise':
|
||||
msg = ("The error occurred because jnp.choose was jit-compiled"
|
||||
" with mode='raise'. Use mode='wrap' or mode='clip' instead.")
|
||||
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg):
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
jax.jit(jnp_fun)(*args_maker())
|
||||
else:
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
@ -4438,7 +4439,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
# abstract tracer value for jnp.mgrid slice
|
||||
with self.assertRaisesRegex(jax.core.ConcretizationTypeError,
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError,
|
||||
"slice start of jnp.mgrid"):
|
||||
jax.jit(lambda a, b: jnp.mgrid[a:b])(0, 2)
|
||||
|
||||
@ -4479,7 +4480,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
# abstract tracer value for ogrid slice
|
||||
with self.assertRaisesRegex(jax.core.ConcretizationTypeError,
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError,
|
||||
"slice start of jnp.ogrid"):
|
||||
jax.jit(lambda a, b: jnp.ogrid[a:b])(0, 2)
|
||||
|
||||
@ -4506,7 +4507,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, "could not understand directive.*"):
|
||||
jnp.r_["asdfgh",[1,2,3]]
|
||||
# abstract tracer value for r_ slice
|
||||
with self.assertRaisesRegex(jax.core.ConcretizationTypeError,
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError,
|
||||
"slice start of jnp.r_"):
|
||||
jax.jit(lambda a, b: jnp.r_[a:b])(0, 2)
|
||||
|
||||
@ -4555,7 +4556,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, "could not understand directive.*"):
|
||||
jnp.c_["asdfgh",[1,2,3]]
|
||||
# abstract tracer value for c_ slice
|
||||
with self.assertRaisesRegex(jax.core.ConcretizationTypeError,
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError,
|
||||
"slice start of jnp.c_"):
|
||||
jax.jit(lambda a, b: jnp.c_[a:b])(0, 2)
|
||||
|
||||
@ -4948,13 +4949,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
def testArangeConcretizationError(self):
|
||||
msg = r"It arose in jax.numpy.arange argument `{}`".format
|
||||
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('stop')):
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg('stop')):
|
||||
jax.jit(jnp.arange)(3)
|
||||
|
||||
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('start')):
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg('start')):
|
||||
jax.jit(lambda start: jnp.arange(start, 3))(0)
|
||||
|
||||
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('stop')):
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg('stop')):
|
||||
jax.jit(lambda stop: jnp.arange(0, stop))(3)
|
||||
|
||||
@jtu.sample_product(dtype=[None] + float_dtypes)
|
||||
|
@ -28,7 +28,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax.test_util import check_grads
|
||||
|
@ -28,6 +28,7 @@ import numpy as np
|
||||
import jax
|
||||
from jax import experimental
|
||||
from jax.config import config
|
||||
from jax._src import core
|
||||
from jax._src import distributed
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
@ -537,7 +538,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
f = pjit.pjit(lambda x, y: (x, y),
|
||||
in_axis_resources=experimental.PartitionSpec("x", "y"),
|
||||
out_axis_resources=experimental.PartitionSpec("x", "y"))
|
||||
inp_aval = jax.core.ShapedArray((8, 2), jnp.int32)
|
||||
inp_aval = core.ShapedArray((8, 2), jnp.int32)
|
||||
# `ShapedArray` is considered global when lowered and compiled.
|
||||
# Hence it can bypass the contiguous mesh restriction.
|
||||
compiled = f.lower(inp_aval, gda1).compile()
|
||||
|
@ -16,7 +16,7 @@ import functools
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax import lax
|
||||
from jax._src.pjit import pjit
|
||||
from jax._src import linear_util as lu
|
||||
|
@ -23,7 +23,7 @@ from absl.testing import parameterized
|
||||
|
||||
import scipy.stats
|
||||
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax.test_util import check_grads
|
||||
from jax import nn
|
||||
|
@ -28,6 +28,7 @@ import concurrent.futures
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.config import parallel_functions_output_gda, jax_array
|
||||
@ -588,7 +589,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
in_axis_resources=(P('x'), P('y')),
|
||||
out_axis_resources=P('y'))
|
||||
f_jaxpr = jax.make_jaxpr(f)(x, y)
|
||||
f_eval = jax.core.jaxpr_as_fun(f_jaxpr)
|
||||
f_eval = core.jaxpr_as_fun(f_jaxpr)
|
||||
r, = f_eval(x, y)
|
||||
self.assertAllClose(r, x.sum() + jnp.sin(y))
|
||||
|
||||
@ -727,11 +728,11 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
def f_for_jit(x):
|
||||
token = lax.create_token(x)
|
||||
(y,), token = lax.infeed(
|
||||
token, shape=(jax.core.ShapedArray(x.shape, np.float32),))
|
||||
token, shape=(core.ShapedArray(x.shape, np.float32),))
|
||||
(z,), token = lax.infeed(
|
||||
token, shape=(jax.core.ShapedArray(x.shape, np.float32),))
|
||||
token, shape=(core.ShapedArray(x.shape, np.float32),))
|
||||
(w,), token = lax.infeed(
|
||||
token, shape=(jax.core.ShapedArray(x.shape, np.float32),))
|
||||
token, shape=(core.ShapedArray(x.shape, np.float32),))
|
||||
|
||||
return x + y + z + w
|
||||
|
||||
@ -761,17 +762,17 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
# A replicated infeed
|
||||
(y,), token = lax.infeed(
|
||||
token,
|
||||
shape=(jax.core.ShapedArray(x.shape, np.float32),),
|
||||
shape=(core.ShapedArray(x.shape, np.float32),),
|
||||
partitions=(None,))
|
||||
# An infeed sharded on first axis
|
||||
(z,), token = lax.infeed(
|
||||
token,
|
||||
shape=(jax.core.ShapedArray(x.shape, np.float32),),
|
||||
shape=(core.ShapedArray(x.shape, np.float32),),
|
||||
partitions=(P(nr_devices, 1),))
|
||||
# An infeed sharded on second axis
|
||||
(w,), token = lax.infeed(
|
||||
token,
|
||||
shape=(jax.core.ShapedArray(x.shape, np.float32),),
|
||||
shape=(core.ShapedArray(x.shape, np.float32),),
|
||||
partitions=(P(1, nr_devices),))
|
||||
return x + y + z + w
|
||||
|
||||
@ -855,7 +856,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertEqual(lowered.in_avals, compiled.in_avals)
|
||||
self.assertEqual(
|
||||
lowered.in_avals,
|
||||
((jax.core.ShapedArray(x.shape, x.dtype, weak_type=False),) * 2, {}))
|
||||
((core.ShapedArray(x.shape, x.dtype, weak_type=False),) * 2, {}))
|
||||
|
||||
splits = np.split(expected, 4)
|
||||
self.assertAllClose(np.asarray(actual.device_buffers[0]), splits[0],
|
||||
@ -1058,7 +1059,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x @ y
|
||||
|
||||
shape = (8, 8)
|
||||
aval = jax.core.ShapedArray(shape, dtypes.canonicalize_dtype(jnp.int64))
|
||||
aval = core.ShapedArray(shape, dtypes.canonicalize_dtype(jnp.int64))
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
exe = f.lower(aval, x).compile()
|
||||
self.assertIsInstance(exe, stages.Compiled)
|
||||
@ -1509,7 +1510,7 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=P('x'))
|
||||
compiled = f.lower(jax.core.ShapedArray(global_input_shape, jnp.float32)).compile()
|
||||
compiled = f.lower(core.ShapedArray(global_input_shape, jnp.float32)).compile()
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "GDA sharding does not match the input sharding."):
|
||||
compiled(input_gda)
|
||||
@ -1521,7 +1522,7 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
g1, _ = create_gda(global_input_shape, global_mesh, P(None,))
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x, in_axis_resources=P(None), out_axis_resources=P('x'))
|
||||
compiled = f.lower(jax.core.ShapedArray(global_input_shape, jnp.float32)).compile()
|
||||
compiled = f.lower(core.ShapedArray(global_input_shape, jnp.float32)).compile()
|
||||
compiled(g1) # no error
|
||||
|
||||
@parallel_functions_output_gda(True)
|
||||
@ -1577,7 +1578,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
f = pjit(lambda x: x, in_axis_resources=AUTO,
|
||||
out_axis_resources=AUTO)
|
||||
|
||||
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp).compile()
|
||||
inputs = [create_gda(global_input_shape, global_mesh, ip, input_data)[0]
|
||||
for ip in compiled.input_shardings[0]]
|
||||
@ -1604,7 +1605,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
f = pjit(lambda x: x, in_axis_resources=AUTO,
|
||||
out_axis_resources=AUTO)
|
||||
|
||||
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp).compile()
|
||||
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
|
||||
for ip in compiled.input_shardings[0]]
|
||||
@ -1628,7 +1629,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
with ctx(True):
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x, in_axis_resources=AUTO, out_axis_resources=AUTO)
|
||||
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp).compile()
|
||||
|
||||
different_pspec = (P('y', 'x')
|
||||
@ -1652,7 +1653,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
with global_mesh:
|
||||
f = pjit(lambda x, y, z: (x, y, z), in_axis_resources=AUTO,
|
||||
out_axis_resources=AUTO)
|
||||
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp, inp, inp).compile()
|
||||
self.assertLen(compiled.output_shardings, 3)
|
||||
self.assertLen(compiled.input_shardings[0], 3)
|
||||
@ -1680,7 +1681,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
f = pjit(lambda x, y: (x, y), in_axis_resources=(in_resource, AUTO),
|
||||
out_axis_resources=AUTO)
|
||||
|
||||
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp, inp).compile()
|
||||
inputs = [create_gda(global_input_shape, global_mesh, ip, input_data)[0]
|
||||
for ip in compiled.input_shardings[0]]
|
||||
@ -1710,7 +1711,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
f = pjit(lambda x, y: (x, y), in_axis_resources=(in_resource, AUTO),
|
||||
out_axis_resources=AUTO)
|
||||
|
||||
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp, inp).compile()
|
||||
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
|
||||
for ip in compiled.input_shardings[0]]
|
||||
@ -1735,7 +1736,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
f = pjit(lambda x: x, in_axis_resources=AUTO,
|
||||
out_axis_resources=AUTO)
|
||||
|
||||
inp = jax.core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp).compile()
|
||||
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
|
||||
for ip in compiled.input_shardings[0]]
|
||||
@ -1987,7 +1988,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
a1, input_data = create_array(global_input_shape, global_mesh, P('x', 'y'))
|
||||
a2, _ = create_array(global_input_shape, global_mesh, P('x'))
|
||||
|
||||
aval = jax.core.ShapedArray(global_input_shape, np.float32)
|
||||
aval = core.ShapedArray(global_input_shape, np.float32)
|
||||
|
||||
with jax_array(True):
|
||||
with global_mesh:
|
||||
@ -2111,7 +2112,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with jax_array(True):
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x, in_axis_resources=NamedSharding(global_mesh, P(None,)))
|
||||
compiled = f.lower(jax.core.ShapedArray(input_shape, jnp.float32)).compile()
|
||||
compiled = f.lower(core.ShapedArray(input_shape, jnp.float32)).compile()
|
||||
compiled(a1) # no error
|
||||
|
||||
@jax_array(True)
|
||||
@ -2237,7 +2238,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
di_map = s.devices_indices_map(shape)
|
||||
bufs = [jax.device_put(inp_data[di_map[d]], d)
|
||||
for d in jax.local_devices()]
|
||||
arr = array.ArrayImpl(jax.core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
arr = array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
|
||||
|
||||
f = pjit(lambda x: x, out_axis_resources=s)
|
||||
out = f(arr)
|
||||
@ -2338,7 +2339,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
inp = np.arange(prod(shape), dtype=np.int32).reshape(shape)
|
||||
arr = array.ArrayImpl(
|
||||
jax.core.ShapedArray(shape, np.int32), NamedSharding(mesh, P(None)),
|
||||
core.ShapedArray(shape, np.int32), NamedSharding(mesh, P(None)),
|
||||
[jax.device_put(inp, d) for d in mesh.devices.flat], committed=False)
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
@ -2861,8 +2862,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
def test_pjit_with_mismatched_static_argnames(self):
|
||||
x_is_tracer, y_is_tracer = False, False
|
||||
def f(x, y):
|
||||
assert isinstance(x, jax.core.Tracer) == x_is_tracer
|
||||
assert isinstance(y, jax.core.Tracer) == y_is_tracer
|
||||
assert isinstance(x, core.Tracer) == x_is_tracer
|
||||
assert isinstance(y, core.Tracer) == y_is_tracer
|
||||
return 1
|
||||
|
||||
# If both static_argnums and static_argnames are provided, they are allowed
|
||||
@ -3000,7 +3001,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(cache_info3.hits, cache_info2.hits)
|
||||
|
||||
# AOT test
|
||||
compiled = f.lower(jax.core.ShapedArray(y.shape, y.dtype)).compile()
|
||||
compiled = f.lower(core.ShapedArray(y.shape, y.dtype)).compile()
|
||||
out3 = compiled(y)
|
||||
_check(out3, jax.devices()[1], y)
|
||||
|
||||
@ -3030,7 +3031,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
g_out = g(x)
|
||||
_check(g_out, jax.devices()[0], x)
|
||||
|
||||
compiled = g.lower(jax.core.ShapedArray(x.shape, x.dtype)).compile()
|
||||
compiled = g.lower(core.ShapedArray(x.shape, x.dtype)).compile()
|
||||
out4 = compiled(x)
|
||||
_check(out4, jax.devices()[0], x)
|
||||
|
||||
@ -3703,7 +3704,7 @@ class UtilTest(jtu.JaxTestCase):
|
||||
mesh = pxla.Mesh(np.array(devices).reshape(*mesh_shape), tuple(mesh_axes))
|
||||
|
||||
dims = 5
|
||||
aval = jax.core.ShapedArray((len(devices),) * dims, jnp.float32)
|
||||
aval = core.ShapedArray((len(devices),) * dims, jnp.float32)
|
||||
def roundtrip(spec):
|
||||
op_sharding = NamedSharding(mesh, spec)._to_xla_op_sharding(aval.ndim)
|
||||
parsed_spec = pjit_lib.parse_flatten_op_sharding(op_sharding, mesh)[0].partitions
|
||||
@ -3732,9 +3733,9 @@ class UtilTest(jtu.JaxTestCase):
|
||||
|
||||
def test_get_input_metadata_fully_replicated(self):
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_in_aval1 = jax.core.ShapedArray((4, 4), jnp.int32)
|
||||
global_in_aval2 = jax.core.ShapedArray((4, 4, 4), jnp.int32)
|
||||
global_in_aval3 = jax.core.ShapedArray((), jnp.int32)
|
||||
global_in_aval1 = core.ShapedArray((4, 4), jnp.int32)
|
||||
global_in_aval2 = core.ShapedArray((4, 4, 4), jnp.int32)
|
||||
global_in_aval3 = core.ShapedArray((), jnp.int32)
|
||||
in_avals = [global_in_aval1, global_in_aval2, global_in_aval3]
|
||||
|
||||
mp = NamedSharding(global_mesh, P(None))
|
||||
@ -3753,7 +3754,7 @@ class UtilTest(jtu.JaxTestCase):
|
||||
def test_mesh_sharding_spec(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
array_mapping = pxla.get_array_mapping(P('x', 'y'))
|
||||
aval = jax.core.ShapedArray((1, 1), jnp.int32)
|
||||
aval = core.ShapedArray((1, 1), jnp.int32)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'The aval shape on dimension 0 is 1 and the size of axis x is 4. The '
|
||||
|
@ -37,7 +37,8 @@ from jax import lax
|
||||
from jax._src.lax import parallel
|
||||
from jax._src import api as src_api
|
||||
from jax import random
|
||||
from jax.core import ShapedArray
|
||||
from jax._src import core
|
||||
from jax._src.core import ShapedArray
|
||||
from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
|
||||
linearize, device_put)
|
||||
from jax._src import config as jax_config
|
||||
@ -204,7 +205,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
for obj in [lowered, compiled]:
|
||||
self.assertFalse(obj._no_kwargs)
|
||||
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0,), {}))[1])
|
||||
self.assertEqual(obj.in_avals, ((jax.core.ShapedArray(x.shape, x.dtype),), {}))
|
||||
self.assertEqual(obj.in_avals, ((core.ShapedArray(x.shape, x.dtype),), {}))
|
||||
|
||||
def testLowerCompileInTreeMismatch(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
@ -334,7 +335,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x_shape = jax.core.ShapedArray(x.shape, x.dtype)
|
||||
x_shape = core.ShapedArray(x.shape, x.dtype)
|
||||
self.assertAllClose(f.lower(x_shape).compile()(x), f(x))
|
||||
|
||||
def testMean(self):
|
||||
@ -2012,7 +2013,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def test_axis_env_length(self):
|
||||
f = lambda x: jax.pmap(g)(jnp.array([x]))[0]
|
||||
def g(x):
|
||||
assert len(jax.core.thread_local_state.trace_state.axis_env) == 1
|
||||
assert len(core.thread_local_state.trace_state.axis_env) == 1
|
||||
return x
|
||||
jax.grad(f)(3.) # doesn't fail
|
||||
|
||||
|
@ -19,9 +19,9 @@ from typing import Any, Callable, Sequence
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import core
|
||||
from jax._src import debugging
|
||||
from jax._src import dispatch
|
||||
from jax._src import sharding
|
||||
|
@ -27,12 +27,12 @@ import scipy.special
|
||||
import scipy.stats
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import grad
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import prng
|
||||
from jax import random
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax import vmap
|
||||
|
@ -23,6 +23,7 @@ from jax import lax
|
||||
from jax.config import config
|
||||
from jax.experimental.maps import Mesh
|
||||
from jax.experimental.pjit import PartitionSpec as P
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_bridge
|
||||
import jax.numpy as jnp
|
||||
@ -414,7 +415,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
def test_check_rep_false_doesnt_hit_rep_rules(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
|
||||
|
||||
prim = jax.core.Primitive('prim') # no rep rule here!
|
||||
prim = core.Primitive('prim') # no rep rule here!
|
||||
prim.multiple_results = True
|
||||
prim.def_impl(lambda: [])
|
||||
prim.def_abstract_eval(lambda: [])
|
||||
|
@ -20,7 +20,7 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax import lax
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
|
@ -20,7 +20,7 @@ so it should be checked with pytype/mypy as well as being run with pytest.
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import typing
|
||||
|
@ -31,8 +31,8 @@ import jax.scipy as jscipy
|
||||
from jax._src import test_util as jtu
|
||||
from jax import vmap
|
||||
from jax import lax
|
||||
from jax import core
|
||||
from jax.core import NamedShape
|
||||
from jax._src import core
|
||||
from jax._src.core import NamedShape
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import global_device_array
|
||||
from jax._src import array
|
||||
|
Loading…
x
Reference in New Issue
Block a user