mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Cleanup: fix unused imports & mark exported names
This commit is contained in:
parent
6a00055980
commit
de3191fab3
@ -81,7 +81,7 @@ del _xc
|
||||
|
||||
from jax._src.api import effects_barrier as effects_barrier
|
||||
from jax._src.api import block_until_ready as block_until_ready
|
||||
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint
|
||||
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401
|
||||
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
|
||||
from jax._src.api import clear_backends as _deprecated_clear_backends
|
||||
from jax._src.api import clear_caches as clear_caches
|
||||
@ -122,7 +122,7 @@ from jax._src.xla_bridge import process_count as process_count
|
||||
from jax._src.xla_bridge import process_index as process_index
|
||||
from jax._src.xla_bridge import process_indices as process_indices
|
||||
from jax._src.callback import pure_callback as pure_callback
|
||||
from jax._src.ad_checkpoint import checkpoint_wrapper as remat
|
||||
from jax._src.ad_checkpoint import checkpoint_wrapper as remat # noqa: F401
|
||||
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
|
||||
from jax._src.api import value_and_grad as value_and_grad
|
||||
from jax._src.api import vjp as vjp
|
||||
@ -249,6 +249,6 @@ else:
|
||||
del _deprecation_getattr
|
||||
del _typing
|
||||
|
||||
import jax.lib # TODO(phawkins): remove this export.
|
||||
import jax.lib # TODO(phawkins): remove this export. # noqa: F401
|
||||
|
||||
# trailer
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .cluster import ClusterEnv
|
||||
from .cluster import ClusterEnv as ClusterEnv
|
||||
|
||||
# Order of declaration of the cluster environments
|
||||
# will dictate the order in which they will be checked.
|
||||
@ -20,9 +20,9 @@ from .cluster import ClusterEnv
|
||||
# the user did not explicitly provide the arguments
|
||||
# to :func:`jax.distributed.initialize`, the first
|
||||
# available one from the list will be picked.
|
||||
from .ompi_cluster import OmpiCluster
|
||||
from .slurm_cluster import SlurmCluster
|
||||
from .mpi4py_cluster import Mpi4pyCluster
|
||||
from .cloud_tpu_cluster import GkeTpuCluster
|
||||
from .cloud_tpu_cluster import GceTpuCluster
|
||||
from .k8s_cluster import K8sCluster
|
||||
from .ompi_cluster import OmpiCluster as OmpiCluster
|
||||
from .slurm_cluster import SlurmCluster as SlurmCluster
|
||||
from .mpi4py_cluster import Mpi4pyCluster as Mpi4pyCluster
|
||||
from .cloud_tpu_cluster import GkeTpuCluster as GkeTpuCluster
|
||||
from .cloud_tpu_cluster import GceTpuCluster as GceTpuCluster
|
||||
from .k8s_cluster import K8sCluster as K8sCluster
|
||||
|
@ -12,4 +12,4 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .fusion import cudnn_fusion
|
||||
from .fusion import cudnn_fusion as cudnn_fusion
|
||||
|
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from jax._src.debugger.core import breakpoint
|
||||
from jax._src.debugger.core import breakpoint as breakpoint
|
||||
from jax._src.debugger import cli_debugger
|
||||
from jax._src.debugger import colab_debugger
|
||||
from jax._src.debugger import web_debugger
|
||||
|
@ -12,25 +12,48 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for the control flow primitives."""
|
||||
from jax._src.lax.control_flow.loops import (associative_scan, cummax, cummax_p,
|
||||
cummin, cummin_p, cumlogsumexp,
|
||||
cumlogsumexp_p, cumprod,
|
||||
cumprod_p, cumsum, cumsum_p,
|
||||
cumred_reduce_window_impl,
|
||||
fori_loop, map,
|
||||
scan, scan_bind, scan_p,
|
||||
_scan_impl, while_loop, while_p)
|
||||
from jax._src.lax.control_flow.conditionals import (cond, cond_p, switch,
|
||||
platform_dependent)
|
||||
from jax._src.lax.control_flow.solves import (custom_linear_solve, custom_root,
|
||||
_custom_linear_solve_impl,
|
||||
linear_solve_p)
|
||||
|
||||
from jax._src.lax.control_flow.loops import (
|
||||
associative_scan as associative_scan,
|
||||
cummax as cummax,
|
||||
cummax_p as cummax_p,
|
||||
cummin as cummin,
|
||||
cummin_p as cummin_p,
|
||||
cumlogsumexp as cumlogsumexp,
|
||||
cumlogsumexp_p as cumlogsumexp_p,
|
||||
cumprod as cumprod,
|
||||
cumprod_p as cumprod_p,
|
||||
cumsum as cumsum,
|
||||
cumsum_p as cumsum_p,
|
||||
cumred_reduce_window_impl as cumred_reduce_window_impl,
|
||||
fori_loop as fori_loop,
|
||||
map as map,
|
||||
scan as scan,
|
||||
scan_bind as scan_bind,
|
||||
scan_p as scan_p,
|
||||
_scan_impl as _scan_impl,
|
||||
while_loop as while_loop,
|
||||
while_p as while_p,
|
||||
)
|
||||
from jax._src.lax.control_flow.conditionals import (
|
||||
cond as cond,
|
||||
cond_p as cond_p,
|
||||
switch as switch,
|
||||
platform_dependent as platform_dependent,
|
||||
)
|
||||
from jax._src.lax.control_flow.solves import (
|
||||
custom_linear_solve as custom_linear_solve,
|
||||
custom_root as custom_root,
|
||||
_custom_linear_solve_impl as _custom_linear_solve_impl,
|
||||
linear_solve_p as linear_solve_p,
|
||||
)
|
||||
# Private utilities used elsewhere in JAX
|
||||
# TODO(sharadmv): lift them into a more common place
|
||||
from jax._src.lax.control_flow.common import (_initial_style_open_jaxpr,
|
||||
_initial_style_jaxpr,
|
||||
_initial_style_jaxprs_with_common_consts,
|
||||
_check_tree_and_avals)
|
||||
from jax._src.lax.control_flow.common import (
|
||||
_initial_style_open_jaxpr as _initial_style_open_jaxpr,
|
||||
_initial_style_jaxpr as _initial_style_jaxpr,
|
||||
_initial_style_jaxprs_with_common_consts as _initial_style_jaxprs_with_common_consts,
|
||||
_check_tree_and_avals as _check_tree_and_avals,
|
||||
|
||||
)
|
||||
# TODO(mattjj): fix dependent library which expects optimization_barrier_p here
|
||||
from jax._src.lax.lax import optimization_barrier_p
|
||||
from jax._src.lax.lax import optimization_barrier_p as optimization_barrier_p
|
||||
|
@ -21,7 +21,6 @@ import gc
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
import jaxlib as jaxlib
|
||||
@ -84,9 +83,9 @@ version = check_jaxlib_version(
|
||||
import jaxlib.cpu_feature_guard as cpu_feature_guard
|
||||
cpu_feature_guard.check_cpu_features()
|
||||
|
||||
import jaxlib.utils as utils
|
||||
import jaxlib.utils as utils # noqa: F401
|
||||
import jaxlib.xla_client as xla_client
|
||||
import jaxlib.lapack as lapack
|
||||
import jaxlib.lapack as lapack # noqa: F401
|
||||
|
||||
xla_extension = xla_client._xla
|
||||
pytree = xla_client._xla.pytree
|
||||
@ -99,18 +98,18 @@ def _xla_gc_callback(*args):
|
||||
gc.callbacks.append(_xla_gc_callback)
|
||||
|
||||
try:
|
||||
import jaxlib.cuda._versions as cuda_versions # pytype: disable=import-error
|
||||
import jaxlib.cuda._versions as cuda_versions # pytype: disable=import-error # noqa: F401
|
||||
except ImportError:
|
||||
try:
|
||||
import jax_cuda12_plugin._versions as cuda_versions # pytype: disable=import-error
|
||||
import jax_cuda12_plugin._versions as cuda_versions # pytype: disable=import-error # noqa: F401
|
||||
except ImportError:
|
||||
cuda_versions = None
|
||||
|
||||
import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error
|
||||
import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error
|
||||
import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error
|
||||
import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
|
||||
import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error
|
||||
import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error # noqa: F401
|
||||
import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error # noqa: F401
|
||||
import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error # noqa: F401
|
||||
import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error # noqa: F401
|
||||
import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error # noqa: F401
|
||||
|
||||
# Jaxlib code is split between the Jax and the Tensorflow repositories.
|
||||
# Only for the internal usage of the JAX developers, we expose a version
|
||||
@ -118,10 +117,10 @@ import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error
|
||||
# branch on the Jax github.
|
||||
xla_extension_version: int = getattr(xla_client, '_version', 0)
|
||||
|
||||
import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error
|
||||
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error
|
||||
import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error # noqa: F401
|
||||
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401
|
||||
|
||||
import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error
|
||||
import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error # noqa: F401
|
||||
|
||||
# Version number for MLIR:Python APIs, provided by jaxlib.
|
||||
mlir_api_version = xla_client.mlir_api_version
|
||||
|
@ -13,14 +13,14 @@
|
||||
# limitations under the License.
|
||||
"""Module for state."""
|
||||
from jax._src.state.types import (
|
||||
AbstractRef,
|
||||
AccumEffect,
|
||||
ReadEffect,
|
||||
RefEffect,
|
||||
StateEffect,
|
||||
Transform,
|
||||
TransformedRef,
|
||||
WriteEffect,
|
||||
get_ref_state_effects,
|
||||
shaped_array_ref,
|
||||
AbstractRef as AbstractRef,
|
||||
AccumEffect as AccumEffect,
|
||||
ReadEffect as ReadEffect,
|
||||
RefEffect as RefEffect,
|
||||
StateEffect as StateEffect,
|
||||
Transform as Transform,
|
||||
TransformedRef as TransformedRef,
|
||||
WriteEffect as WriteEffect,
|
||||
get_ref_state_effects as get_ref_state_effects,
|
||||
shaped_array_ref as shaped_array_ref,
|
||||
)
|
||||
|
@ -13,14 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.ad_checkpoint import (
|
||||
checkpoint,
|
||||
checkpoint_policies,
|
||||
checkpoint_name,
|
||||
print_saved_residuals,
|
||||
remat,
|
||||
checkpoint as checkpoint,
|
||||
checkpoint_policies as checkpoint_policies,
|
||||
checkpoint_name as checkpoint_name,
|
||||
print_saved_residuals as print_saved_residuals,
|
||||
remat as remat,
|
||||
)
|
||||
from jax._src.interpreters.partial_eval import (
|
||||
Recompute,
|
||||
Saveable,
|
||||
Offloadable,
|
||||
Recompute as Recompute,
|
||||
Saveable as Saveable,
|
||||
Offloadable as Offloadable,
|
||||
)
|
||||
|
@ -14,12 +14,12 @@
|
||||
|
||||
|
||||
from jax._src.api_util import (
|
||||
argnums_partial,
|
||||
donation_vector,
|
||||
flatten_axes,
|
||||
flatten_fun,
|
||||
flatten_fun_nokwargs,
|
||||
rebase_donate_argnums,
|
||||
safe_map,
|
||||
shaped_abstractify,
|
||||
argnums_partial as argnums_partial,
|
||||
donation_vector as donation_vector,
|
||||
flatten_axes as flatten_axes,
|
||||
flatten_fun as flatten_fun,
|
||||
flatten_fun_nokwargs as flatten_fun_nokwargs,
|
||||
rebase_donate_argnums as rebase_donate_argnums,
|
||||
safe_map as safe_map,
|
||||
shaped_abstractify as shaped_abstractify,
|
||||
)
|
||||
|
@ -12,4 +12,4 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.cloud_tpu_init import cloud_tpu_init
|
||||
from jax._src.cloud_tpu_init import cloud_tpu_init as cloud_tpu_init
|
||||
|
@ -42,7 +42,7 @@ from jax._src.core import (
|
||||
Literal as Literal,
|
||||
MainTrace as MainTrace,
|
||||
MapPrimitive as MapPrimitive,
|
||||
nonempty_axis_env as nonempty_axis_env_DO_NOT_USE,
|
||||
nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401
|
||||
OpaqueTraceState as OpaqueTraceState,
|
||||
NameGatheringSubst as NameGatheringSubst,
|
||||
OutDBIdx as OutDBIdx,
|
||||
@ -58,9 +58,9 @@ from jax._src.core import (
|
||||
TraceStack as TraceStack,
|
||||
TraceState as TraceState,
|
||||
Tracer as Tracer,
|
||||
unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE,
|
||||
unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE,
|
||||
unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE,
|
||||
unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401
|
||||
unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, # noqa: F401
|
||||
unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, # noqa: F401
|
||||
UnshapedArray as UnshapedArray,
|
||||
Value as Value,
|
||||
Var as Var,
|
||||
|
@ -16,9 +16,9 @@
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
from jax._src.custom_derivatives import (
|
||||
_initial_style_jaxpr,
|
||||
_sum_tangents,
|
||||
_zeros_like_pytree,
|
||||
_initial_style_jaxpr, # noqa: F401
|
||||
_sum_tangents, # noqa: F401
|
||||
_zeros_like_pytree, # noqa: F401
|
||||
closure_convert as closure_convert,
|
||||
custom_gradient as custom_gradient,
|
||||
custom_jvp as custom_jvp,
|
||||
|
@ -18,10 +18,10 @@
|
||||
from jax._src.dtypes import (
|
||||
bfloat16 as bfloat16,
|
||||
canonicalize_dtype as canonicalize_dtype,
|
||||
finfo, # TODO(phawkins): switch callers to jnp.finfo?
|
||||
finfo, # TODO(phawkins): switch callers to jnp.finfo? # noqa: F401
|
||||
float0 as float0,
|
||||
iinfo, # TODO(phawkins): switch callers to jnp.iinfo?
|
||||
issubdtype, # TODO(phawkins): switch callers to jnp.issubdtype?
|
||||
iinfo, # TODO(phawkins): switch callers to jnp.iinfo? # noqa: F401
|
||||
issubdtype, # TODO(phawkins): switch callers to jnp.issubdtype? # noqa: F401
|
||||
extended as extended,
|
||||
prng_key as prng_key,
|
||||
result_type as result_type,
|
||||
|
@ -29,11 +29,9 @@ from typing import Any, Optional
|
||||
|
||||
import jax
|
||||
from jax._src import array
|
||||
from jax._src import config
|
||||
from jax._src import distributed
|
||||
from jax._src import sharding
|
||||
from jax._src import sharding_impls
|
||||
from jax._src.layout import Layout, DeviceLocalLayout as DLL
|
||||
from jax._src.layout import Layout
|
||||
from jax._src import typing
|
||||
from jax._src import util
|
||||
from jax._src.lib import xla_extension as xe
|
||||
|
@ -27,7 +27,6 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import array
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax.sharding import NamedSharding, GSPMDSharding, SingleDeviceSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental.array_serialization import serialization
|
||||
|
@ -30,7 +30,6 @@ from typing import Any
|
||||
import warnings
|
||||
from absl import flags
|
||||
|
||||
import flax
|
||||
from flax import linen as nn
|
||||
|
||||
import jax
|
||||
|
@ -39,8 +39,8 @@ from absl import logging
|
||||
import numpy.random as npr
|
||||
|
||||
import jax # Must import before TF
|
||||
from jax.experimental import jax2tf # Defines needed flags
|
||||
from jax._src import test_util # Defines needed flags
|
||||
from jax.experimental import jax2tf # Defines needed flags # noqa: F401
|
||||
from jax._src import test_util # Defines needed flags # noqa: F401
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
@ -28,7 +28,6 @@ import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
|
@ -67,7 +67,6 @@ from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental import jax2tf
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
@ -45,7 +45,6 @@ from jax._src import util
|
||||
from jax._src.export import shape_poly
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax._src.lib import xla_client
|
||||
import numpy as np
|
||||
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
@ -22,8 +22,6 @@ from collections.abc import Sequence
|
||||
import contextlib
|
||||
from functools import partial
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
import unittest
|
||||
|
@ -16,7 +16,7 @@ from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterator
|
||||
from functools import partial, reduce, total_ordering, wraps
|
||||
from functools import partial, reduce, total_ordering
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import jax
|
||||
@ -25,7 +25,6 @@ from jax import tree_util
|
||||
from jax.errors import KeyReuseError
|
||||
from jax.interpreters import batching, mlir
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import pjit
|
||||
|
@ -14,5 +14,5 @@
|
||||
# ==============================================================================
|
||||
"""Contains bindings for Mosaic."""
|
||||
|
||||
from jax._src.tpu_custom_call import as_tpu_kernel
|
||||
from jax._src.tpu_custom_call import lower_module_to_custom_call
|
||||
from jax._src.tpu_custom_call import as_tpu_kernel as as_tpu_kernel
|
||||
from jax._src.tpu_custom_call import lower_module_to_custom_call as lower_module_to_custom_call
|
||||
|
@ -14,4 +14,4 @@
|
||||
# ==============================================================================
|
||||
"""Contains bindings for Mosaic MLIR dialects."""
|
||||
|
||||
from jax._src.lib import tpu
|
||||
from jax._src.lib import tpu as tpu
|
||||
|
@ -13,55 +13,55 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from jax import ShapeDtypeStruct
|
||||
from jax import ShapeDtypeStruct as ShapeDtypeStruct
|
||||
from .core import (
|
||||
Barrier,
|
||||
ClusterBarrier,
|
||||
LaunchContext,
|
||||
MemRefTransform,
|
||||
TMABarrier,
|
||||
TileTransform,
|
||||
TransposeTransform,
|
||||
Union,
|
||||
as_gpu_kernel,
|
||||
Barrier as Barrier,
|
||||
ClusterBarrier as ClusterBarrier,
|
||||
LaunchContext as LaunchContext,
|
||||
MemRefTransform as MemRefTransform,
|
||||
TMABarrier as TMABarrier,
|
||||
TileTransform as TileTransform,
|
||||
TransposeTransform as TransposeTransform,
|
||||
Union as Union,
|
||||
as_gpu_kernel as as_gpu_kernel,
|
||||
)
|
||||
from .fragmented_array import (
|
||||
FragmentedArray,
|
||||
FragmentedLayout,
|
||||
WGMMA_LAYOUT,
|
||||
WGMMA_ROW_LAYOUT,
|
||||
WGSplatFragLayout,
|
||||
WGStridedFragLayout,
|
||||
FragmentedArray as FragmentedArray,
|
||||
FragmentedLayout as FragmentedLayout,
|
||||
WGMMA_LAYOUT as WGMMA_LAYOUT,
|
||||
WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT,
|
||||
WGSplatFragLayout as WGSplatFragLayout,
|
||||
WGStridedFragLayout as WGStridedFragLayout,
|
||||
)
|
||||
from .utils import (
|
||||
BarrierRef,
|
||||
CollectiveBarrierRef,
|
||||
DynamicSlice,
|
||||
Partition,
|
||||
Partition1D,
|
||||
bytewidth,
|
||||
c,
|
||||
commit_shared,
|
||||
debug_print,
|
||||
ds,
|
||||
fori,
|
||||
memref_fold,
|
||||
memref_slice,
|
||||
memref_reshape,
|
||||
memref_transpose,
|
||||
memref_unfold,
|
||||
memref_unsqueeze,
|
||||
single_thread,
|
||||
single_thread_predicate,
|
||||
thread_idx,
|
||||
tile_shape,
|
||||
warp_idx,
|
||||
warpgroup_barrier,
|
||||
warpgroup_idx,
|
||||
when,
|
||||
BarrierRef as BarrierRef,
|
||||
CollectiveBarrierRef as CollectiveBarrierRef,
|
||||
DynamicSlice as DynamicSlice,
|
||||
Partition as Partition,
|
||||
Partition1D as Partition1D,
|
||||
bytewidth as bytewidth,
|
||||
c as c,
|
||||
commit_shared as commit_shared,
|
||||
debug_print as debug_print,
|
||||
ds as ds,
|
||||
fori as fori,
|
||||
memref_fold as memref_fold,
|
||||
memref_slice as memref_slice,
|
||||
memref_reshape as memref_reshape,
|
||||
memref_transpose as memref_transpose,
|
||||
memref_unfold as memref_unfold,
|
||||
memref_unsqueeze as memref_unsqueeze,
|
||||
single_thread as single_thread,
|
||||
single_thread_predicate as single_thread_predicate,
|
||||
thread_idx as thread_idx,
|
||||
tile_shape as tile_shape,
|
||||
warp_idx as warp_idx,
|
||||
warpgroup_barrier as warpgroup_barrier,
|
||||
warpgroup_idx as warpgroup_idx,
|
||||
when as when,
|
||||
)
|
||||
from .wgmma import (
|
||||
WGMMAAccumulator,
|
||||
WGMMALayout,
|
||||
wgmma,
|
||||
WGMMAAccumulator as WGMMAAccumulator,
|
||||
WGMMALayout as WGMMALayout,
|
||||
wgmma as wgmma,
|
||||
)
|
||||
|
@ -17,21 +17,17 @@ import contextlib
|
||||
import dataclasses
|
||||
import enum
|
||||
import itertools
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from absl import app
|
||||
import jax
|
||||
from jax import random
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
from jax.experimental.mosaic.gpu import * # noqa: F403
|
||||
import jax.numpy as jnp
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith
|
||||
from jaxlib.mlir.dialects import gpu
|
||||
from jaxlib.mlir.dialects import nvgpu
|
||||
from jaxlib.mlir.dialects import nvvm
|
||||
from jaxlib.mlir.dialects import scf
|
||||
import numpy as np
|
||||
|
@ -16,7 +16,6 @@
|
||||
|
||||
import dataclasses
|
||||
import itertools
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
|
@ -20,7 +20,7 @@ import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import math
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, Literal
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
|
@ -16,13 +16,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial, lru_cache
|
||||
from typing import Optional
|
||||
import zlib
|
||||
|
||||
from typing import Any
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src import core
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
|
@ -18,47 +18,47 @@ See the Pallas documentation at
|
||||
https://jax.readthedocs.io/en/latest/pallas.html.
|
||||
"""
|
||||
|
||||
from jax._src.pallas.core import Blocked
|
||||
from jax._src.pallas.core import BlockSpec
|
||||
from jax._src.pallas.core import CompilerParams
|
||||
from jax._src.pallas.core import core_map
|
||||
from jax._src.pallas.core import CostEstimate
|
||||
from jax._src.pallas.core import GridSpec
|
||||
from jax._src.pallas.core import IndexingMode
|
||||
from jax._src.pallas.core import MemorySpace
|
||||
from jax._src.pallas.core import MemoryRef
|
||||
from jax._src.pallas.core import no_block_spec
|
||||
from jax._src.pallas.core import Unblocked
|
||||
from jax._src.pallas.core import unblocked
|
||||
from jax._src.pallas.pallas_call import pallas_call
|
||||
from jax._src.pallas.pallas_call import pallas_call_p
|
||||
from jax._src.pallas.primitives import atomic_add
|
||||
from jax._src.pallas.primitives import atomic_and
|
||||
from jax._src.pallas.primitives import atomic_cas
|
||||
from jax._src.pallas.primitives import atomic_max
|
||||
from jax._src.pallas.primitives import atomic_min
|
||||
from jax._src.pallas.primitives import atomic_or
|
||||
from jax._src.pallas.primitives import atomic_xchg
|
||||
from jax._src.pallas.primitives import atomic_xor
|
||||
from jax._src.pallas.primitives import debug_print
|
||||
from jax._src.pallas.primitives import dot
|
||||
from jax._src.pallas.primitives import load
|
||||
from jax._src.pallas.primitives import max_contiguous
|
||||
from jax._src.pallas.primitives import multiple_of
|
||||
from jax._src.pallas.primitives import num_programs
|
||||
from jax._src.pallas.primitives import program_id
|
||||
from jax._src.pallas.primitives import run_scoped
|
||||
from jax._src.pallas.primitives import store
|
||||
from jax._src.pallas.primitives import swap
|
||||
from jax._src.pallas.utils import cdiv
|
||||
from jax._src.pallas.utils import next_power_of_2
|
||||
from jax._src.pallas.utils import strides_from_shape
|
||||
from jax._src.pallas.utils import when
|
||||
from jax._src.state.discharge import run_state
|
||||
from jax._src.state.indexing import ds
|
||||
from jax._src.state.indexing import dslice
|
||||
from jax._src.state.indexing import Slice
|
||||
from jax._src.state.primitives import broadcast_to
|
||||
from jax._src.pallas.core import Blocked as Blocked
|
||||
from jax._src.pallas.core import BlockSpec as BlockSpec
|
||||
from jax._src.pallas.core import CompilerParams as CompilerParams
|
||||
from jax._src.pallas.core import core_map as core_map
|
||||
from jax._src.pallas.core import CostEstimate as CostEstimate
|
||||
from jax._src.pallas.core import GridSpec as GridSpec
|
||||
from jax._src.pallas.core import IndexingMode as IndexingMode
|
||||
from jax._src.pallas.core import MemorySpace as MemorySpace
|
||||
from jax._src.pallas.core import MemoryRef as MemoryRef
|
||||
from jax._src.pallas.core import no_block_spec as no_block_spec
|
||||
from jax._src.pallas.core import Unblocked as Unblocked
|
||||
from jax._src.pallas.core import unblocked as unblocked
|
||||
from jax._src.pallas.pallas_call import pallas_call as pallas_call
|
||||
from jax._src.pallas.pallas_call import pallas_call_p as pallas_call_p
|
||||
from jax._src.pallas.primitives import atomic_add as atomic_add
|
||||
from jax._src.pallas.primitives import atomic_and as atomic_and
|
||||
from jax._src.pallas.primitives import atomic_cas as atomic_cas
|
||||
from jax._src.pallas.primitives import atomic_max as atomic_max
|
||||
from jax._src.pallas.primitives import atomic_min as atomic_min
|
||||
from jax._src.pallas.primitives import atomic_or as atomic_or
|
||||
from jax._src.pallas.primitives import atomic_xchg as atomic_xchg
|
||||
from jax._src.pallas.primitives import atomic_xor as atomic_xor
|
||||
from jax._src.pallas.primitives import debug_print as debug_print
|
||||
from jax._src.pallas.primitives import dot as dot
|
||||
from jax._src.pallas.primitives import load as load
|
||||
from jax._src.pallas.primitives import max_contiguous as max_contiguous
|
||||
from jax._src.pallas.primitives import multiple_of as multiple_of
|
||||
from jax._src.pallas.primitives import num_programs as num_programs
|
||||
from jax._src.pallas.primitives import program_id as program_id
|
||||
from jax._src.pallas.primitives import run_scoped as run_scoped
|
||||
from jax._src.pallas.primitives import store as store
|
||||
from jax._src.pallas.primitives import swap as swap
|
||||
from jax._src.pallas.utils import cdiv as cdiv
|
||||
from jax._src.pallas.utils import next_power_of_2 as next_power_of_2
|
||||
from jax._src.pallas.utils import strides_from_shape as strides_from_shape
|
||||
from jax._src.pallas.utils import when as when
|
||||
from jax._src.state.discharge import run_state as run_state
|
||||
from jax._src.state.indexing import ds as ds
|
||||
from jax._src.state.indexing import dslice as dslice
|
||||
from jax._src.state.indexing import Slice as Slice
|
||||
from jax._src.state.primitives import broadcast_to as broadcast_to
|
||||
|
||||
|
||||
ANY = MemorySpace.ANY
|
||||
|
@ -17,23 +17,23 @@
|
||||
These APIs are highly unstable and can change weekly. Use at your own risk.
|
||||
"""
|
||||
|
||||
from jax._src.pallas.mosaic_gpu.core import Barrier
|
||||
from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec
|
||||
from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams
|
||||
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
|
||||
from jax._src.pallas.mosaic_gpu.core import SwizzleTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import TilingTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import TransposeTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef
|
||||
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC
|
||||
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import barrier_wait
|
||||
from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive
|
||||
from jax._src.pallas.mosaic_gpu.primitives import set_max_registers
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wgmma
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait
|
||||
from jax._src.pallas.mosaic_gpu.core import Barrier as Barrier
|
||||
from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec as GPUBlockSpec
|
||||
from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams as GPUCompilerParams
|
||||
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace as GPUMemorySpace
|
||||
from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef
|
||||
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401
|
||||
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait
|
||||
from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive
|
||||
from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait as wgmma_wait
|
||||
|
||||
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM`.
|
||||
GMEM = GPUMemorySpace.GMEM
|
||||
|
@ -18,8 +18,6 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
|
@ -12,4 +12,4 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax.experimental.pallas.ops.tpu.megablox.ops import gmm
|
||||
from jax.experimental.pallas.ops.tpu.megablox.ops import gmm as gmm
|
||||
|
@ -12,4 +12,4 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention
|
||||
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as paged_attention
|
||||
|
@ -12,21 +12,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import BlockSizes
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_masked_mha_reference
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_masked_mqa_reference
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mha
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mha_single_device
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mqa
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mqa_single_device
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import QKVLayout
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import SegmentIds
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import CausalMask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import FullMask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import LocalMask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_causal_mask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_local_attention_mask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_random_mask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import Mask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import MultiHeadMask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import NumpyMask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import BlockSizes as BlockSizes
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_masked_mha_reference as make_masked_mha_reference
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_masked_mqa_reference as make_masked_mqa_reference
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mha as make_splash_mha
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mha_single_device as make_splash_mha_single_device
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mqa as make_splash_mqa
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mqa_single_device as make_splash_mqa_single_device
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import QKVLayout as QKVLayout
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import SegmentIds as SegmentIds
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import CausalMask as CausalMask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import FullMask as FullMask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import LocalMask as LocalMask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_causal_mask as make_causal_mask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_local_attention_mask as make_local_attention_mask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_random_mask as make_random_mask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import Mask as Mask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import MultiHeadMask as MultiHeadMask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import NumpyMask as NumpyMask
|
||||
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
import collections
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
from typing import Dict, List, NamedTuple, Set, Tuple
|
||||
from typing import NamedTuple
|
||||
from jax import util as jax_util
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib
|
||||
import numpy as np
|
||||
|
@ -14,47 +14,47 @@
|
||||
|
||||
"""Mosaic-specific Pallas APIs."""
|
||||
|
||||
from jax._src.pallas.mosaic import core
|
||||
from jax._src.pallas.mosaic.core import create_tensorcore_mesh
|
||||
from jax._src.pallas.mosaic.core import dma_semaphore
|
||||
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
|
||||
from jax._src.pallas.mosaic.core import semaphore
|
||||
from jax._src.pallas.mosaic.core import SemaphoreType
|
||||
from jax._src.pallas.mosaic.core import TPUMemorySpace
|
||||
from jax._src.pallas.mosaic.core import TPUCompilerParams
|
||||
from jax._src.pallas.mosaic.core import runtime_assert_enabled
|
||||
from jax._src.pallas.mosaic.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert
|
||||
from jax._src.pallas.mosaic.lowering import LoweringException
|
||||
from jax._src.pallas.mosaic.pipeline import ARBITRARY
|
||||
from jax._src.pallas.mosaic.pipeline import BufferedRef
|
||||
from jax._src.pallas.mosaic.pipeline import emit_pipeline
|
||||
from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations
|
||||
from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule
|
||||
from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations
|
||||
from jax._src.pallas.mosaic.pipeline import PARALLEL
|
||||
from jax._src.pallas.mosaic.primitives import async_copy
|
||||
from jax._src.pallas.mosaic.primitives import async_remote_copy
|
||||
from jax._src.pallas.mosaic.primitives import bitcast
|
||||
from jax._src.pallas.mosaic.primitives import delay
|
||||
from jax._src.pallas.mosaic.primitives import device_id
|
||||
from jax._src.pallas.mosaic.primitives import DeviceIdType
|
||||
from jax._src.pallas.mosaic.primitives import get_barrier_semaphore
|
||||
from jax._src.pallas.mosaic.primitives import make_async_copy
|
||||
from jax._src.pallas.mosaic.primitives import make_async_remote_copy
|
||||
from jax._src.pallas.mosaic.primitives import prng_random_bits
|
||||
from jax._src.pallas.mosaic.primitives import prng_seed
|
||||
from jax._src.pallas.mosaic.primitives import repeat
|
||||
from jax._src.pallas.mosaic.primitives import roll
|
||||
from jax._src.pallas.mosaic.primitives import semaphore_read
|
||||
from jax._src.pallas.mosaic.primitives import semaphore_signal
|
||||
from jax._src.pallas.mosaic.primitives import semaphore_wait
|
||||
from jax._src.pallas.mosaic.random import to_pallas_key
|
||||
from jax._src.pallas.mosaic import core as core
|
||||
from jax._src.pallas.mosaic.core import create_tensorcore_mesh as create_tensorcore_mesh
|
||||
from jax._src.pallas.mosaic.core import dma_semaphore as dma_semaphore
|
||||
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec as PrefetchScalarGridSpec
|
||||
from jax._src.pallas.mosaic.core import semaphore as semaphore
|
||||
from jax._src.pallas.mosaic.core import SemaphoreType as SemaphoreType
|
||||
from jax._src.pallas.mosaic.core import TPUMemorySpace as TPUMemorySpace
|
||||
from jax._src.pallas.mosaic.core import TPUCompilerParams as TPUCompilerParams
|
||||
from jax._src.pallas.mosaic.core import runtime_assert_enabled as runtime_assert_enabled
|
||||
from jax._src.pallas.mosaic.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert # noqa: F401
|
||||
from jax._src.pallas.mosaic.lowering import LoweringException as LoweringException
|
||||
from jax._src.pallas.mosaic.pipeline import ARBITRARY as ARBITRARY
|
||||
from jax._src.pallas.mosaic.pipeline import BufferedRef as BufferedRef
|
||||
from jax._src.pallas.mosaic.pipeline import emit_pipeline as emit_pipeline
|
||||
from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations as emit_pipeline_with_allocations
|
||||
from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule as get_pipeline_schedule
|
||||
from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations as make_pipeline_allocations
|
||||
from jax._src.pallas.mosaic.pipeline import PARALLEL as PARALLEL
|
||||
from jax._src.pallas.mosaic.primitives import async_copy as async_copy
|
||||
from jax._src.pallas.mosaic.primitives import async_remote_copy as async_remote_copy
|
||||
from jax._src.pallas.mosaic.primitives import bitcast as bitcast
|
||||
from jax._src.pallas.mosaic.primitives import delay as delay
|
||||
from jax._src.pallas.mosaic.primitives import device_id as device_id
|
||||
from jax._src.pallas.mosaic.primitives import DeviceIdType as DeviceIdType
|
||||
from jax._src.pallas.mosaic.primitives import get_barrier_semaphore as get_barrier_semaphore
|
||||
from jax._src.pallas.mosaic.primitives import make_async_copy as make_async_copy
|
||||
from jax._src.pallas.mosaic.primitives import make_async_remote_copy as make_async_remote_copy
|
||||
from jax._src.pallas.mosaic.primitives import prng_random_bits as prng_random_bits
|
||||
from jax._src.pallas.mosaic.primitives import prng_seed as prng_seed
|
||||
from jax._src.pallas.mosaic.primitives import repeat as repeat
|
||||
from jax._src.pallas.mosaic.primitives import roll as roll
|
||||
from jax._src.pallas.mosaic.primitives import semaphore_read as semaphore_read
|
||||
from jax._src.pallas.mosaic.primitives import semaphore_signal as semaphore_signal
|
||||
from jax._src.pallas.mosaic.primitives import semaphore_wait as semaphore_wait
|
||||
from jax._src.pallas.mosaic.random import to_pallas_key as to_pallas_key
|
||||
# Remove this import after October 22th 2024.
|
||||
from jax._src.tpu_custom_call import CostEstimate
|
||||
from jax._src.tpu_custom_call import CostEstimate as CostEstimate
|
||||
|
||||
# TODO(cperivol): Temporary alias to the global run_scoped. Remove
|
||||
# this once everyone has migrated to the pallas core one.
|
||||
from jax._src.pallas.primitives import run_scoped
|
||||
from jax._src.pallas.primitives import run_scoped as run_scoped
|
||||
|
||||
import types
|
||||
from jax._src.pallas.mosaic.verification import assume
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
"""Triton-specific Pallas APIs."""
|
||||
|
||||
from jax._src.pallas.triton.core import TritonCompilerParams
|
||||
from jax._src.pallas.triton.primitives import approx_tanh
|
||||
from jax._src.pallas.triton.primitives import debug_barrier
|
||||
from jax._src.pallas.triton.primitives import elementwise_inline_asm
|
||||
from jax._src.pallas.triton.core import TritonCompilerParams as TritonCompilerParams
|
||||
from jax._src.pallas.triton.primitives import approx_tanh as approx_tanh
|
||||
from jax._src.pallas.triton.primitives import debug_barrier as debug_barrier
|
||||
from jax._src.pallas.triton.primitives import elementwise_inline_asm as elementwise_inline_asm
|
||||
|
@ -17,7 +17,6 @@ from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
import io
|
||||
from typing import Optional, Union
|
||||
|
||||
import jax
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
@ -52,7 +52,7 @@ from jax._src.lax import (lax, parallel as lax_parallel, slicing,
|
||||
windowed_reductions, convolution, fft, linalg,
|
||||
special, control_flow, ann)
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo, sdy
|
||||
from jax._src.lib.mlir.dialects import sdy
|
||||
from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3,
|
||||
as_hashable_function, memoize, partition_list,
|
||||
merge_lists, split_list, subs_list2)
|
||||
|
@ -34,7 +34,6 @@ from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Optional, Union
|
||||
|
||||
import jax
|
||||
from jax import tree_util
|
||||
|
@ -49,9 +49,7 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax.lax import (
|
||||
_const, ranges_like, remaining, _dot_general_batch_dim_nums, DotDimensionNumbers)
|
||||
from jax._src.lax.slicing import GatherDimensionNumbers, GatherScatterMode
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy.setops import _unique
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
from jax._src.util import canonicalize_axis
|
||||
|
@ -31,8 +31,7 @@ from jax.experimental.sparse._base import JAXSparse
|
||||
from jax.experimental.sparse import bcoo
|
||||
from jax.experimental.sparse.util import (
|
||||
nfold_vmap, _count_stored_elements,
|
||||
_csr_to_coo, _dot_general_validated_shape,
|
||||
CuSparseEfficiencyWarning, SparseInfo, Shape)
|
||||
_csr_to_coo, CuSparseEfficiencyWarning, SparseInfo, Shape)
|
||||
from jax.util import split_list, safe_zip
|
||||
|
||||
from jax._src import api_util
|
||||
|
@ -17,7 +17,6 @@ from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Optional
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -35,7 +34,7 @@ from jax._src.interpreters import ad
|
||||
from jax._src.lax.lax import _const
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.numpy.util import promote_dtypes
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
from jax._src.typing import Array, DTypeLike
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
|
@ -26,7 +26,6 @@ from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax.lax import DotDimensionNumbers
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.typing import DTypeLike
|
||||
from jax.experimental import sparse
|
||||
import jax.numpy as jnp
|
||||
|
@ -15,7 +15,7 @@
|
||||
"""Sparse utilities."""
|
||||
|
||||
import functools
|
||||
from typing import Any, NamedTuple, Union
|
||||
from typing import NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
@ -23,8 +23,6 @@ from jax import lax
|
||||
from jax import tree_util
|
||||
from jax import vmap
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import stages
|
||||
from jax._src.api_util import flatten_axes
|
||||
import jax.numpy as jnp
|
||||
from jax.util import safe_zip
|
||||
|
@ -14,9 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import jax
|
||||
from jax.experimental import mesh_utils
|
||||
|
@ -27,7 +27,7 @@ from jax._src.interpreters.mlir import (
|
||||
Token as Token,
|
||||
TokenSet as TokenSet,
|
||||
Value as Value,
|
||||
call_lowering as _call_lowering,
|
||||
call_lowering as _call_lowering, # noqa: F401
|
||||
_lowerings as _lowerings,
|
||||
_platform_specific_lowerings as _platform_specific_lowerings,
|
||||
aval_to_ir_type as aval_to_ir_type,
|
||||
@ -41,7 +41,7 @@ from jax._src.interpreters.mlir import (
|
||||
dtype_to_ir_type as dtype_to_ir_type,
|
||||
emit_python_callback as emit_python_callback,
|
||||
flatten_ir_types as flatten_ir_types,
|
||||
flatten_ir_values as flatten_lowering_ir_args, # TODO(phawkins): remove me
|
||||
flatten_ir_values as flatten_lowering_ir_args, # TODO(phawkins): remove me # noqa: F401
|
||||
flatten_ir_values as flatten_ir_values,
|
||||
unflatten_ir_values_like_types as unflatten_ir_values_like_types,
|
||||
func_dialect as func_dialect,
|
||||
|
@ -38,9 +38,9 @@ from jax._src.op_shardings import (
|
||||
|
||||
from jax._src.sharding_impls import (
|
||||
ArrayMapping as ArrayMapping,
|
||||
UNSPECIFIED as _UNSPECIFIED,
|
||||
UNSPECIFIED as _UNSPECIFIED, # noqa: F401
|
||||
array_mapping_to_axis_resources as array_mapping_to_axis_resources,
|
||||
is_unspecified as _is_unspecified,
|
||||
is_unspecified as _is_unspecified, # noqa: F401
|
||||
)
|
||||
|
||||
from jax._src.sharding_specs import (
|
||||
|
@ -13,31 +13,31 @@
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.lax.linalg import (
|
||||
cholesky,
|
||||
cholesky_p,
|
||||
eig,
|
||||
eig_p,
|
||||
eigh,
|
||||
eigh_p,
|
||||
hessenberg,
|
||||
hessenberg_p,
|
||||
lu,
|
||||
lu_p,
|
||||
lu_pivots_to_permutation,
|
||||
householder_product,
|
||||
householder_product_p,
|
||||
qr,
|
||||
qr_p,
|
||||
svd,
|
||||
svd_p,
|
||||
triangular_solve,
|
||||
triangular_solve_p,
|
||||
tridiagonal,
|
||||
tridiagonal_p,
|
||||
tridiagonal_solve,
|
||||
tridiagonal_solve_p,
|
||||
schur,
|
||||
schur_p
|
||||
cholesky as cholesky,
|
||||
cholesky_p as cholesky_p,
|
||||
eig as eig,
|
||||
eig_p as eig_p,
|
||||
eigh as eigh,
|
||||
eigh_p as eigh_p,
|
||||
hessenberg as hessenberg,
|
||||
hessenberg_p as hessenberg_p,
|
||||
lu as lu,
|
||||
lu_p as lu_p,
|
||||
lu_pivots_to_permutation as lu_pivots_to_permutation,
|
||||
householder_product as householder_product,
|
||||
householder_product_p as householder_product_p,
|
||||
qr as qr,
|
||||
qr_p as qr_p,
|
||||
svd as svd,
|
||||
svd_p as svd_p,
|
||||
triangular_solve as triangular_solve,
|
||||
triangular_solve_p as triangular_solve_p,
|
||||
tridiagonal as tridiagonal,
|
||||
tridiagonal_p as tridiagonal_p,
|
||||
tridiagonal_solve as tridiagonal_solve,
|
||||
tridiagonal_solve_p as tridiagonal_solve_p,
|
||||
schur as schur,
|
||||
schur_p as schur_p,
|
||||
)
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@
|
||||
from jax.numpy import fft as fft
|
||||
from jax.numpy import linalg as linalg
|
||||
|
||||
from jax._src.basearray import Array as ndarray
|
||||
from jax._src.basearray import Array as ndarray # noqa: F401
|
||||
|
||||
from jax._src.dtypes import (
|
||||
isdtype as isdtype,
|
||||
@ -53,7 +53,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
bincount as bincount,
|
||||
blackman as blackman,
|
||||
block as block,
|
||||
bool_ as bool, # Array API alias for bool_
|
||||
bool_ as bool, # Array API alias for bool_ # noqa: F401
|
||||
bool_ as bool_,
|
||||
broadcast_arrays as broadcast_arrays,
|
||||
broadcast_shapes as broadcast_shapes,
|
||||
|
@ -13,6 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.scipy.stats.nbinom import (
|
||||
logpmf,
|
||||
pmf,
|
||||
logpmf as logpmf,
|
||||
pmf as pmf,
|
||||
)
|
||||
|
@ -17,7 +17,6 @@
|
||||
|
||||
from jax._src.sharding import Sharding as Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
XLACompatibleSharding as _deprecated_XLACompatibleSharding,
|
||||
NamedSharding as NamedSharding,
|
||||
SingleDeviceSharding as SingleDeviceSharding,
|
||||
PmapSharding as PmapSharding,
|
||||
@ -28,7 +27,7 @@ from jax._src.partition_spec import (
|
||||
PartitionSpec as PartitionSpec,
|
||||
)
|
||||
from jax._src.interpreters.pxla import Mesh as Mesh
|
||||
from jax._src.mesh import AbstractMesh
|
||||
from jax._src.mesh import AbstractMesh as AbstractMesh
|
||||
|
||||
_deprecations = {
|
||||
# Finalized 2024-10-01; remove after 2025-01-01.
|
||||
|
@ -17,8 +17,6 @@ import importlib
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import sys
|
||||
|
||||
from jax._src.lib import xla_client
|
||||
import jax._src.xla_bridge as xb
|
||||
|
@ -17,7 +17,6 @@ import importlib
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
|
||||
from jax._src.lib import xla_client
|
||||
import jax._src.xla_bridge as xb
|
||||
|
@ -133,42 +133,3 @@ max-complexity = 18
|
||||
"docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb" = ["F811"]
|
||||
"docs/jep/9407-type-promotion.ipynb" = ["F811"]
|
||||
"docs/autodidax.ipynb" = ["F811"]
|
||||
# Note: we don't use jax/*.py because this matches contents of jax/_src
|
||||
"__init__.py" = ["F401"]
|
||||
"jax/abstract_arrays.py" = ["F401"]
|
||||
"jax/ad_checkpoint.py" = ["F401"]
|
||||
"jax/api_util.py" = ["F401"]
|
||||
"jax/cloud_tpu_init.py" = ["F401"]
|
||||
"jax/core.py" = ["F401"]
|
||||
"jax/custom_batching.py" = ["F401"]
|
||||
"jax/custom_derivatives.py" = ["F401"]
|
||||
"jax/custom_transpose.py" = ["F401"]
|
||||
"jax/debug.py" = ["F401"]
|
||||
"jax/distributed.py" = ["F401"]
|
||||
"jax/dlpack.py" = ["F401"]
|
||||
"jax/dtypes.py" = ["F401"]
|
||||
"jax/errors.py" = ["F401"]
|
||||
"jax/experimental/*.py" = ["F401"]
|
||||
"jax/extend/*.py" = ["F401"]
|
||||
"jax/flatten_util.py" = ["F401"]
|
||||
"jax/interpreters/ad.py" = ["F401"]
|
||||
"jax/interpreters/batching.py" = ["F401"]
|
||||
"jax/interpreters/mlir.py" = ["F401"]
|
||||
"jax/interpreters/partial_eval.py" = ["F401"]
|
||||
"jax/interpreters/pxla.py" = ["F401"]
|
||||
"jax/interpreters/xla.py" = ["F401"]
|
||||
"jax/lax/*.py" = ["F401"]
|
||||
"jax/linear_util.py" = ["F401"]
|
||||
"jax/monitoring.py" = ["F401"]
|
||||
"jax/nn/*.py" = ["F401"]
|
||||
"jax/numpy/*.py" = ["F401"]
|
||||
"jax/prng.py" = ["F401"]
|
||||
"jax/profiler.py" = ["F401"]
|
||||
"jax/random.py" = ["F401"]
|
||||
"jax/scipy/*.py" = ["F401"]
|
||||
"jax/sharding.py" = ["F401"]
|
||||
"jax/stages.py" = ["F401"]
|
||||
"jax/test_util.py" = ["F401"]
|
||||
"jax/tree_util.py" = ["F401"]
|
||||
"jax/typing.py" = ["F401"]
|
||||
"jax/util.py" = ["F401"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user