Cleanup: fix unused imports & mark exported names

This commit is contained in:
Jake VanderPlas 2024-10-16 12:19:58 -07:00
parent 6a00055980
commit de3191fab3
56 changed files with 314 additions and 364 deletions

View File

@ -81,7 +81,7 @@ del _xc
from jax._src.api import effects_barrier as effects_barrier
from jax._src.api import block_until_ready as block_until_ready
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
from jax._src.api import clear_backends as _deprecated_clear_backends
from jax._src.api import clear_caches as clear_caches
@ -122,7 +122,7 @@ from jax._src.xla_bridge import process_count as process_count
from jax._src.xla_bridge import process_index as process_index
from jax._src.xla_bridge import process_indices as process_indices
from jax._src.callback import pure_callback as pure_callback
from jax._src.ad_checkpoint import checkpoint_wrapper as remat
from jax._src.ad_checkpoint import checkpoint_wrapper as remat # noqa: F401
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
from jax._src.api import value_and_grad as value_and_grad
from jax._src.api import vjp as vjp
@ -249,6 +249,6 @@ else:
del _deprecation_getattr
del _typing
import jax.lib # TODO(phawkins): remove this export.
import jax.lib # TODO(phawkins): remove this export. # noqa: F401
# trailer

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .cluster import ClusterEnv
from .cluster import ClusterEnv as ClusterEnv
# Order of declaration of the cluster environments
# will dictate the order in which they will be checked.
@ -20,9 +20,9 @@ from .cluster import ClusterEnv
# the user did not explicitly provide the arguments
# to :func:`jax.distributed.initialize`, the first
# available one from the list will be picked.
from .ompi_cluster import OmpiCluster
from .slurm_cluster import SlurmCluster
from .mpi4py_cluster import Mpi4pyCluster
from .cloud_tpu_cluster import GkeTpuCluster
from .cloud_tpu_cluster import GceTpuCluster
from .k8s_cluster import K8sCluster
from .ompi_cluster import OmpiCluster as OmpiCluster
from .slurm_cluster import SlurmCluster as SlurmCluster
from .mpi4py_cluster import Mpi4pyCluster as Mpi4pyCluster
from .cloud_tpu_cluster import GkeTpuCluster as GkeTpuCluster
from .cloud_tpu_cluster import GceTpuCluster as GceTpuCluster
from .k8s_cluster import K8sCluster as K8sCluster

View File

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .fusion import cudnn_fusion
from .fusion import cudnn_fusion as cudnn_fusion

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.debugger.core import breakpoint
from jax._src.debugger.core import breakpoint as breakpoint
from jax._src.debugger import cli_debugger
from jax._src.debugger import colab_debugger
from jax._src.debugger import web_debugger

View File

