[pallas:mosaic_gpu] Do not store the grid mapping in ModuleContext

We really only ever use the grid names.

PiperOrigin-RevId: 703108864
This commit is contained in:
Sergei Lebedev 2024-12-05 07:32:10 -08:00 committed by jax authors
parent d034680f6d
commit e5102957b0

View File

@ -17,13 +17,13 @@
from __future__ import annotations
import collections
from collections.abc import MutableMapping, MutableSequence, Sequence
from collections.abc import Hashable, MutableMapping, MutableSequence, Sequence
import contextlib
import dataclasses
import functools
import itertools as it
import math
from typing import Any, Hashable, Protocol, cast
from typing import Any, Protocol, cast
import jax
from jax import lax
@ -192,7 +192,7 @@ def _reduce_sum_resource_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int
@dataclasses.dataclass
class ModuleContext:
name: str
grid_mapping: pallas_core.GridMapping
grid_names: Sequence[Hashable] | None
program_ids: Sequence[ir.Value] | None
approx_math: bool
runtime_smem: ir.Value # ir.MemRefType
@ -517,7 +517,7 @@ def lower_jaxpr_to_module(
grouped_barriers[barrier].append(barrier_ref)
module_ctx = ModuleContext(
name_and_src_info.name,
grid_mapping,
grid_mapping.grid_names,
None,
approx_math,
runtime_smem,
@ -1290,7 +1290,7 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
@register_lowering_rule(lax.axis_index_p)
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
i32 = ir.IntegerType.get_signless(32)
grid_names = ctx.module_ctx.grid_mapping.grid_names
grid_names = ctx.module_ctx.grid_names
squashed_dims = ctx.module_ctx.squashed_dims
if squashed_dims:
unsquashed_names = grid_names[-3:]