Updated the type annotations of *_spec= parameters of pl.pallas_call

The previous type did not work for nested pytrees and for some reason neither
pytype nor mypy flagged that.

I also re-enabled type checking for most pallas/*.py files.
This commit is contained in:
Sergei Lebedev 2024-06-07 12:07:07 +01:00
parent f8473509cf
commit 70f6ab3128
5 changed files with 24 additions and 34 deletions

View File

@ -15,13 +15,12 @@
"""Module for pallas-core functionality."""
from __future__ import annotations
from collections.abc import Iterator, Sequence
import copy
from collections.abc import Sequence
import contextlib
import dataclasses
import functools
from typing import Any, Callable, Union
from collections.abc import Iterator
from jax._src import api_util
from jax._src import core as jax_core
@ -33,9 +32,6 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.state import discharge as state_discharge
import jax.numpy as jnp
# TODO(sharadmv): enable type checking
# mypy: ignore-errors
partial = functools.partial
Grid = tuple[Union[int, jax_core.Array, None], ...] # None indicates that the bound is dynamic.
DynamicGrid = tuple[Union[int, jax_core.Array], ...]
@ -156,6 +152,10 @@ class BlockSpec:
return out
# A PyTree of BlockSpec | NoBlockSpec.
BlockSpecTree = Any
@dataclasses.dataclass(frozen=True)
class BlockMapping:
block_shape: tuple[Mapped | int, ...]
@ -201,7 +201,7 @@ class GridMapping:
def static_grid(self) -> StaticGrid:
if self.num_dynamic_grid_bounds:
raise ValueError("Expected a grid with fully static bounds")
return self.grid # typing: ignore
return self.grid # type: ignore
def _preprocess_grid(grid: Grid | int | None) -> Grid:
@ -213,9 +213,9 @@ def _preprocess_grid(grid: Grid | int | None) -> Grid:
def _convert_block_spec_to_block_mapping(
in_avals: list[jax_core.ShapedArray], block_spec: BlockSpec | None,
in_avals: Sequence[jax_core.ShapedArray], block_spec: BlockSpec,
aval: jax_core.ShapedArray, in_tree: Any,
) -> BlockSpec | None:
) -> BlockMapping | None:
if block_spec is no_block_spec:
return None
if block_spec.index_map is None:
@ -283,12 +283,8 @@ class GridSpec:
def __init__(
self,
grid: Grid | None = None,
in_specs: BlockSpec
| Sequence[BlockSpec | NoBlockSpec]
| NoBlockSpec = no_block_spec,
out_specs: BlockSpec
| Sequence[BlockSpec | NoBlockSpec]
| NoBlockSpec = no_block_spec,
in_specs: BlockSpecTree = no_block_spec,
out_specs: BlockSpecTree = no_block_spec,
):
# Be more lenient for in/out_specs
if isinstance(in_specs, list):

View File

@ -19,7 +19,7 @@ from collections.abc import Sequence
import dataclasses
import enum
import functools
from typing import Any, Union
from typing import Any
from jax._src import core as jax_core
from jax._src import dtypes
@ -28,15 +28,13 @@ from jax._src import util
import jax.numpy as jnp
from jax._src.pallas import core as pallas_core
# TODO(sharadmv): enable type checking
# mypy: ignore-errors
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
partial = functools.partial
Grid = pallas_core.Grid
BlockSpec = pallas_core.BlockSpec
BlockSpecTree = pallas_core.BlockSpecTree
GridMapping = pallas_core.GridMapping
NoBlockSpec = pallas_core.NoBlockSpec
AbstractMemoryRef = pallas_core.AbstractMemoryRef
@ -97,6 +95,7 @@ class SemaphoreType(enum.Enum):
BARRIER = "barrier"
def __call__(self, shape: tuple[int, ...]):
dtype: Any
if self == SemaphoreType.DMA:
dtype = DmaSemaphoreTy()
elif self == SemaphoreType.BARRIER:
@ -143,9 +142,6 @@ def _make_aval(obj: object) -> jax_core.AbstractValue:
"Only VMEM and SemaphoreType are supported.")
BlockSpecTree = Union[BlockSpec, NoBlockSpec, Sequence["BlockSpecTree"]]
@dataclasses.dataclass(init=False, unsafe_hash=True)
class PrefetchScalarGridSpec(pallas_core.GridSpec):
grid: Grid

View File

@ -55,10 +55,11 @@ map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Grid = pallas_core.Grid
BlockSpec = pallas_core.BlockSpec
GridSpec = pallas_core.GridSpec
BlockMapping = pallas_core.BlockMapping
GridMapping = pallas_core.GridMapping
BlockSpec = pallas_core.BlockSpec
BlockSpecTree = pallas_core.BlockSpecTree
NoBlockSpec = pallas_core.NoBlockSpec
no_block_spec = pallas_core.no_block_spec
@ -763,14 +764,13 @@ def pallas_call(
grid_spec: GridSpec | None = None,
debug: bool = False,
grid: Grid | None = None,
in_specs: Sequence[BlockSpec | NoBlockSpec] | NoBlockSpec = no_block_spec,
out_specs: BlockSpec | NoBlockSpec
| Sequence[BlockSpec | NoBlockSpec] = no_block_spec,
in_specs: BlockSpecTree = no_block_spec,
out_specs: BlockSpecTree = no_block_spec,
input_output_aliases: dict[int, int] = {},
interpret: bool = False,
name: str | None = None,
compiler_params: dict[str, Any] | None = None,
):
) -> Callable[..., Any]:
name = _extract_function_name(f, name)
if compiler_params is None:
compiler_params = {}

View File

@ -39,10 +39,6 @@ from jax._src.state import primitives as sp
from jax.interpreters import mlir
import jax.numpy as jnp
# TODO(sharadmv): enable type checking
# mypy: ignore-errors
partial = functools.partial
Slice = indexing.Slice
NDIndexer = indexing.NDIndexer
@ -64,6 +60,7 @@ program_id_p.def_custom_bind(program_id_bind)
def _program_id_impl(*, axis: int):
grid_env = pallas_core.current_grid_env()
assert grid_env
return grid_env[axis].axis_index
program_id_p.def_impl(_program_id_impl)
@ -87,6 +84,7 @@ def _num_programs_bind(*, axis: int):
@num_programs_p.def_impl
def _num_programs_impl(*, axis: int):
grid_env = pallas_core.current_grid_env()
assert grid_env
return jnp.asarray(grid_env[axis].axis_size, dtype=jnp.int32)
@num_programs_p.def_abstract_eval
@ -569,7 +567,7 @@ def debug_print(fmt: str, *args: jax.ArrayLike):
""" # fmt: skip
has_placeholders = False
if fmt:
_, field_name, *_ = next(string.Formatter().parse(fmt))
_, field_name, *_ = next(iter(string.Formatter().parse(fmt)))
has_placeholders = field_name is not None
return debug_print_p.bind(*args, fmt=fmt, has_placeholders=has_placeholders)

View File

@ -1099,7 +1099,7 @@ def _flash_attention_bwd_dkv(
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
grid=grid,
in_specs=in_specs, # type: ignore
in_specs=in_specs,
out_specs=out_specs,
scratch_shapes=scratch_shapes,
),
@ -1444,8 +1444,8 @@ def _flash_attention_bwd_dq(
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
grid=grid,
in_specs=in_specs, # type: ignore
out_specs=out_specs, # type: ignore
in_specs=in_specs,
out_specs=out_specs,
scratch_shapes=scratch_shapes,
),
out_shape=out_shapes,