@ -12,25 +12,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the control flow primitives."""
from jax._src.lax.control_flow.loops import (associative_scan, cummax, cummax_p,
cummin, cummin_p, cumlogsumexp,
cumlogsumexp_p, cumprod,
cumprod_p, cumsum, cumsum_p,
cumred_reduce_window_impl,
fori_loop, map,
scan, scan_bind, scan_p,
_scan_impl, while_loop, while_p)
from jax._src.lax.control_flow.conditionals import (cond, cond_p, switch,
platform_dependent)
from jax._src.lax.control_flow.solves import (custom_linear_solve, custom_root,
_custom_linear_solve_impl,
linear_solve_p)
from jax._src.lax.control_flow.loops import (
associative_scan as associative_scan,
cummax as cummax,
cummax_p as cummax_p,
cummin as cummin,
cummin_p as cummin_p,
cumlogsumexp as cumlogsumexp,
cumlogsumexp_p as cumlogsumexp_p,
cumprod as cumprod,
cumprod_p as cumprod_p,
cumsum as cumsum,
cumsum_p as cumsum_p,
cumred_reduce_window_impl as cumred_reduce_window_impl,
fori_loop as fori_loop,
map as map,
scan as scan,
scan_bind as scan_bind,
scan_p as scan_p,
_scan_impl as _scan_impl,
while_loop as while_loop,
while_p as while_p,
)
from jax._src.lax.control_flow.conditionals import (
cond as cond,
cond_p as cond_p,
switch as switch,
platform_dependent as platform_dependent,
)
from jax._src.lax.control_flow.solves import (
custom_linear_solve as custom_linear_solve,
custom_root as custom_root,
_custom_linear_solve_impl as _custom_linear_solve_impl,
linear_solve_p as linear_solve_p,
)
# Private utilities used elsewhere in JAX
# TODO(sharadmv): lift them into a more common place
from jax._src.lax.control_flow.common import (_initial_style_open_jaxpr,
_initial_style_jaxpr,
_initial_style_jaxprs_with_common_consts,
_check_tree_and_avals)
from jax._src.lax.control_flow.common import (
_initial_style_open_jaxpr as _initial_style_open_jaxpr,
_initial_style_jaxpr as _initial_style_jaxpr,
_initial_style_jaxprs_with_common_consts as _initial_style_jaxprs_with_common_consts,
_check_tree_and_avals as _check_tree_and_avals,
)
# TODO(mattjj): fix dependent library which expects optimization_barrier_p here
from jax._src.lax.lax import optimization_barrier_p
from jax._src.lax.lax import optimization_barrier_p as optimization_barrier_p

View File

@ -21,7 +21,6 @@ import gc
import os
import pathlib
import re
from typing import Any
try:
import jaxlib as jaxlib
@ -84,9 +83,9 @@ version = check_jaxlib_version(
import jaxlib.cpu_feature_guard as cpu_feature_guard
cpu_feature_guard.check_cpu_features()
import jaxlib.utils as utils
import jaxlib.utils as utils # noqa: F401
import jaxlib.xla_client as xla_client
import jaxlib.lapack as lapack
import jaxlib.lapack as lapack # noqa: F401
xla_extension = xla_client._xla
pytree = xla_client._xla.pytree
@ -99,18 +98,18 @@ def _xla_gc_callback(*args):
gc.callbacks.append(_xla_gc_callback)
try:
import jaxlib.cuda._versions as cuda_versions # pytype: disable=import-error
import jaxlib.cuda._versions as cuda_versions # pytype: disable=import-error # noqa: F401
except ImportError:
try:
import jax_cuda12_plugin._versions as cuda_versions # pytype: disable=import-error
import jax_cuda12_plugin._versions as cuda_versions # pytype: disable=import-error # noqa: F401
except ImportError:
cuda_versions = None
import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error
import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error
import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error
import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error
import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error # noqa: F401
import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error # noqa: F401
import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error # noqa: F401
import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error # noqa: F401
import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error # noqa: F401
# Jaxlib code is split between the Jax and the Tensorflow repositories.
# Only for the internal usage of the JAX developers, we expose a version
@ -118,10 +117,10 @@ import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error
# branch on the Jax github.
xla_extension_version: int = getattr(xla_client, '_version', 0)
import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error
import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error # noqa: F401
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401
import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error
import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error # noqa: F401
# Version number for MLIR:Python APIs, provided by jaxlib.
mlir_api_version = xla_client.mlir_api_version

View File

@ -13,14 +13,14 @@
# limitations under the License.
"""Module for state."""
from jax._src.state.types import (
AbstractRef,
AccumEffect,
ReadEffect,
RefEffect,
StateEffect,
Transform,
TransformedRef,
WriteEffect,
get_ref_state_effects,
shaped_array_ref,
AbstractRef as AbstractRef,
AccumEffect as AccumEffect,
ReadEffect as ReadEffect,
RefEffect as RefEffect,
StateEffect as StateEffect,
Transform as Transform,
TransformedRef as TransformedRef,
WriteEffect as WriteEffect,
get_ref_state_effects as get_ref_state_effects,
shaped_array_ref as shaped_array_ref,
)

View File

@ -13,14 +13,14 @@
# limitations under the License.
from jax._src.ad_checkpoint import (
checkpoint,
checkpoint_policies,
checkpoint_name,
print_saved_residuals,
remat,
checkpoint as checkpoint,
checkpoint_policies as checkpoint_policies,
checkpoint_name as checkpoint_name,
print_saved_residuals as print_saved_residuals,
remat as remat,
)
from jax._src.interpreters.partial_eval import (
Recompute,
Saveable,
Offloadable,
Recompute as Recompute,
Saveable as Saveable,
Offloadable as Offloadable,
)

View File

@ -14,12 +14,12 @@
from jax._src.api_util import (
argnums_partial,
donation_vector,
flatten_axes,
flatten_fun,
flatten_fun_nokwargs,
rebase_donate_argnums,
safe_map,
shaped_abstractify,
argnums_partial as argnums_partial,
donation_vector as donation_vector,
flatten_axes as flatten_axes,
flatten_fun as flatten_fun,
flatten_fun_nokwargs as flatten_fun_nokwargs,
rebase_donate_argnums as rebase_donate_argnums,
safe_map as safe_map,
shaped_abstractify as shaped_abstractify,
)

View File

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.cloud_tpu_init import cloud_tpu_init
from jax._src.cloud_tpu_init import cloud_tpu_init as cloud_tpu_init

View File

@ -42,7 +42,7 @@ from jax._src.core import (
Literal as Literal,
MainTrace as MainTrace,
MapPrimitive as MapPrimitive,
nonempty_axis_env as nonempty_axis_env_DO_NOT_USE,
nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401
OpaqueTraceState as OpaqueTraceState,
NameGatheringSubst as NameGatheringSubst,
OutDBIdx as OutDBIdx,
@ -58,9 +58,9 @@ from jax._src.core import (
TraceStack as TraceStack,
TraceState as TraceState,
Tracer as Tracer,
unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE,
unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE,
unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE,
unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401
unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, # noqa: F401
unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, # noqa: F401
UnshapedArray as UnshapedArray,
Value as Value,
Var as Var,

View File

@ -16,9 +16,9 @@
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.custom_derivatives import (
_initial_style_jaxpr,
_sum_tangents,
_zeros_like_pytree,
_initial_style_jaxpr, # noqa: F401
_sum_tangents, # noqa: F401
_zeros_like_pytree, # noqa: F401
closure_convert as closure_convert,
custom_gradient as custom_gradient,
custom_jvp as custom_jvp,

View File

@ -18,10 +18,10 @@
from jax._src.dtypes import (
bfloat16 as bfloat16,
canonicalize_dtype as canonicalize_dtype,
finfo, # TODO(phawkins): switch callers to jnp.finfo?
finfo, # TODO(phawkins): switch callers to jnp.finfo? # noqa: F401
float0 as float0,
iinfo, # TODO(phawkins): switch callers to jnp.iinfo?
issubdtype, # TODO(phawkins): switch callers to jnp.issubdtype?
iinfo, # TODO(phawkins): switch callers to jnp.iinfo? # noqa: F401
issubdtype, # TODO(phawkins): switch callers to jnp.issubdtype? # noqa: F401
extended as extended,
prng_key as prng_key,
result_type as result_type,

View File

@ -29,11 +29,9 @@ from typing import Any, Optional
import jax
from jax._src import array
from jax._src import config
from jax._src import distributed
from jax._src import sharding
from jax._src import sharding_impls
from jax._src.layout import Layout, DeviceLocalLayout as DLL
from jax._src.layout import Layout
from jax._src import typing
from jax._src import util
from jax._src.lib import xla_extension as xe

View File

@ -27,7 +27,6 @@ import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src import array
from jax._src import xla_bridge as xb
from jax.sharding import NamedSharding, GSPMDSharding, SingleDeviceSharding
from jax.sharding import PartitionSpec as P
from jax.experimental.array_serialization import serialization

View File

@ -30,7 +30,6 @@ from typing import Any
import warnings
from absl import flags
import flax
from flax import linen as nn
import jax

View File

@ -39,8 +39,8 @@ from absl import logging
import numpy.random as npr
import jax # Must import before TF
from jax.experimental import jax2tf # Defines needed flags
from jax._src import test_util # Defines needed flags
from jax.experimental import jax2tf # Defines needed flags # noqa: F401
from jax._src import test_util # Defines needed flags # noqa: F401
jax.config.parse_flags_with_absl()

View File

@ -28,7 +28,6 @@ import unittest
from absl.testing import absltest
import jax
from jax._src import config
from jax._src import test_util as jtu

View File

@ -67,7 +67,6 @@ from jax._src import config
from jax._src import test_util as jtu
from jax.experimental import jax2tf
from jax.interpreters import mlir
from jax._src.interpreters import xla
import numpy as np
import tensorflow as tf

View File

@ -45,7 +45,6 @@ from jax._src import util
from jax._src.export import shape_poly
from jax._src.lax import lax as lax_internal
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lib import xla_client
import numpy as np
from jax.experimental.jax2tf.tests import tf_test_util

View File

@ -22,8 +22,6 @@ from collections.abc import Sequence
import contextlib
from functools import partial
import logging
import math
import os
import re
from typing import Any
import unittest

View File

@ -16,7 +16,7 @@ from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable, Iterator
from functools import partial, reduce, total_ordering, wraps
from functools import partial, reduce, total_ordering
from typing import Any, NamedTuple
import jax
@ -25,7 +25,6 @@ from jax import tree_util
from jax.errors import KeyReuseError
from jax.interpreters import batching, mlir
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import linear_util as lu
from jax._src import pjit

View File

@ -14,5 +14,5 @@
# ==============================================================================
"""Contains bindings for Mosaic."""
from jax._src.tpu_custom_call import as_tpu_kernel
from jax._src.tpu_custom_call import lower_module_to_custom_call
from jax._src.tpu_custom_call import as_tpu_kernel as as_tpu_kernel
from jax._src.tpu_custom_call import lower_module_to_custom_call as lower_module_to_custom_call

View File

@ -14,4 +14,4 @@
# ==============================================================================
"""Contains bindings for Mosaic MLIR dialects."""
from jax._src.lib import tpu
from jax._src.lib import tpu as tpu

View File

@ -13,55 +13,55 @@
# limitations under the License.
# ==============================================================================
from jax import ShapeDtypeStruct
from jax import ShapeDtypeStruct as ShapeDtypeStruct
from .core import (
Barrier,
ClusterBarrier,
LaunchContext,
MemRefTransform,
TMABarrier,
TileTransform,
TransposeTransform,
Union,
as_gpu_kernel,
Barrier as Barrier,
ClusterBarrier as ClusterBarrier,
LaunchContext as LaunchContext,
MemRefTransform as MemRefTransform,
TMABarrier as TMABarrier,
TileTransform as TileTransform,
TransposeTransform as TransposeTransform,
Union as Union,
as_gpu_kernel as as_gpu_kernel,
)
from .fragmented_array import (
FragmentedArray,
FragmentedLayout,
WGMMA_LAYOUT,
WGMMA_ROW_LAYOUT,
WGSplatFragLayout,
WGStridedFragLayout,
FragmentedArray as FragmentedArray,
FragmentedLayout as FragmentedLayout,
WGMMA_LAYOUT as WGMMA_LAYOUT,
WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT,
WGSplatFragLayout as WGSplatFragLayout,
WGStridedFragLayout as WGStridedFragLayout,
)
from .utils import (
BarrierRef,
CollectiveBarrierRef,
DynamicSlice,
Partition,
Partition1D,
bytewidth,
c,
commit_shared,
debug_print,
ds,
fori,
memref_fold,
memref_slice,
memref_reshape,
memref_transpose,
memref_unfold,
memref_unsqueeze,
single_thread,
single_thread_predicate,
thread_idx,
tile_shape,
warp_idx,
warpgroup_barrier,
warpgroup_idx,
when,
BarrierRef as BarrierRef,
CollectiveBarrierRef as CollectiveBarrierRef,
DynamicSlice as DynamicSlice,
Partition as Partition,
Partition1D as Partition1D,
bytewidth as bytewidth,
c as c,
commit_shared as commit_shared,
debug_print as debug_print,
ds as ds,
fori as fori,
memref_fold as memref_fold,
memref_slice as memref_slice,
memref_reshape as memref_reshape,
memref_transpose as memref_transpose,
memref_unfold as memref_unfold,
memref_unsqueeze as memref_unsqueeze,
single_thread as single_thread,
single_thread_predicate as single_thread_predicate,
thread_idx as thread_idx,
tile_shape as tile_shape,
warp_idx as warp_idx,
warpgroup_barrier as warpgroup_barrier,
warpgroup_idx as warpgroup_idx,
when as when,
)
from .wgmma import (
WGMMAAccumulator,
WGMMALayout,
wgmma,
WGMMAAccumulator as WGMMAAccumulator,
WGMMALayout as WGMMALayout,
wgmma as wgmma,
)

View File

@ -17,21 +17,17 @@ import contextlib
import dataclasses
import enum
import itertools
import os
import warnings
from absl import app
import jax
from jax import random
from jax._src.interpreters import mlir
from jax._src import test_util as jtu
from jax.experimental.mosaic.gpu import profiler
from jax.experimental.mosaic.gpu import * # noqa: F403
import jax.numpy as jnp
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import gpu
from jaxlib.mlir.dialects import nvgpu
from jaxlib.mlir.dialects import nvvm
from jaxlib.mlir.dialects import scf
import numpy as np

View File

@ -16,7 +16,6 @@
import dataclasses
import itertools
import math
from typing import Any
import jax

View File

@ -20,7 +20,7 @@ import dataclasses
import enum
import functools
import math
from typing import Any, Literal, cast
from typing import Any, Literal
import jax
from jax import numpy as jnp

View File

@ -16,13 +16,12 @@
from __future__ import annotations
from functools import partial, lru_cache
from typing import Optional
import zlib
from typing import Any
import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import core
from jax._src.interpreters import ad
from jax._src.interpreters import batching

View File

@ -18,47 +18,47 @@ See the Pallas documentation at
https://jax.readthedocs.io/en/latest/pallas.html.
"""
from jax._src.pallas.core import Blocked
from jax._src.pallas.core import BlockSpec
from jax._src.pallas.core import CompilerParams
from jax._src.pallas.core import core_map
from jax._src.pallas.core import CostEstimate
from jax._src.pallas.core import GridSpec
from jax._src.pallas.core import IndexingMode
from jax._src.pallas.core import MemorySpace
from jax._src.pallas.core import MemoryRef
from jax._src.pallas.core import no_block_spec
from jax._src.pallas.core import Unblocked
from jax._src.pallas.core import unblocked
from jax._src.pallas.pallas_call import pallas_call
from jax._src.pallas.pallas_call import pallas_call_p
from jax._src.pallas.primitives import atomic_add
from jax._src.pallas.primitives import atomic_and
from jax._src.pallas.primitives import atomic_cas
from jax._src.pallas.primitives import atomic_max
from jax._src.pallas.primitives import atomic_min
from jax._src.pallas.primitives import atomic_or
from jax._src.pallas.primitives import atomic_xchg
from jax._src.pallas.primitives import atomic_xor
from jax._src.pallas.primitives import debug_print
from jax._src.pallas.primitives import dot
from jax._src.pallas.primitives import load
from jax._src.pallas.primitives import max_contiguous
from jax._src.pallas.primitives import multiple_of
from jax._src.pallas.primitives import num_programs
from jax._src.pallas.primitives import program_id
from jax._src.pallas.primitives import run_scoped
from jax._src.pallas.primitives import store
from jax._src.pallas.primitives import swap
from jax._src.pallas.utils import cdiv
from jax._src.pallas.utils import next_power_of_2
from jax._src.pallas.utils import strides_from_shape
from jax._src.pallas.utils import when
from jax._src.state.discharge import run_state
from jax._src.state.indexing import ds
from jax._src.state.indexing import dslice
from jax._src.state.indexing import Slice
from jax._src.state.primitives import broadcast_to
from jax._src.pallas.core import Blocked as Blocked
from jax._src.pallas.core import BlockSpec as BlockSpec
from jax._src.pallas.core import CompilerParams as CompilerParams
from jax._src.pallas.core import core_map as core_map
from jax._src.pallas.core import CostEstimate as CostEstimate
from jax._src.pallas.core import GridSpec as GridSpec
from jax._src.pallas.core import IndexingMode as IndexingMode
from jax._src.pallas.core import MemorySpace as MemorySpace
from jax._src.pallas.core import MemoryRef as MemoryRef
from jax._src.pallas.core import no_block_spec as no_block_spec
from jax._src.pallas.core import Unblocked as Unblocked
from jax._src.pallas.core import unblocked as unblocked
from jax._src.pallas.pallas_call import pallas_call as pallas_call
from jax._src.pallas.pallas_call import pallas_call_p as pallas_call_p
from jax._src.pallas.primitives import atomic_add as atomic_add
from jax._src.pallas.primitives import atomic_and as atomic_and
from jax._src.pallas.primitives import atomic_cas as atomic_cas
from jax._src.pallas.primitives import atomic_max as atomic_max
from jax._src.pallas.primitives import atomic_min as atomic_min
from jax._src.pallas.primitives import atomic_or as atomic_or
from jax._src.pallas.primitives import atomic_xchg as atomic_xchg
from jax._src.pallas.primitives import atomic_xor as atomic_xor
from jax._src.pallas.primitives import debug_print as debug_print
from jax._src.pallas.primitives import dot as dot
from jax._src.pallas.primitives import load as load
from jax._src.pallas.primitives import max_contiguous as max_contiguous
from jax._src.pallas.primitives import multiple_of as multiple_of
from jax._src.pallas.primitives import num_programs as num_programs
from jax._src.pallas.primitives import program_id as program_id
from jax._src.pallas.primitives import run_scoped as run_scoped
from jax._src.pallas.primitives import store as store
from jax._src.pallas.primitives import swap as swap
from jax._src.pallas.utils import cdiv as cdiv
from jax._src.pallas.utils import next_power_of_2 as next_power_of_2
from jax._src.pallas.utils import strides_from_shape as strides_from_shape
from jax._src.pallas.utils import when as when
from jax._src.state.discharge import run_state as run_state
from jax._src.state.indexing import ds as ds
from jax._src.state.indexing import dslice as dslice
from jax._src.state.indexing import Slice as Slice
from jax._src.state.primitives import broadcast_to as broadcast_to
ANY = MemorySpace.ANY

