mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Update usages of mosaic compiler params with TPUCompilerParams.
PiperOrigin-RevId: 667762992
This commit is contained in:
parent
57c0d59d04
commit
9027101737
@ -68,8 +68,8 @@ class TPUCompilerParams(pallas_core.CompilerParams):
|
||||
device_type: The device type to compile for.
|
||||
"""
|
||||
PLATFORM: ClassVar[str] = "mosaic"
|
||||
dimension_semantics: list[str] | None = None
|
||||
allow_input_fusion: list[bool] | None = None
|
||||
dimension_semantics: Sequence[str] | None = None
|
||||
allow_input_fusion: Sequence[bool] | None = None
|
||||
vmem_limit_bytes: int | None = None
|
||||
collective_id: int | None = None
|
||||
flags: dict[str, Any] | None = None
|
||||
|
@ -136,7 +136,7 @@ def all_gather(x, *, mesh: jax.sharding.Mesh, axis_name: str | Sequence[str],
|
||||
out = pl.pallas_call(
|
||||
functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh),
|
||||
out_shape=out_shape,
|
||||
compiler_params=dict(mosaic=dict(collective_id=0)),
|
||||
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
scratch_shapes=(
|
||||
|
@ -745,15 +745,13 @@ def _flash_attention_impl(
|
||||
),
|
||||
out_shape=out_shape,
|
||||
debug=debug,
|
||||
compiler_params=dict(
|
||||
mosaic=dict(
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=(
|
||||
"parallel",
|
||||
"parallel",
|
||||
"parallel",
|
||||
"arbitrary",
|
||||
)
|
||||
)
|
||||
),
|
||||
)(q, k, v, ab, q_segment_ids, kv_segment_ids)
|
||||
if save_residuals:
|
||||
@ -1105,15 +1103,13 @@ def _flash_attention_bwd_dkv(
|
||||
),
|
||||
out_shape=out_shapes,
|
||||
debug=debug,
|
||||
compiler_params=dict(
|
||||
mosaic=dict(
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=(
|
||||
"parallel",
|
||||
"parallel",
|
||||
"parallel",
|
||||
"arbitrary",
|
||||
)
|
||||
)
|
||||
),
|
||||
)(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di)
|
||||
assert dk.shape == k.shape
|
||||
@ -1450,15 +1446,13 @@ def _flash_attention_bwd_dq(
|
||||
),
|
||||
out_shape=out_shapes,
|
||||
debug=debug,
|
||||
compiler_params=dict(
|
||||
mosaic=dict(
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=(
|
||||
"parallel",
|
||||
"parallel",
|
||||
"parallel",
|
||||
"arbitrary",
|
||||
)
|
||||
)
|
||||
),
|
||||
)(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di)
|
||||
|
||||
|
@ -78,8 +78,7 @@ def matmul(
|
||||
grid=(x.shape[0] // l, y.shape[1] // r, x.shape[1] // block_k),
|
||||
scratch_shapes=[pltpu.VMEM((l, r), acc_dtype)],
|
||||
),
|
||||
compiler_params=dict(
|
||||
mosaic=dict(dimension_semantics=("parallel", "parallel", "arbitrary"))
|
||||
),
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=("parallel", "parallel", "arbitrary")),
|
||||
debug=debug,
|
||||
)(x, y)
|
||||
|
@ -538,11 +538,8 @@ def gmm(
|
||||
scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)],
|
||||
),
|
||||
input_output_aliases=input_output_aliases,
|
||||
compiler_params=dict(
|
||||
mosaic=dict(
|
||||
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
|
||||
)
|
||||
),
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=("parallel", "arbitrary", "arbitrary")),
|
||||
interpret=interpret,
|
||||
cost_estimate=cost_estimate,
|
||||
)
|
||||
@ -780,13 +777,10 @@ def tgmm(
|
||||
scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)],
|
||||
),
|
||||
input_output_aliases=input_output_aliases,
|
||||
compiler_params=dict(
|
||||
mosaic=dict(
|
||||
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
|
||||
cost_estimate=cost_estimate,
|
||||
)
|
||||
),
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=("parallel", "arbitrary", "arbitrary")),
|
||||
interpret=interpret,
|
||||
cost_estimate=cost_estimate,
|
||||
)
|
||||
|
||||
out = call_gmm(
|
||||
|
@ -640,7 +640,8 @@ def paged_attention(
|
||||
grid=grid,
|
||||
scratch_shapes=scratch_shapes,
|
||||
),
|
||||
compiler_params=dict(mosaic=dict(dimension_semantics=dimension_sematics)),
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=dimension_sematics),
|
||||
out_shape=[
|
||||
jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch),
|
||||
jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32),
|
||||
|
@ -1071,11 +1071,6 @@ def _splash_attention_forward(
|
||||
out_shapes += [None]
|
||||
out_specs += [None]
|
||||
|
||||
mosaic_params = dict(
|
||||
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
|
||||
flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True},
|
||||
)
|
||||
|
||||
kernel_name = get_kernel_name(
|
||||
dataclasses.asdict(block_sizes),
|
||||
is_mqa=is_mqa,
|
||||
@ -1112,7 +1107,9 @@ def _splash_attention_forward(
|
||||
out_specs=out_specs,
|
||||
grid=grid,
|
||||
),
|
||||
compiler_params=dict(mosaic=mosaic_params),
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
|
||||
),
|
||||
out_shape=out_shapes,
|
||||
name=kernel_name,
|
||||
interpret=interpret,
|
||||
@ -1545,11 +1542,6 @@ def _splash_attention_bwd_dq(
|
||||
)
|
||||
num_scalar_prefetch = 3
|
||||
|
||||
mosaic_params = dict(
|
||||
dimension_semantics=("arbitrary", "arbitrary", "arbitrary"),
|
||||
flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True},
|
||||
)
|
||||
|
||||
kernel_name = get_kernel_name(
|
||||
dict(
|
||||
block_q_dq=bq,
|
||||
@ -1573,7 +1565,9 @@ def _splash_attention_bwd_dq(
|
||||
grid=grid,
|
||||
),
|
||||
out_shape=out_shapes,
|
||||
compiler_params=dict(mosaic=mosaic_params),
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=("arbitrary", "arbitrary", "arbitrary"),
|
||||
),
|
||||
name=kernel_name,
|
||||
interpret=interpret,
|
||||
)(
|
||||
@ -2088,16 +2082,6 @@ def _splash_attention_bwd_dkv(
|
||||
)
|
||||
num_scalar_prefetch = 3
|
||||
|
||||
# We set all dimensions to arbitrary because:
|
||||
# 1) for kv_seq_len, the splash attention prefetch schedule assumes no
|
||||
# megacore
|
||||
# 2) for heads, we are reducing over heads
|
||||
# 3) for q_seq_len, we are reducing over it to compute dkv
|
||||
mosaic_params = dict(
|
||||
dimension_semantics=("arbitrary", "arbitrary", "arbitrary"),
|
||||
flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True},
|
||||
)
|
||||
|
||||
kernel_name = get_kernel_name(
|
||||
dict(
|
||||
block_q_dkv=bq,
|
||||
@ -2122,7 +2106,14 @@ def _splash_attention_bwd_dkv(
|
||||
grid=grid,
|
||||
),
|
||||
out_shape=out_shapes,
|
||||
compiler_params=dict(mosaic=mosaic_params),
|
||||
# We set all dimensions to arbitrary because:
|
||||
# 1) for kv_seq_len, the splash attention prefetch schedule assumes no
|
||||
# megacore
|
||||
# 2) for heads, we are reducing over heads
|
||||
# 3) for q_seq_len, we are reducing over it to compute dkv
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=("arbitrary", "arbitrary", "arbitrary"),
|
||||
),
|
||||
name=kernel_name,
|
||||
interpret=interpret,
|
||||
)(
|
||||
|
Loading…
x
Reference in New Issue
Block a user