From 70f6ab3128a5d7ef26c924e4c2c51d0301bd58ea Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 7 Jun 2024 12:07:07 +0100 Subject: [PATCH] Updated the type annotations of *_spec= parameters of pl.pallas_call The previous type did not work for nested pytrees and for some reason neither pytype nor mypy flagged that. I also re-enabled type checking for most pallas/*.py files. --- jax/_src/pallas/core.py | 24 ++++++++----------- jax/_src/pallas/mosaic/core.py | 10 +++----- jax/_src/pallas/pallas_call.py | 10 ++++---- jax/_src/pallas/primitives.py | 8 +++---- .../pallas/ops/tpu/flash_attention.py | 6 ++--- 5 files changed, 24 insertions(+), 34 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 5e044a0a6..8820825f0 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -15,13 +15,12 @@ """Module for pallas-core functionality.""" from __future__ import annotations +from collections.abc import Iterator, Sequence import copy -from collections.abc import Sequence import contextlib import dataclasses import functools from typing import Any, Callable, Union -from collections.abc import Iterator from jax._src import api_util from jax._src import core as jax_core @@ -33,9 +32,6 @@ from jax._src.interpreters import partial_eval as pe from jax._src.state import discharge as state_discharge import jax.numpy as jnp -# TODO(sharadmv): enable type checking -# mypy: ignore-errors - partial = functools.partial Grid = tuple[Union[int, jax_core.Array, None], ...] # None indicates that the bound is dynamic. DynamicGrid = tuple[Union[int, jax_core.Array], ...] @@ -156,6 +152,10 @@ class BlockSpec: return out +# A PyTree of BlockSpec | NoBlockSpec. +BlockSpecTree = Any + + @dataclasses.dataclass(frozen=True) class BlockMapping: block_shape: tuple[Mapped | int, ...] @@ -201,7 +201,7 @@ class GridMapping: def static_grid(self) -> StaticGrid: if self.num_dynamic_grid_bounds: raise ValueError("Expected a grid with fully static bounds") - return self.grid # typing: ignore + return self.grid # type: ignore def _preprocess_grid(grid: Grid | int | None) -> Grid: @@ -213,9 +213,9 @@ def _preprocess_grid(grid: Grid | int | None) -> Grid: def _convert_block_spec_to_block_mapping( - in_avals: list[jax_core.ShapedArray], block_spec: BlockSpec | None, + in_avals: Sequence[jax_core.ShapedArray], block_spec: BlockSpec, aval: jax_core.ShapedArray, in_tree: Any, - ) -> BlockSpec | None: +) -> BlockMapping | None: if block_spec is no_block_spec: return None if block_spec.index_map is None: @@ -283,12 +283,8 @@ class GridSpec: def __init__( self, grid: Grid | None = None, - in_specs: BlockSpec - | Sequence[BlockSpec | NoBlockSpec] - | NoBlockSpec = no_block_spec, - out_specs: BlockSpec - | Sequence[BlockSpec | NoBlockSpec] - | NoBlockSpec = no_block_spec, + in_specs: BlockSpecTree = no_block_spec, + out_specs: BlockSpecTree = no_block_spec, ): # Be more lenient for in/out_specs if isinstance(in_specs, list): diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 76ad43d59..a312c9587 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -19,7 +19,7 @@ from collections.abc import Sequence import dataclasses import enum import functools -from typing import Any, Union +from typing import Any from jax._src import core as jax_core from jax._src import dtypes @@ -28,15 +28,13 @@ from jax._src import util import jax.numpy as jnp from jax._src.pallas import core as pallas_core -# TODO(sharadmv): enable type checking -# mypy: ignore-errors - map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip partial = functools.partial Grid = pallas_core.Grid BlockSpec = pallas_core.BlockSpec +BlockSpecTree = pallas_core.BlockSpecTree GridMapping = pallas_core.GridMapping NoBlockSpec = pallas_core.NoBlockSpec AbstractMemoryRef = pallas_core.AbstractMemoryRef @@ -97,6 +95,7 @@ class SemaphoreType(enum.Enum): BARRIER = "barrier" def __call__(self, shape: tuple[int, ...]): + dtype: Any if self == SemaphoreType.DMA: dtype = DmaSemaphoreTy() elif self == SemaphoreType.BARRIER: @@ -143,9 +142,6 @@ def _make_aval(obj: object) -> jax_core.AbstractValue: "Only VMEM and SemaphoreType are supported.") -BlockSpecTree = Union[BlockSpec, NoBlockSpec, Sequence["BlockSpecTree"]] - - @dataclasses.dataclass(init=False, unsafe_hash=True) class PrefetchScalarGridSpec(pallas_core.GridSpec): grid: Grid diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 7746a719b..a4436cd1d 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -55,10 +55,11 @@ map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip Grid = pallas_core.Grid -BlockSpec = pallas_core.BlockSpec GridSpec = pallas_core.GridSpec BlockMapping = pallas_core.BlockMapping GridMapping = pallas_core.GridMapping +BlockSpec = pallas_core.BlockSpec +BlockSpecTree = pallas_core.BlockSpecTree NoBlockSpec = pallas_core.NoBlockSpec no_block_spec = pallas_core.no_block_spec @@ -763,14 +764,13 @@ def pallas_call( grid_spec: GridSpec | None = None, debug: bool = False, grid: Grid | None = None, - in_specs: Sequence[BlockSpec | NoBlockSpec] | NoBlockSpec = no_block_spec, - out_specs: BlockSpec | NoBlockSpec - | Sequence[BlockSpec | NoBlockSpec] = no_block_spec, + in_specs: BlockSpecTree = no_block_spec, + out_specs: BlockSpecTree = no_block_spec, input_output_aliases: dict[int, int] = {}, interpret: bool = False, name: str | None = None, compiler_params: dict[str, Any] | None = None, -): +) -> Callable[..., Any]: name = _extract_function_name(f, name) if compiler_params is None: compiler_params = {} diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 876b0b7d2..5888d5ae4 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -39,10 +39,6 @@ from jax._src.state import primitives as sp from jax.interpreters import mlir import jax.numpy as jnp - -# TODO(sharadmv): enable type checking -# mypy: ignore-errors - partial = functools.partial Slice = indexing.Slice NDIndexer = indexing.NDIndexer @@ -64,6 +60,7 @@ program_id_p.def_custom_bind(program_id_bind) def _program_id_impl(*, axis: int): grid_env = pallas_core.current_grid_env() + assert grid_env return grid_env[axis].axis_index program_id_p.def_impl(_program_id_impl) @@ -87,6 +84,7 @@ def _num_programs_bind(*, axis: int): @num_programs_p.def_impl def _num_programs_impl(*, axis: int): grid_env = pallas_core.current_grid_env() + assert grid_env return jnp.asarray(grid_env[axis].axis_size, dtype=jnp.int32) @num_programs_p.def_abstract_eval @@ -569,7 +567,7 @@ def debug_print(fmt: str, *args: jax.ArrayLike): """ # fmt: skip has_placeholders = False if fmt: - _, field_name, *_ = next(string.Formatter().parse(fmt)) + _, field_name, *_ = next(iter(string.Formatter().parse(fmt))) has_placeholders = field_name is not None return debug_print_p.bind(*args, fmt=fmt, has_placeholders=has_placeholders) diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index 0905176e4..f3b09c964 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -1099,7 +1099,7 @@ def _flash_attention_bwd_dkv( grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, grid=grid, - in_specs=in_specs, # type: ignore + in_specs=in_specs, out_specs=out_specs, scratch_shapes=scratch_shapes, ), @@ -1444,8 +1444,8 @@ def _flash_attention_bwd_dq( grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, grid=grid, - in_specs=in_specs, # type: ignore - out_specs=out_specs, # type: ignore + in_specs=in_specs, + out_specs=out_specs, scratch_shapes=scratch_shapes, ), out_shape=out_shapes,