View File

@ -17,23 +17,23 @@
These APIs are highly unstable and can change weekly. Use at your own risk.
"""
from jax._src.pallas.mosaic_gpu.core import Barrier
from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec
from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
from jax._src.pallas.mosaic_gpu.core import SwizzleTransform
from jax._src.pallas.mosaic_gpu.core import TilingTransform
from jax._src.pallas.mosaic_gpu.core import TransposeTransform
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import barrier_wait
from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive
from jax._src.pallas.mosaic_gpu.primitives import set_max_registers
from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import wgmma
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait
from jax._src.pallas.mosaic_gpu.core import Barrier as Barrier
from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec as GPUBlockSpec
from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams as GPUCompilerParams
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace as GPUMemorySpace
from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform
from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform
from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait
from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive
from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers
from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait as wgmma_wait
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM`.
GMEM = GPUMemorySpace.GMEM

View File

@ -18,8 +18,6 @@ from __future__ import annotations
import functools
from typing import Optional
import jax
from jax import lax
import jax.numpy as jnp

View File

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from jax.experimental.pallas.ops.tpu.megablox.ops import gmm
from jax.experimental.pallas.ops.tpu.megablox.ops import gmm as gmm

View File

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as paged_attention

View File

@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import BlockSizes
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_masked_mha_reference
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_masked_mqa_reference
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mha
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mha_single_device
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mqa
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mqa_single_device
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import QKVLayout
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import SegmentIds
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import CausalMask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import FullMask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import LocalMask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_causal_mask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_local_attention_mask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_random_mask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import Mask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import MultiHeadMask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import NumpyMask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import BlockSizes as BlockSizes
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_masked_mha_reference as make_masked_mha_reference
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_masked_mqa_reference as make_masked_mqa_reference
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mha as make_splash_mha
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mha_single_device as make_splash_mha_single_device
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mqa as make_splash_mqa
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mqa_single_device as make_splash_mqa_single_device
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import QKVLayout as QKVLayout
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import SegmentIds as SegmentIds
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import CausalMask as CausalMask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import FullMask as FullMask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import LocalMask as LocalMask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_causal_mask as make_causal_mask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_local_attention_mask as make_local_attention_mask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_random_mask as make_random_mask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import Mask as Mask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import MultiHeadMask as MultiHeadMask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import NumpyMask as NumpyMask

View File

@ -18,7 +18,7 @@ from __future__ import annotations
import collections
from collections.abc import Callable
import functools
from typing import Dict, List, NamedTuple, Set, Tuple
from typing import NamedTuple
from jax import util as jax_util
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib
import numpy as np

View File

@ -14,47 +14,47 @@
"""Mosaic-specific Pallas APIs."""
from jax._src.pallas.mosaic import core
from jax._src.pallas.mosaic.core import create_tensorcore_mesh
from jax._src.pallas.mosaic.core import dma_semaphore
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
from jax._src.pallas.mosaic.core import semaphore
from jax._src.pallas.mosaic.core import SemaphoreType
from jax._src.pallas.mosaic.core import TPUMemorySpace
from jax._src.pallas.mosaic.core import TPUCompilerParams
from jax._src.pallas.mosaic.core import runtime_assert_enabled
from jax._src.pallas.mosaic.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert
from jax._src.pallas.mosaic.lowering import LoweringException
from jax._src.pallas.mosaic.pipeline import ARBITRARY
from jax._src.pallas.mosaic.pipeline import BufferedRef
from jax._src.pallas.mosaic.pipeline import emit_pipeline
from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations
from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule
from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations
from jax._src.pallas.mosaic.pipeline import PARALLEL
from jax._src.pallas.mosaic.primitives import async_copy
from jax._src.pallas.mosaic.primitives import async_remote_copy
from jax._src.pallas.mosaic.primitives import bitcast
from jax._src.pallas.mosaic.primitives import delay
from jax._src.pallas.mosaic.primitives import device_id
from jax._src.pallas.mosaic.primitives import DeviceIdType
from jax._src.pallas.mosaic.primitives import get_barrier_semaphore
from jax._src.pallas.mosaic.primitives import make_async_copy
from jax._src.pallas.mosaic.primitives import make_async_remote_copy
from jax._src.pallas.mosaic.primitives import prng_random_bits
from jax._src.pallas.mosaic.primitives import prng_seed
from jax._src.pallas.mosaic.primitives import repeat
from jax._src.pallas.mosaic.primitives import roll
from jax._src.pallas.mosaic.primitives import semaphore_read
from jax._src.pallas.mosaic.primitives import semaphore_signal
from jax._src.pallas.mosaic.primitives import semaphore_wait
from jax._src.pallas.mosaic.random import to_pallas_key
from jax._src.pallas.mosaic import core as core
from jax._src.pallas.mosaic.core import create_tensorcore_mesh as create_tensorcore_mesh
from jax._src.pallas.mosaic.core import dma_semaphore as dma_semaphore
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec as PrefetchScalarGridSpec
from jax._src.pallas.mosaic.core import semaphore as semaphore
from jax._src.pallas.mosaic.core import SemaphoreType as SemaphoreType
from jax._src.pallas.mosaic.core import TPUMemorySpace as TPUMemorySpace
from jax._src.pallas.mosaic.core import TPUCompilerParams as TPUCompilerParams
from jax._src.pallas.mosaic.core import runtime_assert_enabled as runtime_assert_enabled
from jax._src.pallas.mosaic.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert # noqa: F401
from jax._src.pallas.mosaic.lowering import LoweringException as LoweringException
from jax._src.pallas.mosaic.pipeline import ARBITRARY as ARBITRARY
from jax._src.pallas.mosaic.pipeline import BufferedRef as BufferedRef
from jax._src.pallas.mosaic.pipeline import emit_pipeline as emit_pipeline
from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations as emit_pipeline_with_allocations
from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule as get_pipeline_schedule
from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations as make_pipeline_allocations
from jax._src.pallas.mosaic.pipeline import PARALLEL as PARALLEL
from jax._src.pallas.mosaic.primitives import async_copy as async_copy
from jax._src.pallas.mosaic.primitives import async_remote_copy as async_remote_copy
from jax._src.pallas.mosaic.primitives import bitcast as bitcast
from jax._src.pallas.mosaic.primitives import delay as delay
from jax._src.pallas.mosaic.primitives import device_id as device_id
from jax._src.pallas.mosaic.primitives import DeviceIdType as DeviceIdType
from jax._src.pallas.mosaic.primitives import get_barrier_semaphore as get_barrier_semaphore
from jax._src.pallas.mosaic.primitives import make_async_copy as make_async_copy
from jax._src.pallas.mosaic.primitives import make_async_remote_copy as make_async_remote_copy
from jax._src.pallas.mosaic.primitives import prng_random_bits as prng_random_bits
from jax._src.pallas.mosaic.primitives import prng_seed as prng_seed
from jax._src.pallas.mosaic.primitives import repeat as repeat
from jax._src.pallas.mosaic.primitives import roll as roll
from jax._src.pallas.mosaic.primitives import semaphore_read as semaphore_read
from jax._src.pallas.mosaic.primitives import semaphore_signal as semaphore_signal
from jax._src.pallas.mosaic.primitives import semaphore_wait as semaphore_wait
from jax._src.pallas.mosaic.random import to_pallas_key as to_pallas_key
# Remove this import after October 22th 2024.
from jax._src.tpu_custom_call import CostEstimate
from jax._src.tpu_custom_call import CostEstimate as CostEstimate
# TODO(cperivol): Temporary alias to the global run_scoped. Remove
# this once everyone has migrated to the pallas core one.
from jax._src.pallas.primitives import run_scoped
from jax._src.pallas.primitives import run_scoped as run_scoped
import types
from jax._src.pallas.mosaic.verification import assume

View File

@ -14,7 +14,7 @@
"""Triton-specific Pallas APIs."""
from jax._src.pallas.triton.core import TritonCompilerParams
from jax._src.pallas.triton.primitives import approx_tanh
from jax._src.pallas.triton.primitives import debug_barrier
from jax._src.pallas.triton.primitives import elementwise_inline_asm
from jax._src.pallas.triton.core import TritonCompilerParams as TritonCompilerParams
from jax._src.pallas.triton.primitives import approx_tanh as approx_tanh
from jax._src.pallas.triton.primitives import debug_barrier as debug_barrier
from jax._src.pallas.triton.primitives import elementwise_inline_asm as elementwise_inline_asm

View File

@ -17,7 +17,6 @@ from __future__ import annotations
import pickle
import io
from typing import Optional, Union
import jax
from jax._src.lib import xla_client as xc

View File

@ -52,7 +52,7 @@ from jax._src.lax import (lax, parallel as lax_parallel, slicing,
windowed_reductions, convolution, fft, linalg,
special, control_flow, ann)
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo, sdy
from jax._src.lib.mlir.dialects import sdy
from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3,
as_hashable_function, memoize, partition_list,
merge_lists, split_list, subs_list2)

View File

@ -34,7 +34,6 @@ from __future__ import annotations
from functools import partial
import operator
from typing import Optional, Union
import jax
from jax import tree_util

View File

@ -49,9 +49,7 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.lax.lax import (
_const, ranges_like, remaining, _dot_general_batch_dim_nums, DotDimensionNumbers)
from jax._src.lax.slicing import GatherDimensionNumbers, GatherScatterMode
from jax._src.lib.mlir import ir
from jax._src.lib import gpu_sparse
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.setops import _unique
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.util import canonicalize_axis

View File

@ -31,8 +31,7 @@ from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse import bcoo
from jax.experimental.sparse.util import (
nfold_vmap, _count_stored_elements,
_csr_to_coo, _dot_general_validated_shape,
CuSparseEfficiencyWarning, SparseInfo, Shape)
_csr_to_coo, CuSparseEfficiencyWarning, SparseInfo, Shape)
from jax.util import split_list, safe_zip
from jax._src import api_util

View File

@ -17,7 +17,6 @@ from __future__ import annotations
from functools import partial
import operator
from typing import Optional
import warnings
import numpy as np
@ -35,7 +34,7 @@ from jax._src.interpreters import ad
from jax._src.lax.lax import _const
from jax._src.lib import gpu_sparse
from jax._src.numpy.util import promote_dtypes
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.typing import Array, DTypeLike
import jax.numpy as jnp

View File

@ -26,7 +26,6 @@ from jax import lax
from jax import tree_util
from jax._src import test_util as jtu
from jax._src.lax.lax import DotDimensionNumbers
from jax._src.lib import gpu_sparse
from jax._src.typing import DTypeLike
from jax.experimental import sparse
import jax.numpy as jnp

View File

@ -15,7 +15,7 @@
"""Sparse utilities."""
import functools
from typing import Any, NamedTuple, Union
from typing import NamedTuple
import numpy as np
import jax
@ -23,8 +23,6 @@ from jax import lax
from jax import tree_util
from jax import vmap
from jax._src import core
from jax._src import dtypes
from jax._src import stages
from jax._src.api_util import flatten_axes
import jax.numpy as jnp
from jax.util import safe_zip

View File

@ -14,9 +14,7 @@
from __future__ import annotations
import abc
from collections.abc import Sequence
from typing import Optional
import jax
from jax.experimental import mesh_utils

View File

@ -27,7 +27,7 @@ from jax._src.interpreters.mlir import (
Token as Token,
TokenSet as TokenSet,
Value as Value,
call_lowering as _call_lowering,
call_lowering as _call_lowering, # noqa: F401
_lowerings as _lowerings,
_platform_specific_lowerings as _platform_specific_lowerings,
aval_to_ir_type as aval_to_ir_type,
@ -41,7 +41,7 @@ from jax._src.interpreters.mlir import (
dtype_to_ir_type as dtype_to_ir_type,
emit_python_callback as emit_python_callback,
flatten_ir_types as flatten_ir_types,
flatten_ir_values as flatten_lowering_ir_args, # TODO(phawkins): remove me
flatten_ir_values as flatten_lowering_ir_args, # TODO(phawkins): remove me # noqa: F401
flatten_ir_values as flatten_ir_values,
unflatten_ir_values_like_types as unflatten_ir_values_like_types,
func_dialect as func_dialect,

View File

@ -38,9 +38,9 @@ from jax._src.op_shardings import (
from jax._src.sharding_impls import (
ArrayMapping as ArrayMapping,
UNSPECIFIED as _UNSPECIFIED,
UNSPECIFIED as _UNSPECIFIED, # noqa: F401
array_mapping_to_axis_resources as array_mapping_to_axis_resources,
is_unspecified as _is_unspecified,
is_unspecified as _is_unspecified, # noqa: F401
)
from jax._src.sharding_specs import (

View File

@ -13,31 +13,31 @@
# limitations under the License.
from jax._src.lax.linalg import (
cholesky,
cholesky_p,
eig,
eig_p,
eigh,
eigh_p,
hessenberg,
hessenberg_p,
lu,
lu_p,
lu_pivots_to_permutation,
householder_product,
householder_product_p,
qr,
qr_p,
svd,
svd_p,
triangular_solve,
triangular_solve_p,
tridiagonal,
tridiagonal_p,
tridiagonal_solve,
tridiagonal_solve_p,
schur,
schur_p
cholesky as cholesky,
cholesky_p as cholesky_p,
eig as eig,
eig_p as eig_p,
eigh as eigh,
eigh_p as eigh_p,
hessenberg as hessenberg,
hessenberg_p as hessenberg_p,
lu as lu,
lu_p as lu_p,
lu_pivots_to_permutation as lu_pivots_to_permutation,
householder_product as householder_product,
householder_product_p as householder_product_p,
qr as qr,
qr_p as qr_p,
svd as svd,
svd_p as svd_p,
triangular_solve as triangular_solve,
triangular_solve_p as triangular_solve_p,
tridiagonal as tridiagonal,
tridiagonal_p as tridiagonal_p,
tridiagonal_solve as tridiagonal_solve,
tridiagonal_solve_p as tridiagonal_solve_p,
schur as schur,
schur_p as schur_p,
)

View File

@ -18,7 +18,7 @@
from jax.numpy import fft as fft
from jax.numpy import linalg as linalg
from jax._src.basearray import Array as ndarray
from jax._src.basearray import Array as ndarray # noqa: F401
from jax._src.dtypes import (
isdtype as isdtype,
@ -53,7 +53,7 @@ from jax._src.numpy.lax_numpy import (
bincount as bincount,
blackman as blackman,
block as block,
bool_ as bool, # Array API alias for bool_
bool_ as bool, # Array API alias for bool_ # noqa: F401
bool_ as bool_,
broadcast_arrays as broadcast_arrays,
broadcast_shapes as broadcast_shapes,

View File

@ -13,6 +13,6 @@
# limitations under the License.
from jax._src.scipy.stats.nbinom import (
logpmf,
pmf,
logpmf as logpmf,
pmf as pmf,
)

View File

@ -17,7 +17,6 @@
from jax._src.sharding import Sharding as Sharding
from jax._src.sharding_impls import (
XLACompatibleSharding as _deprecated_XLACompatibleSharding,
NamedSharding as NamedSharding,
SingleDeviceSharding as SingleDeviceSharding,
PmapSharding as PmapSharding,
@ -28,7 +27,7 @@ from jax._src.partition_spec import (
PartitionSpec as PartitionSpec,
)
from jax._src.interpreters.pxla import Mesh as Mesh
from jax._src.mesh import AbstractMesh
from jax._src.mesh import AbstractMesh as AbstractMesh
_deprecations = {
# Finalized 2024-10-01; remove after 2025-01-01.

View File

@ -17,8 +17,6 @@ import importlib
import logging
import os
import pathlib
import platform
import sys
from jax._src.lib import xla_client
import jax._src.xla_bridge as xb

View File

@ -17,7 +17,6 @@ import importlib
import logging
import os
import pathlib
import platform
from jax._src.lib import xla_client
import jax._src.xla_bridge as xb

View File

@ -133,42 +133,3 @@ max-complexity = 18
"docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb" = ["F811"]
"docs/jep/9407-type-promotion.ipynb" = ["F811"]
"docs/autodidax.ipynb" = ["F811"]
# Note: we don't use jax/*.py because this matches contents of jax/_src
"__init__.py" = ["F401"]
"jax/abstract_arrays.py" = ["F401"]
"jax/ad_checkpoint.py" = ["F401"]
"jax/api_util.py" = ["F401"]
"jax/cloud_tpu_init.py" = ["F401"]
"jax/core.py" = ["F401"]
"jax/custom_batching.py" = ["F401"]
"jax/custom_derivatives.py" = ["F401"]
"jax/custom_transpose.py" = ["F401"]
"jax/debug.py" = ["F401"]
"jax/distributed.py" = ["F401"]
"jax/dlpack.py" = ["F401"]
"jax/dtypes.py" = ["F401"]
"jax/errors.py" = ["F401"]
"jax/experimental/*.py" = ["F401"]
"jax/extend/*.py" = ["F401"]
"jax/flatten_util.py" = ["F401"]
"jax/interpreters/ad.py" = ["F401"]
"jax/interpreters/batching.py" = ["F401"]
"jax/interpreters/mlir.py" = ["F401"]
"jax/interpreters/partial_eval.py" = ["F401"]
"jax/interpreters/pxla.py" = ["F401"]
"jax/interpreters/xla.py" = ["F401"]
"jax/lax/*.py" = ["F401"]
"jax/linear_util.py" = ["F401"]
"jax/monitoring.py" = ["F401"]
"jax/nn/*.py" = ["F401"]
"jax/numpy/*.py" = ["F401"]
"jax/prng.py" = ["F401"]
"jax/profiler.py" = ["F401"]
"jax/random.py" = ["F401"]
"jax/scipy/*.py" = ["F401"]
"jax/sharding.py" = ["F401"]
"jax/stages.py" = ["F401"]
"jax/test_util.py" = ["F401"]
"jax/tree_util.py" = ["F401"]
"jax/typing.py" = ["F401"]
"jax/util.py" = ["F401"]