mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[pallas:mosaic_gpu] Do not store the grid mapping in ModuleContext
We really only ever use the grid names. PiperOrigin-RevId: 703108864
This commit is contained in:
parent
d034680f6d
commit
e5102957b0
@ -17,13 +17,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections.abc import MutableMapping, MutableSequence, Sequence
|
||||
from collections.abc import Hashable, MutableMapping, MutableSequence, Sequence
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools as it
|
||||
import math
|
||||
from typing import Any, Hashable, Protocol, cast
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
@ -192,7 +192,7 @@ def _reduce_sum_resource_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int
|
||||
@dataclasses.dataclass
|
||||
class ModuleContext:
|
||||
name: str
|
||||
grid_mapping: pallas_core.GridMapping
|
||||
grid_names: Sequence[Hashable] | None
|
||||
program_ids: Sequence[ir.Value] | None
|
||||
approx_math: bool
|
||||
runtime_smem: ir.Value # ir.MemRefType
|
||||
@ -517,7 +517,7 @@ def lower_jaxpr_to_module(
|
||||
grouped_barriers[barrier].append(barrier_ref)
|
||||
module_ctx = ModuleContext(
|
||||
name_and_src_info.name,
|
||||
grid_mapping,
|
||||
grid_mapping.grid_names,
|
||||
None,
|
||||
approx_math,
|
||||
runtime_smem,
|
||||
@ -1290,7 +1290,7 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
|
||||
@register_lowering_rule(lax.axis_index_p)
|
||||
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
grid_names = ctx.module_ctx.grid_mapping.grid_names
|
||||
grid_names = ctx.module_ctx.grid_names
|
||||
squashed_dims = ctx.module_ctx.squashed_dims
|
||||
if squashed_dims:
|
||||
unsquashed_names = grid_names[-3:]
|
||||
|
Loading…
x
Reference in New Issue
Block a user