mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Updated the type annotations of *_spec= parameters of pl.pallas_call
The previous type did not work for nested pytrees and for some reason neither pytype nor mypy flagged that. I also re-enabled type checking for most pallas/*.py files.
This commit is contained in:
parent
f8473509cf
commit
70f6ab3128
@ -15,13 +15,12 @@
|
||||
"""Module for pallas-core functionality."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Sequence
|
||||
import copy
|
||||
from collections.abc import Sequence
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Any, Callable, Union
|
||||
from collections.abc import Iterator
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import core as jax_core
|
||||
@ -33,9 +32,6 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.state import discharge as state_discharge
|
||||
import jax.numpy as jnp
|
||||
|
||||
# TODO(sharadmv): enable type checking
|
||||
# mypy: ignore-errors
|
||||
|
||||
partial = functools.partial
|
||||
Grid = tuple[Union[int, jax_core.Array, None], ...] # None indicates that the bound is dynamic.
|
||||
DynamicGrid = tuple[Union[int, jax_core.Array], ...]
|
||||
@ -156,6 +152,10 @@ class BlockSpec:
|
||||
return out
|
||||
|
||||
|
||||
# A PyTree of BlockSpec | NoBlockSpec.
|
||||
BlockSpecTree = Any
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BlockMapping:
|
||||
block_shape: tuple[Mapped | int, ...]
|
||||
@ -201,7 +201,7 @@ class GridMapping:
|
||||
def static_grid(self) -> StaticGrid:
|
||||
if self.num_dynamic_grid_bounds:
|
||||
raise ValueError("Expected a grid with fully static bounds")
|
||||
return self.grid # typing: ignore
|
||||
return self.grid # type: ignore
|
||||
|
||||
|
||||
def _preprocess_grid(grid: Grid | int | None) -> Grid:
|
||||
@ -213,9 +213,9 @@ def _preprocess_grid(grid: Grid | int | None) -> Grid:
|
||||
|
||||
|
||||
def _convert_block_spec_to_block_mapping(
|
||||
in_avals: list[jax_core.ShapedArray], block_spec: BlockSpec | None,
|
||||
in_avals: Sequence[jax_core.ShapedArray], block_spec: BlockSpec,
|
||||
aval: jax_core.ShapedArray, in_tree: Any,
|
||||
) -> BlockSpec | None:
|
||||
) -> BlockMapping | None:
|
||||
if block_spec is no_block_spec:
|
||||
return None
|
||||
if block_spec.index_map is None:
|
||||
@ -283,12 +283,8 @@ class GridSpec:
|
||||
def __init__(
|
||||
self,
|
||||
grid: Grid | None = None,
|
||||
in_specs: BlockSpec
|
||||
| Sequence[BlockSpec | NoBlockSpec]
|
||||
| NoBlockSpec = no_block_spec,
|
||||
out_specs: BlockSpec
|
||||
| Sequence[BlockSpec | NoBlockSpec]
|
||||
| NoBlockSpec = no_block_spec,
|
||||
in_specs: BlockSpecTree = no_block_spec,
|
||||
out_specs: BlockSpecTree = no_block_spec,
|
||||
):
|
||||
# Be more lenient for in/out_specs
|
||||
if isinstance(in_specs, list):
|
||||
|
@ -19,7 +19,7 @@ from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import dtypes
|
||||
@ -28,15 +28,13 @@ from jax._src import util
|
||||
import jax.numpy as jnp
|
||||
from jax._src.pallas import core as pallas_core
|
||||
|
||||
# TODO(sharadmv): enable type checking
|
||||
# mypy: ignore-errors
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
partial = functools.partial
|
||||
Grid = pallas_core.Grid
|
||||
BlockSpec = pallas_core.BlockSpec
|
||||
BlockSpecTree = pallas_core.BlockSpecTree
|
||||
GridMapping = pallas_core.GridMapping
|
||||
NoBlockSpec = pallas_core.NoBlockSpec
|
||||
AbstractMemoryRef = pallas_core.AbstractMemoryRef
|
||||
@ -97,6 +95,7 @@ class SemaphoreType(enum.Enum):
|
||||
BARRIER = "barrier"
|
||||
|
||||
def __call__(self, shape: tuple[int, ...]):
|
||||
dtype: Any
|
||||
if self == SemaphoreType.DMA:
|
||||
dtype = DmaSemaphoreTy()
|
||||
elif self == SemaphoreType.BARRIER:
|
||||
@ -143,9 +142,6 @@ def _make_aval(obj: object) -> jax_core.AbstractValue:
|
||||
"Only VMEM and SemaphoreType are supported.")
|
||||
|
||||
|
||||
BlockSpecTree = Union[BlockSpec, NoBlockSpec, Sequence["BlockSpecTree"]]
|
||||
|
||||
|
||||
@dataclasses.dataclass(init=False, unsafe_hash=True)
|
||||
class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
grid: Grid
|
||||
|
@ -55,10 +55,11 @@ map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
Grid = pallas_core.Grid
|
||||
BlockSpec = pallas_core.BlockSpec
|
||||
GridSpec = pallas_core.GridSpec
|
||||
BlockMapping = pallas_core.BlockMapping
|
||||
GridMapping = pallas_core.GridMapping
|
||||
BlockSpec = pallas_core.BlockSpec
|
||||
BlockSpecTree = pallas_core.BlockSpecTree
|
||||
NoBlockSpec = pallas_core.NoBlockSpec
|
||||
no_block_spec = pallas_core.no_block_spec
|
||||
|
||||
@ -763,14 +764,13 @@ def pallas_call(
|
||||
grid_spec: GridSpec | None = None,
|
||||
debug: bool = False,
|
||||
grid: Grid | None = None,
|
||||
in_specs: Sequence[BlockSpec | NoBlockSpec] | NoBlockSpec = no_block_spec,
|
||||
out_specs: BlockSpec | NoBlockSpec
|
||||
| Sequence[BlockSpec | NoBlockSpec] = no_block_spec,
|
||||
in_specs: BlockSpecTree = no_block_spec,
|
||||
out_specs: BlockSpecTree = no_block_spec,
|
||||
input_output_aliases: dict[int, int] = {},
|
||||
interpret: bool = False,
|
||||
name: str | None = None,
|
||||
compiler_params: dict[str, Any] | None = None,
|
||||
):
|
||||
) -> Callable[..., Any]:
|
||||
name = _extract_function_name(f, name)
|
||||
if compiler_params is None:
|
||||
compiler_params = {}
|
||||
|
@ -39,10 +39,6 @@ from jax._src.state import primitives as sp
|
||||
from jax.interpreters import mlir
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
# TODO(sharadmv): enable type checking
|
||||
# mypy: ignore-errors
|
||||
|
||||
partial = functools.partial
|
||||
Slice = indexing.Slice
|
||||
NDIndexer = indexing.NDIndexer
|
||||
@ -64,6 +60,7 @@ program_id_p.def_custom_bind(program_id_bind)
|
||||
|
||||
def _program_id_impl(*, axis: int):
|
||||
grid_env = pallas_core.current_grid_env()
|
||||
assert grid_env
|
||||
return grid_env[axis].axis_index
|
||||
program_id_p.def_impl(_program_id_impl)
|
||||
|
||||
@ -87,6 +84,7 @@ def _num_programs_bind(*, axis: int):
|
||||
@num_programs_p.def_impl
|
||||
def _num_programs_impl(*, axis: int):
|
||||
grid_env = pallas_core.current_grid_env()
|
||||
assert grid_env
|
||||
return jnp.asarray(grid_env[axis].axis_size, dtype=jnp.int32)
|
||||
|
||||
@num_programs_p.def_abstract_eval
|
||||
@ -569,7 +567,7 @@ def debug_print(fmt: str, *args: jax.ArrayLike):
|
||||
""" # fmt: skip
|
||||
has_placeholders = False
|
||||
if fmt:
|
||||
_, field_name, *_ = next(string.Formatter().parse(fmt))
|
||||
_, field_name, *_ = next(iter(string.Formatter().parse(fmt)))
|
||||
has_placeholders = field_name is not None
|
||||
return debug_print_p.bind(*args, fmt=fmt, has_placeholders=has_placeholders)
|
||||
|
||||
|
@ -1099,7 +1099,7 @@ def _flash_attention_bwd_dkv(
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
grid=grid,
|
||||
in_specs=in_specs, # type: ignore
|
||||
in_specs=in_specs,
|
||||
out_specs=out_specs,
|
||||
scratch_shapes=scratch_shapes,
|
||||
),
|
||||
@ -1444,8 +1444,8 @@ def _flash_attention_bwd_dq(
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
grid=grid,
|
||||
in_specs=in_specs, # type: ignore
|
||||
out_specs=out_specs, # type: ignore
|
||||
in_specs=in_specs,
|
||||
out_specs=out_specs,
|
||||
scratch_shapes=scratch_shapes,
|
||||
),
|
||||
out_shape=out_shapes,
|
||||
|
Loading…
x
Reference in New Issue
Block a user