diff --git a/jax/__init__.py b/jax/__init__.py index c6e073699..4f5c256b0 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -81,7 +81,7 @@ del _xc from jax._src.api import effects_barrier as effects_barrier from jax._src.api import block_until_ready as block_until_ready -from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint +from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401 from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies from jax._src.api import clear_backends as _deprecated_clear_backends from jax._src.api import clear_caches as clear_caches @@ -122,7 +122,7 @@ from jax._src.xla_bridge import process_count as process_count from jax._src.xla_bridge import process_index as process_index from jax._src.xla_bridge import process_indices as process_indices from jax._src.callback import pure_callback as pure_callback -from jax._src.ad_checkpoint import checkpoint_wrapper as remat +from jax._src.ad_checkpoint import checkpoint_wrapper as remat # noqa: F401 from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct from jax._src.api import value_and_grad as value_and_grad from jax._src.api import vjp as vjp @@ -249,6 +249,6 @@ else: del _deprecation_getattr del _typing -import jax.lib # TODO(phawkins): remove this export. +import jax.lib # TODO(phawkins): remove this export. # noqa: F401 # trailer diff --git a/jax/_src/clusters/__init__.py b/jax/_src/clusters/__init__.py index 9abb628f8..a01869de9 100644 --- a/jax/_src/clusters/__init__.py +++ b/jax/_src/clusters/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .cluster import ClusterEnv +from .cluster import ClusterEnv as ClusterEnv # Order of declaration of the cluster environments # will dictate the order in which they will be checked. @@ -20,9 +20,9 @@ from .cluster import ClusterEnv # the user did not explicitly provide the arguments # to :func:`jax.distributed.initialize`, the first # available one from the list will be picked. -from .ompi_cluster import OmpiCluster -from .slurm_cluster import SlurmCluster -from .mpi4py_cluster import Mpi4pyCluster -from .cloud_tpu_cluster import GkeTpuCluster -from .cloud_tpu_cluster import GceTpuCluster -from .k8s_cluster import K8sCluster +from .ompi_cluster import OmpiCluster as OmpiCluster +from .slurm_cluster import SlurmCluster as SlurmCluster +from .mpi4py_cluster import Mpi4pyCluster as Mpi4pyCluster +from .cloud_tpu_cluster import GkeTpuCluster as GkeTpuCluster +from .cloud_tpu_cluster import GceTpuCluster as GceTpuCluster +from .k8s_cluster import K8sCluster as K8sCluster diff --git a/jax/_src/cudnn/__init__.py b/jax/_src/cudnn/__init__.py index 23d1fa28f..d182269ea 100644 --- a/jax/_src/cudnn/__init__.py +++ b/jax/_src/cudnn/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .fusion import cudnn_fusion +from .fusion import cudnn_fusion as cudnn_fusion diff --git a/jax/_src/debugger/__init__.py b/jax/_src/debugger/__init__.py index 5e367d3fc..c765685cb 100644 --- a/jax/_src/debugger/__init__.py +++ b/jax/_src/debugger/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.debugger.core import breakpoint +from jax._src.debugger.core import breakpoint as breakpoint from jax._src.debugger import cli_debugger from jax._src.debugger import colab_debugger from jax._src.debugger import web_debugger diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index 5e6fa86f7..db03143f1 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -12,25 +12,48 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for the control flow primitives.""" -from jax._src.lax.control_flow.loops import (associative_scan, cummax, cummax_p, - cummin, cummin_p, cumlogsumexp, - cumlogsumexp_p, cumprod, - cumprod_p, cumsum, cumsum_p, - cumred_reduce_window_impl, - fori_loop, map, - scan, scan_bind, scan_p, - _scan_impl, while_loop, while_p) -from jax._src.lax.control_flow.conditionals import (cond, cond_p, switch, - platform_dependent) -from jax._src.lax.control_flow.solves import (custom_linear_solve, custom_root, - _custom_linear_solve_impl, - linear_solve_p) - +from jax._src.lax.control_flow.loops import ( + associative_scan as associative_scan, + cummax as cummax, + cummax_p as cummax_p, + cummin as cummin, + cummin_p as cummin_p, + cumlogsumexp as cumlogsumexp, + cumlogsumexp_p as cumlogsumexp_p, + cumprod as cumprod, + cumprod_p as cumprod_p, + cumsum as cumsum, + cumsum_p as cumsum_p, + cumred_reduce_window_impl as cumred_reduce_window_impl, + fori_loop as fori_loop, + map as map, + scan as scan, + scan_bind as scan_bind, + scan_p as scan_p, + _scan_impl as _scan_impl, + while_loop as while_loop, + while_p as while_p, +) +from jax._src.lax.control_flow.conditionals import ( + cond as cond, + cond_p as cond_p, + switch as switch, + platform_dependent as platform_dependent, +) +from jax._src.lax.control_flow.solves import ( + custom_linear_solve as custom_linear_solve, + custom_root as custom_root, + _custom_linear_solve_impl as _custom_linear_solve_impl, + linear_solve_p as linear_solve_p, +) # Private utilities used elsewhere in JAX # TODO(sharadmv): lift them into a more common place -from jax._src.lax.control_flow.common import (_initial_style_open_jaxpr, - _initial_style_jaxpr, - _initial_style_jaxprs_with_common_consts, - _check_tree_and_avals) +from jax._src.lax.control_flow.common import ( + _initial_style_open_jaxpr as _initial_style_open_jaxpr, + _initial_style_jaxpr as _initial_style_jaxpr, + _initial_style_jaxprs_with_common_consts as _initial_style_jaxprs_with_common_consts, + _check_tree_and_avals as _check_tree_and_avals, + +) # TODO(mattjj): fix dependent library which expects optimization_barrier_p here -from jax._src.lax.lax import optimization_barrier_p +from jax._src.lax.lax import optimization_barrier_p as optimization_barrier_p diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index b72c6ee46..68a0f1553 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -21,7 +21,6 @@ import gc import os import pathlib import re -from typing import Any try: import jaxlib as jaxlib @@ -84,9 +83,9 @@ version = check_jaxlib_version( import jaxlib.cpu_feature_guard as cpu_feature_guard cpu_feature_guard.check_cpu_features() -import jaxlib.utils as utils +import jaxlib.utils as utils # noqa: F401 import jaxlib.xla_client as xla_client -import jaxlib.lapack as lapack +import jaxlib.lapack as lapack # noqa: F401 xla_extension = xla_client._xla pytree = xla_client._xla.pytree @@ -99,18 +98,18 @@ def _xla_gc_callback(*args): gc.callbacks.append(_xla_gc_callback) try: - import jaxlib.cuda._versions as cuda_versions # pytype: disable=import-error + import jaxlib.cuda._versions as cuda_versions # pytype: disable=import-error # noqa: F401 except ImportError: try: - import jax_cuda12_plugin._versions as cuda_versions # pytype: disable=import-error + import jax_cuda12_plugin._versions as cuda_versions # pytype: disable=import-error # noqa: F401 except ImportError: cuda_versions = None -import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error -import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error -import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error -import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error -import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error +import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error # noqa: F401 +import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error # noqa: F401 +import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error # noqa: F401 +import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error # noqa: F401 +import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error # noqa: F401 # Jaxlib code is split between the Jax and the Tensorflow repositories. # Only for the internal usage of the JAX developers, we expose a version @@ -118,10 +117,10 @@ import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error # branch on the Jax github. xla_extension_version: int = getattr(xla_client, '_version', 0) -import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error -import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error +import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error # noqa: F401 +import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401 -import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error +import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error # noqa: F401 # Version number for MLIR:Python APIs, provided by jaxlib. mlir_api_version = xla_client.mlir_api_version diff --git a/jax/_src/state/__init__.py b/jax/_src/state/__init__.py index 2f1c88be4..38710d9db 100644 --- a/jax/_src/state/__init__.py +++ b/jax/_src/state/__init__.py @@ -13,14 +13,14 @@ # limitations under the License. """Module for state.""" from jax._src.state.types import ( - AbstractRef, - AccumEffect, - ReadEffect, - RefEffect, - StateEffect, - Transform, - TransformedRef, - WriteEffect, - get_ref_state_effects, - shaped_array_ref, + AbstractRef as AbstractRef, + AccumEffect as AccumEffect, + ReadEffect as ReadEffect, + RefEffect as RefEffect, + StateEffect as StateEffect, + Transform as Transform, + TransformedRef as TransformedRef, + WriteEffect as WriteEffect, + get_ref_state_effects as get_ref_state_effects, + shaped_array_ref as shaped_array_ref, ) diff --git a/jax/ad_checkpoint.py b/jax/ad_checkpoint.py index 9eda640be..44c13e379 100644 --- a/jax/ad_checkpoint.py +++ b/jax/ad_checkpoint.py @@ -13,14 +13,14 @@ # limitations under the License. from jax._src.ad_checkpoint import ( - checkpoint, - checkpoint_policies, - checkpoint_name, - print_saved_residuals, - remat, + checkpoint as checkpoint, + checkpoint_policies as checkpoint_policies, + checkpoint_name as checkpoint_name, + print_saved_residuals as print_saved_residuals, + remat as remat, ) from jax._src.interpreters.partial_eval import ( - Recompute, - Saveable, - Offloadable, + Recompute as Recompute, + Saveable as Saveable, + Offloadable as Offloadable, ) diff --git a/jax/api_util.py b/jax/api_util.py index 77b3ae36f..6cd731741 100644 --- a/jax/api_util.py +++ b/jax/api_util.py @@ -14,12 +14,12 @@ from jax._src.api_util import ( - argnums_partial, - donation_vector, - flatten_axes, - flatten_fun, - flatten_fun_nokwargs, - rebase_donate_argnums, - safe_map, - shaped_abstractify, + argnums_partial as argnums_partial, + donation_vector as donation_vector, + flatten_axes as flatten_axes, + flatten_fun as flatten_fun, + flatten_fun_nokwargs as flatten_fun_nokwargs, + rebase_donate_argnums as rebase_donate_argnums, + safe_map as safe_map, + shaped_abstractify as shaped_abstractify, ) diff --git a/jax/cloud_tpu_init.py b/jax/cloud_tpu_init.py index 8cc49ac3b..8b886eb6b 100644 --- a/jax/cloud_tpu_init.py +++ b/jax/cloud_tpu_init.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.cloud_tpu_init import cloud_tpu_init +from jax._src.cloud_tpu_init import cloud_tpu_init as cloud_tpu_init diff --git a/jax/core.py b/jax/core.py index 90ef668b2..9682d106e 100644 --- a/jax/core.py +++ b/jax/core.py @@ -42,7 +42,7 @@ from jax._src.core import ( Literal as Literal, MainTrace as MainTrace, MapPrimitive as MapPrimitive, - nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, + nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401 OpaqueTraceState as OpaqueTraceState, NameGatheringSubst as NameGatheringSubst, OutDBIdx as OutDBIdx, @@ -58,9 +58,9 @@ from jax._src.core import ( TraceStack as TraceStack, TraceState as TraceState, Tracer as Tracer, - unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, - unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, - unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, + unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401 + unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, # noqa: F401 + unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, # noqa: F401 UnshapedArray as UnshapedArray, Value as Value, Var as Var, diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index ea1ef4f02..3628ae4aa 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -16,9 +16,9 @@ # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.custom_derivatives import ( - _initial_style_jaxpr, - _sum_tangents, - _zeros_like_pytree, + _initial_style_jaxpr, # noqa: F401 + _sum_tangents, # noqa: F401 + _zeros_like_pytree, # noqa: F401 closure_convert as closure_convert, custom_gradient as custom_gradient, custom_jvp as custom_jvp, diff --git a/jax/dtypes.py b/jax/dtypes.py index a6f1b7645..4c1136360 100644 --- a/jax/dtypes.py +++ b/jax/dtypes.py @@ -18,10 +18,10 @@ from jax._src.dtypes import ( bfloat16 as bfloat16, canonicalize_dtype as canonicalize_dtype, - finfo, # TODO(phawkins): switch callers to jnp.finfo? + finfo, # TODO(phawkins): switch callers to jnp.finfo? # noqa: F401 float0 as float0, - iinfo, # TODO(phawkins): switch callers to jnp.iinfo? - issubdtype, # TODO(phawkins): switch callers to jnp.issubdtype? + iinfo, # TODO(phawkins): switch callers to jnp.iinfo? # noqa: F401 + issubdtype, # TODO(phawkins): switch callers to jnp.issubdtype? # noqa: F401 extended as extended, prng_key as prng_key, result_type as result_type, diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index 2620f5cc7..abe5eea3e 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -29,11 +29,9 @@ from typing import Any, Optional import jax from jax._src import array -from jax._src import config from jax._src import distributed from jax._src import sharding -from jax._src import sharding_impls -from jax._src.layout import Layout, DeviceLocalLayout as DLL +from jax._src.layout import Layout from jax._src import typing from jax._src import util from jax._src.lib import xla_extension as xe diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 619936379..3df3eb25c 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -27,7 +27,6 @@ import jax import jax.numpy as jnp from jax._src import test_util as jtu from jax._src import array -from jax._src import xla_bridge as xb from jax.sharding import NamedSharding, GSPMDSharding, SingleDeviceSharding from jax.sharding import PartitionSpec as P from jax.experimental.array_serialization import serialization diff --git a/jax/experimental/jax2tf/examples/mnist_lib.py b/jax/experimental/jax2tf/examples/mnist_lib.py index 77432f9eb..5c96506b3 100644 --- a/jax/experimental/jax2tf/examples/mnist_lib.py +++ b/jax/experimental/jax2tf/examples/mnist_lib.py @@ -30,7 +30,6 @@ from typing import Any import warnings from absl import flags -import flax from flax import linen as nn import jax diff --git a/jax/experimental/jax2tf/tests/cross_compilation_check.py b/jax/experimental/jax2tf/tests/cross_compilation_check.py index cc34d78e8..32b9a0f2a 100644 --- a/jax/experimental/jax2tf/tests/cross_compilation_check.py +++ b/jax/experimental/jax2tf/tests/cross_compilation_check.py @@ -39,8 +39,8 @@ from absl import logging import numpy.random as npr import jax # Must import before TF -from jax.experimental import jax2tf # Defines needed flags -from jax._src import test_util # Defines needed flags +from jax.experimental import jax2tf # Defines needed flags # noqa: F401 +from jax._src import test_util # Defines needed flags # noqa: F401 jax.config.parse_flags_with_absl() diff --git a/jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py b/jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py index 276927621..1653e3b8e 100644 --- a/jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py +++ b/jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py @@ -28,7 +28,6 @@ import unittest from absl.testing import absltest -import jax from jax._src import config from jax._src import test_util as jtu diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 34844aa77..2863ca4ed 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -67,7 +67,6 @@ from jax._src import config from jax._src import test_util as jtu from jax.experimental import jax2tf from jax.interpreters import mlir -from jax._src.interpreters import xla import numpy as np import tensorflow as tf diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 0ab35efb4..38af6d9d7 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -45,7 +45,6 @@ from jax._src import util from jax._src.export import shape_poly from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow -from jax._src.lib import xla_client import numpy as np from jax.experimental.jax2tf.tests import tf_test_util diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 247135395..8fe9a1dd9 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -22,8 +22,6 @@ from collections.abc import Sequence import contextlib from functools import partial import logging -import math -import os import re from typing import Any import unittest diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index b4989e151..ef19f94c4 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -16,7 +16,7 @@ from __future__ import annotations from collections import defaultdict from collections.abc import Callable, Iterator -from functools import partial, reduce, total_ordering, wraps +from functools import partial, reduce, total_ordering from typing import Any, NamedTuple import jax @@ -25,7 +25,6 @@ from jax import tree_util from jax.errors import KeyReuseError from jax.interpreters import batching, mlir from jax._src import api_util -from jax._src import config from jax._src import core from jax._src import linear_util as lu from jax._src import pjit diff --git a/jax/experimental/mosaic/__init__.py b/jax/experimental/mosaic/__init__.py index 867f485cb..5527bdaab 100644 --- a/jax/experimental/mosaic/__init__.py +++ b/jax/experimental/mosaic/__init__.py @@ -14,5 +14,5 @@ # ============================================================================== """Contains bindings for Mosaic.""" -from jax._src.tpu_custom_call import as_tpu_kernel -from jax._src.tpu_custom_call import lower_module_to_custom_call +from jax._src.tpu_custom_call import as_tpu_kernel as as_tpu_kernel +from jax._src.tpu_custom_call import lower_module_to_custom_call as lower_module_to_custom_call diff --git a/jax/experimental/mosaic/dialects.py b/jax/experimental/mosaic/dialects.py index bbe82fe83..4378eca99 100644 --- a/jax/experimental/mosaic/dialects.py +++ b/jax/experimental/mosaic/dialects.py @@ -14,4 +14,4 @@ # ============================================================================== """Contains bindings for Mosaic MLIR dialects.""" -from jax._src.lib import tpu +from jax._src.lib import tpu as tpu diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index dac0c16fe..5d8a4dd9f 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -13,55 +13,55 @@ # limitations under the License. # ============================================================================== -from jax import ShapeDtypeStruct +from jax import ShapeDtypeStruct as ShapeDtypeStruct from .core import ( - Barrier, - ClusterBarrier, - LaunchContext, - MemRefTransform, - TMABarrier, - TileTransform, - TransposeTransform, - Union, - as_gpu_kernel, + Barrier as Barrier, + ClusterBarrier as ClusterBarrier, + LaunchContext as LaunchContext, + MemRefTransform as MemRefTransform, + TMABarrier as TMABarrier, + TileTransform as TileTransform, + TransposeTransform as TransposeTransform, + Union as Union, + as_gpu_kernel as as_gpu_kernel, ) from .fragmented_array import ( - FragmentedArray, - FragmentedLayout, - WGMMA_LAYOUT, - WGMMA_ROW_LAYOUT, - WGSplatFragLayout, - WGStridedFragLayout, + FragmentedArray as FragmentedArray, + FragmentedLayout as FragmentedLayout, + WGMMA_LAYOUT as WGMMA_LAYOUT, + WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT, + WGSplatFragLayout as WGSplatFragLayout, + WGStridedFragLayout as WGStridedFragLayout, ) from .utils import ( - BarrierRef, - CollectiveBarrierRef, - DynamicSlice, - Partition, - Partition1D, - bytewidth, - c, - commit_shared, - debug_print, - ds, - fori, - memref_fold, - memref_slice, - memref_reshape, - memref_transpose, - memref_unfold, - memref_unsqueeze, - single_thread, - single_thread_predicate, - thread_idx, - tile_shape, - warp_idx, - warpgroup_barrier, - warpgroup_idx, - when, + BarrierRef as BarrierRef, + CollectiveBarrierRef as CollectiveBarrierRef, + DynamicSlice as DynamicSlice, + Partition as Partition, + Partition1D as Partition1D, + bytewidth as bytewidth, + c as c, + commit_shared as commit_shared, + debug_print as debug_print, + ds as ds, + fori as fori, + memref_fold as memref_fold, + memref_slice as memref_slice, + memref_reshape as memref_reshape, + memref_transpose as memref_transpose, + memref_unfold as memref_unfold, + memref_unsqueeze as memref_unsqueeze, + single_thread as single_thread, + single_thread_predicate as single_thread_predicate, + thread_idx as thread_idx, + tile_shape as tile_shape, + warp_idx as warp_idx, + warpgroup_barrier as warpgroup_barrier, + warpgroup_idx as warpgroup_idx, + when as when, ) from .wgmma import ( - WGMMAAccumulator, - WGMMALayout, - wgmma, + WGMMAAccumulator as WGMMAAccumulator, + WGMMALayout as WGMMALayout, + wgmma as wgmma, ) diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index c2a2a6fe6..3394eaaa0 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -17,21 +17,17 @@ import contextlib import dataclasses import enum import itertools -import os import warnings -from absl import app import jax from jax import random from jax._src.interpreters import mlir -from jax._src import test_util as jtu from jax.experimental.mosaic.gpu import profiler from jax.experimental.mosaic.gpu import * # noqa: F403 import jax.numpy as jnp from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import gpu -from jaxlib.mlir.dialects import nvgpu from jaxlib.mlir.dialects import nvvm from jaxlib.mlir.dialects import scf import numpy as np diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index c56c5cd6b..ce99bf423 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -16,7 +16,6 @@ import dataclasses import itertools -import math from typing import Any import jax diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 2709d0075..87ffe0929 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -20,7 +20,7 @@ import dataclasses import enum import functools import math -from typing import Any, Literal, cast +from typing import Any, Literal import jax from jax import numpy as jnp diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 56003ea7a..803efa190 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -16,13 +16,12 @@ from __future__ import annotations from functools import partial, lru_cache -from typing import Optional import zlib from typing import Any import jax import jax.numpy as jnp -from jax.tree_util import tree_flatten, tree_map, tree_unflatten +from jax.tree_util import tree_flatten, tree_unflatten from jax._src import core from jax._src.interpreters import ad from jax._src.interpreters import batching diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 3eee14fea..34cb5328f 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -18,47 +18,47 @@ See the Pallas documentation at https://jax.readthedocs.io/en/latest/pallas.html. """ -from jax._src.pallas.core import Blocked -from jax._src.pallas.core import BlockSpec -from jax._src.pallas.core import CompilerParams -from jax._src.pallas.core import core_map -from jax._src.pallas.core import CostEstimate -from jax._src.pallas.core import GridSpec -from jax._src.pallas.core import IndexingMode -from jax._src.pallas.core import MemorySpace -from jax._src.pallas.core import MemoryRef -from jax._src.pallas.core import no_block_spec -from jax._src.pallas.core import Unblocked -from jax._src.pallas.core import unblocked -from jax._src.pallas.pallas_call import pallas_call -from jax._src.pallas.pallas_call import pallas_call_p -from jax._src.pallas.primitives import atomic_add -from jax._src.pallas.primitives import atomic_and -from jax._src.pallas.primitives import atomic_cas -from jax._src.pallas.primitives import atomic_max -from jax._src.pallas.primitives import atomic_min -from jax._src.pallas.primitives import atomic_or -from jax._src.pallas.primitives import atomic_xchg -from jax._src.pallas.primitives import atomic_xor -from jax._src.pallas.primitives import debug_print -from jax._src.pallas.primitives import dot -from jax._src.pallas.primitives import load -from jax._src.pallas.primitives import max_contiguous -from jax._src.pallas.primitives import multiple_of -from jax._src.pallas.primitives import num_programs -from jax._src.pallas.primitives import program_id -from jax._src.pallas.primitives import run_scoped -from jax._src.pallas.primitives import store -from jax._src.pallas.primitives import swap -from jax._src.pallas.utils import cdiv -from jax._src.pallas.utils import next_power_of_2 -from jax._src.pallas.utils import strides_from_shape -from jax._src.pallas.utils import when -from jax._src.state.discharge import run_state -from jax._src.state.indexing import ds -from jax._src.state.indexing import dslice -from jax._src.state.indexing import Slice -from jax._src.state.primitives import broadcast_to +from jax._src.pallas.core import Blocked as Blocked +from jax._src.pallas.core import BlockSpec as BlockSpec +from jax._src.pallas.core import CompilerParams as CompilerParams +from jax._src.pallas.core import core_map as core_map +from jax._src.pallas.core import CostEstimate as CostEstimate +from jax._src.pallas.core import GridSpec as GridSpec +from jax._src.pallas.core import IndexingMode as IndexingMode +from jax._src.pallas.core import MemorySpace as MemorySpace +from jax._src.pallas.core import MemoryRef as MemoryRef +from jax._src.pallas.core import no_block_spec as no_block_spec +from jax._src.pallas.core import Unblocked as Unblocked +from jax._src.pallas.core import unblocked as unblocked +from jax._src.pallas.pallas_call import pallas_call as pallas_call +from jax._src.pallas.pallas_call import pallas_call_p as pallas_call_p +from jax._src.pallas.primitives import atomic_add as atomic_add +from jax._src.pallas.primitives import atomic_and as atomic_and +from jax._src.pallas.primitives import atomic_cas as atomic_cas +from jax._src.pallas.primitives import atomic_max as atomic_max +from jax._src.pallas.primitives import atomic_min as atomic_min +from jax._src.pallas.primitives import atomic_or as atomic_or +from jax._src.pallas.primitives import atomic_xchg as atomic_xchg +from jax._src.pallas.primitives import atomic_xor as atomic_xor +from jax._src.pallas.primitives import debug_print as debug_print +from jax._src.pallas.primitives import dot as dot +from jax._src.pallas.primitives import load as load +from jax._src.pallas.primitives import max_contiguous as max_contiguous +from jax._src.pallas.primitives import multiple_of as multiple_of +from jax._src.pallas.primitives import num_programs as num_programs +from jax._src.pallas.primitives import program_id as program_id +from jax._src.pallas.primitives import run_scoped as run_scoped +from jax._src.pallas.primitives import store as store +from jax._src.pallas.primitives import swap as swap +from jax._src.pallas.utils import cdiv as cdiv +from jax._src.pallas.utils import next_power_of_2 as next_power_of_2 +from jax._src.pallas.utils import strides_from_shape as strides_from_shape +from jax._src.pallas.utils import when as when +from jax._src.state.discharge import run_state as run_state +from jax._src.state.indexing import ds as ds +from jax._src.state.indexing import dslice as dslice +from jax._src.state.indexing import Slice as Slice +from jax._src.state.primitives import broadcast_to as broadcast_to ANY = MemorySpace.ANY diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 78e23afff..80eaea753 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -17,23 +17,23 @@ These APIs are highly unstable and can change weekly. Use at your own risk. """ -from jax._src.pallas.mosaic_gpu.core import Barrier -from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec -from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams -from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace -from jax._src.pallas.mosaic_gpu.core import SwizzleTransform -from jax._src.pallas.mosaic_gpu.core import TilingTransform -from jax._src.pallas.mosaic_gpu.core import TransposeTransform -from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef -from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC -from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem -from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem -from jax._src.pallas.mosaic_gpu.primitives import barrier_wait -from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive -from jax._src.pallas.mosaic_gpu.primitives import set_max_registers -from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem -from jax._src.pallas.mosaic_gpu.primitives import wgmma -from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait +from jax._src.pallas.mosaic_gpu.core import Barrier as Barrier +from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec as GPUBlockSpec +from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams as GPUCompilerParams +from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace as GPUMemorySpace +from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform +from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform +from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform +from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef +from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401 +from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem +from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem +from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait +from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive +from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers +from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem +from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma +from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait as wgmma_wait #: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM`. GMEM = GPUMemorySpace.GMEM diff --git a/jax/experimental/pallas/ops/gpu/rms_norm.py b/jax/experimental/pallas/ops/gpu/rms_norm.py index 98b26e2d7..2a7824315 100644 --- a/jax/experimental/pallas/ops/gpu/rms_norm.py +++ b/jax/experimental/pallas/ops/gpu/rms_norm.py @@ -18,8 +18,6 @@ from __future__ import annotations import functools -from typing import Optional - import jax from jax import lax import jax.numpy as jnp diff --git a/jax/experimental/pallas/ops/tpu/megablox/__init__.py b/jax/experimental/pallas/ops/tpu/megablox/__init__.py index 2c7391a18..3065c78c2 100644 --- a/jax/experimental/pallas/ops/tpu/megablox/__init__.py +++ b/jax/experimental/pallas/ops/tpu/megablox/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax.experimental.pallas.ops.tpu.megablox.ops import gmm +from jax.experimental.pallas.ops.tpu.megablox.ops import gmm as gmm diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/__init__.py b/jax/experimental/pallas/ops/tpu/paged_attention/__init__.py index 1cce79926..5bb5c4f5c 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/__init__.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention +from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as paged_attention diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/__init__.py b/jax/experimental/pallas/ops/tpu/splash_attention/__init__.py index e2b9a1c4d..4416f2f60 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/__init__.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/__init__.py @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import BlockSizes -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_masked_mha_reference -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_masked_mqa_reference -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mha -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mha_single_device -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mqa -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mqa_single_device -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import QKVLayout -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import SegmentIds -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import CausalMask -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import FullMask -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import LocalMask -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_causal_mask -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_local_attention_mask -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_random_mask -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import Mask -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import MultiHeadMask -from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import NumpyMask +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import BlockSizes as BlockSizes +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_masked_mha_reference as make_masked_mha_reference +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_masked_mqa_reference as make_masked_mqa_reference +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mha as make_splash_mha +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mha_single_device as make_splash_mha_single_device +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mqa as make_splash_mqa +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mqa_single_device as make_splash_mqa_single_device +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import QKVLayout as QKVLayout +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import SegmentIds as SegmentIds +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import CausalMask as CausalMask +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import FullMask as FullMask +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import LocalMask as LocalMask +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_causal_mask as make_causal_mask +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_local_attention_mask as make_local_attention_mask +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_random_mask as make_random_mask +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import Mask as Mask +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import MultiHeadMask as MultiHeadMask +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import NumpyMask as NumpyMask diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py index 3c672b8db..99ab04898 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py @@ -18,7 +18,7 @@ from __future__ import annotations import collections from collections.abc import Callable import functools -from typing import Dict, List, NamedTuple, Set, Tuple +from typing import NamedTuple from jax import util as jax_util from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib import numpy as np diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index 40fa9dc45..d00e0e90c 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -14,47 +14,47 @@ """Mosaic-specific Pallas APIs.""" -from jax._src.pallas.mosaic import core -from jax._src.pallas.mosaic.core import create_tensorcore_mesh -from jax._src.pallas.mosaic.core import dma_semaphore -from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec -from jax._src.pallas.mosaic.core import semaphore -from jax._src.pallas.mosaic.core import SemaphoreType -from jax._src.pallas.mosaic.core import TPUMemorySpace -from jax._src.pallas.mosaic.core import TPUCompilerParams -from jax._src.pallas.mosaic.core import runtime_assert_enabled -from jax._src.pallas.mosaic.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert -from jax._src.pallas.mosaic.lowering import LoweringException -from jax._src.pallas.mosaic.pipeline import ARBITRARY -from jax._src.pallas.mosaic.pipeline import BufferedRef -from jax._src.pallas.mosaic.pipeline import emit_pipeline -from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations -from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule -from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations -from jax._src.pallas.mosaic.pipeline import PARALLEL -from jax._src.pallas.mosaic.primitives import async_copy -from jax._src.pallas.mosaic.primitives import async_remote_copy -from jax._src.pallas.mosaic.primitives import bitcast -from jax._src.pallas.mosaic.primitives import delay -from jax._src.pallas.mosaic.primitives import device_id -from jax._src.pallas.mosaic.primitives import DeviceIdType -from jax._src.pallas.mosaic.primitives import get_barrier_semaphore -from jax._src.pallas.mosaic.primitives import make_async_copy -from jax._src.pallas.mosaic.primitives import make_async_remote_copy -from jax._src.pallas.mosaic.primitives import prng_random_bits -from jax._src.pallas.mosaic.primitives import prng_seed -from jax._src.pallas.mosaic.primitives import repeat -from jax._src.pallas.mosaic.primitives import roll -from jax._src.pallas.mosaic.primitives import semaphore_read -from jax._src.pallas.mosaic.primitives import semaphore_signal -from jax._src.pallas.mosaic.primitives import semaphore_wait -from jax._src.pallas.mosaic.random import to_pallas_key +from jax._src.pallas.mosaic import core as core +from jax._src.pallas.mosaic.core import create_tensorcore_mesh as create_tensorcore_mesh +from jax._src.pallas.mosaic.core import dma_semaphore as dma_semaphore +from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec as PrefetchScalarGridSpec +from jax._src.pallas.mosaic.core import semaphore as semaphore +from jax._src.pallas.mosaic.core import SemaphoreType as SemaphoreType +from jax._src.pallas.mosaic.core import TPUMemorySpace as TPUMemorySpace +from jax._src.pallas.mosaic.core import TPUCompilerParams as TPUCompilerParams +from jax._src.pallas.mosaic.core import runtime_assert_enabled as runtime_assert_enabled +from jax._src.pallas.mosaic.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert # noqa: F401 +from jax._src.pallas.mosaic.lowering import LoweringException as LoweringException +from jax._src.pallas.mosaic.pipeline import ARBITRARY as ARBITRARY +from jax._src.pallas.mosaic.pipeline import BufferedRef as BufferedRef +from jax._src.pallas.mosaic.pipeline import emit_pipeline as emit_pipeline +from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations as emit_pipeline_with_allocations +from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule as get_pipeline_schedule +from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations as make_pipeline_allocations +from jax._src.pallas.mosaic.pipeline import PARALLEL as PARALLEL +from jax._src.pallas.mosaic.primitives import async_copy as async_copy +from jax._src.pallas.mosaic.primitives import async_remote_copy as async_remote_copy +from jax._src.pallas.mosaic.primitives import bitcast as bitcast +from jax._src.pallas.mosaic.primitives import delay as delay +from jax._src.pallas.mosaic.primitives import device_id as device_id +from jax._src.pallas.mosaic.primitives import DeviceIdType as DeviceIdType +from jax._src.pallas.mosaic.primitives import get_barrier_semaphore as get_barrier_semaphore +from jax._src.pallas.mosaic.primitives import make_async_copy as make_async_copy +from jax._src.pallas.mosaic.primitives import make_async_remote_copy as make_async_remote_copy +from jax._src.pallas.mosaic.primitives import prng_random_bits as prng_random_bits +from jax._src.pallas.mosaic.primitives import prng_seed as prng_seed +from jax._src.pallas.mosaic.primitives import repeat as repeat +from jax._src.pallas.mosaic.primitives import roll as roll +from jax._src.pallas.mosaic.primitives import semaphore_read as semaphore_read +from jax._src.pallas.mosaic.primitives import semaphore_signal as semaphore_signal +from jax._src.pallas.mosaic.primitives import semaphore_wait as semaphore_wait +from jax._src.pallas.mosaic.random import to_pallas_key as to_pallas_key # Remove this import after October 22th 2024. -from jax._src.tpu_custom_call import CostEstimate +from jax._src.tpu_custom_call import CostEstimate as CostEstimate # TODO(cperivol): Temporary alias to the global run_scoped. Remove # this once everyone has migrated to the pallas core one. -from jax._src.pallas.primitives import run_scoped +from jax._src.pallas.primitives import run_scoped as run_scoped import types from jax._src.pallas.mosaic.verification import assume diff --git a/jax/experimental/pallas/triton.py b/jax/experimental/pallas/triton.py index bcee04374..06adb9e6d 100644 --- a/jax/experimental/pallas/triton.py +++ b/jax/experimental/pallas/triton.py @@ -14,7 +14,7 @@ """Triton-specific Pallas APIs.""" -from jax._src.pallas.triton.core import TritonCompilerParams -from jax._src.pallas.triton.primitives import approx_tanh -from jax._src.pallas.triton.primitives import debug_barrier -from jax._src.pallas.triton.primitives import elementwise_inline_asm +from jax._src.pallas.triton.core import TritonCompilerParams as TritonCompilerParams +from jax._src.pallas.triton.primitives import approx_tanh as approx_tanh +from jax._src.pallas.triton.primitives import debug_barrier as debug_barrier +from jax._src.pallas.triton.primitives import elementwise_inline_asm as elementwise_inline_asm diff --git a/jax/experimental/serialize_executable.py b/jax/experimental/serialize_executable.py index 1c0614da5..2d65141a2 100644 --- a/jax/experimental/serialize_executable.py +++ b/jax/experimental/serialize_executable.py @@ -17,7 +17,6 @@ from __future__ import annotations import pickle import io -from typing import Optional, Union import jax from jax._src.lib import xla_client as xc diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 7a04c88c5..24bed5034 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -52,7 +52,7 @@ from jax._src.lax import (lax, parallel as lax_parallel, slicing, windowed_reductions, convolution, fft, linalg, special, control_flow, ann) from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import hlo, sdy +from jax._src.lib.mlir.dialects import sdy from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3, as_hashable_function, memoize, partition_list, merge_lists, split_list, subs_list2) diff --git a/jax/experimental/sparse/api.py b/jax/experimental/sparse/api.py index 5e64e1e14..20f0f806a 100644 --- a/jax/experimental/sparse/api.py +++ b/jax/experimental/sparse/api.py @@ -34,7 +34,6 @@ from __future__ import annotations from functools import partial import operator -from typing import Optional, Union import jax from jax import tree_util diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 9f2f0f69b..f65f7b0a1 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -49,9 +49,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lax.lax import ( _const, ranges_like, remaining, _dot_general_batch_dim_nums, DotDimensionNumbers) from jax._src.lax.slicing import GatherDimensionNumbers, GatherScatterMode -from jax._src.lib.mlir import ir from jax._src.lib import gpu_sparse -from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.setops import _unique from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.util import canonicalize_axis diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 7275d6bb2..8aa7d80c7 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -31,8 +31,7 @@ from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse import bcoo from jax.experimental.sparse.util import ( nfold_vmap, _count_stored_elements, - _csr_to_coo, _dot_general_validated_shape, - CuSparseEfficiencyWarning, SparseInfo, Shape) + _csr_to_coo, CuSparseEfficiencyWarning, SparseInfo, Shape) from jax.util import split_list, safe_zip from jax._src import api_util diff --git a/jax/experimental/sparse/csr.py b/jax/experimental/sparse/csr.py index 89d08f109..84171855b 100644 --- a/jax/experimental/sparse/csr.py +++ b/jax/experimental/sparse/csr.py @@ -17,7 +17,6 @@ from __future__ import annotations from functools import partial import operator -from typing import Optional import warnings import numpy as np @@ -35,7 +34,7 @@ from jax._src.interpreters import ad from jax._src.lax.lax import _const from jax._src.lib import gpu_sparse from jax._src.numpy.util import promote_dtypes -from jax._src.typing import Array, ArrayLike, DTypeLike +from jax._src.typing import Array, DTypeLike import jax.numpy as jnp diff --git a/jax/experimental/sparse/test_util.py b/jax/experimental/sparse/test_util.py index 365c43652..77c975130 100644 --- a/jax/experimental/sparse/test_util.py +++ b/jax/experimental/sparse/test_util.py @@ -26,7 +26,6 @@ from jax import lax from jax import tree_util from jax._src import test_util as jtu from jax._src.lax.lax import DotDimensionNumbers -from jax._src.lib import gpu_sparse from jax._src.typing import DTypeLike from jax.experimental import sparse import jax.numpy as jnp diff --git a/jax/experimental/sparse/util.py b/jax/experimental/sparse/util.py index 9aa9e42f2..7ef1ed781 100644 --- a/jax/experimental/sparse/util.py +++ b/jax/experimental/sparse/util.py @@ -15,7 +15,7 @@ """Sparse utilities.""" import functools -from typing import Any, NamedTuple, Union +from typing import NamedTuple import numpy as np import jax @@ -23,8 +23,6 @@ from jax import lax from jax import tree_util from jax import vmap from jax._src import core -from jax._src import dtypes -from jax._src import stages from jax._src.api_util import flatten_axes import jax.numpy as jnp from jax.util import safe_zip diff --git a/jax/experimental/topologies.py b/jax/experimental/topologies.py index 7866564e9..06be2b748 100644 --- a/jax/experimental/topologies.py +++ b/jax/experimental/topologies.py @@ -14,9 +14,7 @@ from __future__ import annotations -import abc from collections.abc import Sequence -from typing import Optional import jax from jax.experimental import mesh_utils diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index edfc56ddd..41456c4cf 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -27,7 +27,7 @@ from jax._src.interpreters.mlir import ( Token as Token, TokenSet as TokenSet, Value as Value, - call_lowering as _call_lowering, + call_lowering as _call_lowering, # noqa: F401 _lowerings as _lowerings, _platform_specific_lowerings as _platform_specific_lowerings, aval_to_ir_type as aval_to_ir_type, @@ -41,7 +41,7 @@ from jax._src.interpreters.mlir import ( dtype_to_ir_type as dtype_to_ir_type, emit_python_callback as emit_python_callback, flatten_ir_types as flatten_ir_types, - flatten_ir_values as flatten_lowering_ir_args, # TODO(phawkins): remove me + flatten_ir_values as flatten_lowering_ir_args, # TODO(phawkins): remove me # noqa: F401 flatten_ir_values as flatten_ir_values, unflatten_ir_values_like_types as unflatten_ir_values_like_types, func_dialect as func_dialect, diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index c5aa31a53..15c9a2cfe 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -38,9 +38,9 @@ from jax._src.op_shardings import ( from jax._src.sharding_impls import ( ArrayMapping as ArrayMapping, - UNSPECIFIED as _UNSPECIFIED, + UNSPECIFIED as _UNSPECIFIED, # noqa: F401 array_mapping_to_axis_resources as array_mapping_to_axis_resources, - is_unspecified as _is_unspecified, + is_unspecified as _is_unspecified, # noqa: F401 ) from jax._src.sharding_specs import ( diff --git a/jax/lax/linalg.py b/jax/lax/linalg.py index bc6f53d62..dad62f099 100644 --- a/jax/lax/linalg.py +++ b/jax/lax/linalg.py @@ -13,31 +13,31 @@ # limitations under the License. from jax._src.lax.linalg import ( - cholesky, - cholesky_p, - eig, - eig_p, - eigh, - eigh_p, - hessenberg, - hessenberg_p, - lu, - lu_p, - lu_pivots_to_permutation, - householder_product, - householder_product_p, - qr, - qr_p, - svd, - svd_p, - triangular_solve, - triangular_solve_p, - tridiagonal, - tridiagonal_p, - tridiagonal_solve, - tridiagonal_solve_p, - schur, - schur_p + cholesky as cholesky, + cholesky_p as cholesky_p, + eig as eig, + eig_p as eig_p, + eigh as eigh, + eigh_p as eigh_p, + hessenberg as hessenberg, + hessenberg_p as hessenberg_p, + lu as lu, + lu_p as lu_p, + lu_pivots_to_permutation as lu_pivots_to_permutation, + householder_product as householder_product, + householder_product_p as householder_product_p, + qr as qr, + qr_p as qr_p, + svd as svd, + svd_p as svd_p, + triangular_solve as triangular_solve, + triangular_solve_p as triangular_solve_p, + tridiagonal as tridiagonal, + tridiagonal_p as tridiagonal_p, + tridiagonal_solve as tridiagonal_solve, + tridiagonal_solve_p as tridiagonal_solve_p, + schur as schur, + schur_p as schur_p, ) diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index bd8068729..d50d55033 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -18,7 +18,7 @@ from jax.numpy import fft as fft from jax.numpy import linalg as linalg -from jax._src.basearray import Array as ndarray +from jax._src.basearray import Array as ndarray # noqa: F401 from jax._src.dtypes import ( isdtype as isdtype, @@ -53,7 +53,7 @@ from jax._src.numpy.lax_numpy import ( bincount as bincount, blackman as blackman, block as block, - bool_ as bool, # Array API alias for bool_ + bool_ as bool, # Array API alias for bool_ # noqa: F401 bool_ as bool_, broadcast_arrays as broadcast_arrays, broadcast_shapes as broadcast_shapes, diff --git a/jax/scipy/stats/nbinom.py b/jax/scipy/stats/nbinom.py index 83e0ba9fc..6a3dc502d 100644 --- a/jax/scipy/stats/nbinom.py +++ b/jax/scipy/stats/nbinom.py @@ -13,6 +13,6 @@ # limitations under the License. from jax._src.scipy.stats.nbinom import ( - logpmf, - pmf, + logpmf as logpmf, + pmf as pmf, ) diff --git a/jax/sharding.py b/jax/sharding.py index 9a2d8db21..3c41439ef 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -17,7 +17,6 @@ from jax._src.sharding import Sharding as Sharding from jax._src.sharding_impls import ( - XLACompatibleSharding as _deprecated_XLACompatibleSharding, NamedSharding as NamedSharding, SingleDeviceSharding as SingleDeviceSharding, PmapSharding as PmapSharding, @@ -28,7 +27,7 @@ from jax._src.partition_spec import ( PartitionSpec as PartitionSpec, ) from jax._src.interpreters.pxla import Mesh as Mesh -from jax._src.mesh import AbstractMesh +from jax._src.mesh import AbstractMesh as AbstractMesh _deprecations = { # Finalized 2024-10-01; remove after 2025-01-01. diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index 9867c07b1..f7d4f9a19 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -17,8 +17,6 @@ import importlib import logging import os import pathlib -import platform -import sys from jax._src.lib import xla_client import jax._src.xla_bridge as xb diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py index 3dbcaf449..8b176b675 100644 --- a/jax_plugins/rocm/__init__.py +++ b/jax_plugins/rocm/__init__.py @@ -17,7 +17,6 @@ import importlib import logging import os import pathlib -import platform from jax._src.lib import xla_client import jax._src.xla_bridge as xb diff --git a/pyproject.toml b/pyproject.toml index 1d901d26d..9f5f06e7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,42 +133,3 @@ max-complexity = 18 "docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb" = ["F811"] "docs/jep/9407-type-promotion.ipynb" = ["F811"] "docs/autodidax.ipynb" = ["F811"] -# Note: we don't use jax/*.py because this matches contents of jax/_src -"__init__.py" = ["F401"] -"jax/abstract_arrays.py" = ["F401"] -"jax/ad_checkpoint.py" = ["F401"] -"jax/api_util.py" = ["F401"] -"jax/cloud_tpu_init.py" = ["F401"] -"jax/core.py" = ["F401"] -"jax/custom_batching.py" = ["F401"] -"jax/custom_derivatives.py" = ["F401"] -"jax/custom_transpose.py" = ["F401"] -"jax/debug.py" = ["F401"] -"jax/distributed.py" = ["F401"] -"jax/dlpack.py" = ["F401"] -"jax/dtypes.py" = ["F401"] -"jax/errors.py" = ["F401"] -"jax/experimental/*.py" = ["F401"] -"jax/extend/*.py" = ["F401"] -"jax/flatten_util.py" = ["F401"] -"jax/interpreters/ad.py" = ["F401"] -"jax/interpreters/batching.py" = ["F401"] -"jax/interpreters/mlir.py" = ["F401"] -"jax/interpreters/partial_eval.py" = ["F401"] -"jax/interpreters/pxla.py" = ["F401"] -"jax/interpreters/xla.py" = ["F401"] -"jax/lax/*.py" = ["F401"] -"jax/linear_util.py" = ["F401"] -"jax/monitoring.py" = ["F401"] -"jax/nn/*.py" = ["F401"] -"jax/numpy/*.py" = ["F401"] -"jax/prng.py" = ["F401"] -"jax/profiler.py" = ["F401"] -"jax/random.py" = ["F401"] -"jax/scipy/*.py" = ["F401"] -"jax/sharding.py" = ["F401"] -"jax/stages.py" = ["F401"] -"jax/test_util.py" = ["F401"] -"jax/tree_util.py" = ["F401"] -"jax/typing.py" = ["F401"] -"jax/util.py" = ["F401"]