Update usages of mosaic compiler params with TPUCompilerParams.

PiperOrigin-RevId: 667762992
This commit is contained in:
Justin Fu 2024-08-26 16:51:05 -07:00 committed by jax authors
parent 57c0d59d04
commit 9027101737
7 changed files with 29 additions and 50 deletions

View File

@ -68,8 +68,8 @@ class TPUCompilerParams(pallas_core.CompilerParams):
device_type: The device type to compile for.
"""
PLATFORM: ClassVar[str] = "mosaic"
dimension_semantics: list[str] | None = None
allow_input_fusion: list[bool] | None = None
dimension_semantics: Sequence[str] | None = None
allow_input_fusion: Sequence[bool] | None = None
vmem_limit_bytes: int | None = None
collective_id: int | None = None
flags: dict[str, Any] | None = None

View File

@ -136,7 +136,7 @@ def all_gather(x, *, mesh: jax.sharding.Mesh, axis_name: str | Sequence[str],
out = pl.pallas_call(
functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh),
out_shape=out_shape,
compiler_params=dict(mosaic=dict(collective_id=0)),
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
scratch_shapes=(

View File

@ -745,15 +745,13 @@ def _flash_attention_impl(
),
out_shape=out_shape,
debug=debug,
compiler_params=dict(
mosaic=dict(
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=(
"parallel",
"parallel",
"parallel",
"arbitrary",
)
)
),
)(q, k, v, ab, q_segment_ids, kv_segment_ids)
if save_residuals:
@ -1105,15 +1103,13 @@ def _flash_attention_bwd_dkv(
),
out_shape=out_shapes,
debug=debug,
compiler_params=dict(
mosaic=dict(
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=(
"parallel",
"parallel",
"parallel",
"arbitrary",
)
)
),
)(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di)
assert dk.shape == k.shape
@ -1450,15 +1446,13 @@ def _flash_attention_bwd_dq(
),
out_shape=out_shapes,
debug=debug,
compiler_params=dict(
mosaic=dict(
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=(
"parallel",
"parallel",
"parallel",
"arbitrary",
)
)
),
)(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di)

View File

@ -78,8 +78,7 @@ def matmul(
grid=(x.shape[0] // l, y.shape[1] // r, x.shape[1] // block_k),
scratch_shapes=[pltpu.VMEM((l, r), acc_dtype)],
),
compiler_params=dict(
mosaic=dict(dimension_semantics=("parallel", "parallel", "arbitrary"))
),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
debug=debug,
)(x, y)

View File

@ -538,11 +538,8 @@ def gmm(
scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)],
),
input_output_aliases=input_output_aliases,
compiler_params=dict(
mosaic=dict(
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
)
),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel", "arbitrary", "arbitrary")),
interpret=interpret,
cost_estimate=cost_estimate,
)
@ -780,13 +777,10 @@ def tgmm(
scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)],
),
input_output_aliases=input_output_aliases,
compiler_params=dict(
mosaic=dict(
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
cost_estimate=cost_estimate,
)
),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel", "arbitrary", "arbitrary")),
interpret=interpret,
cost_estimate=cost_estimate,
)
out = call_gmm(

View File

@ -640,7 +640,8 @@ def paged_attention(
grid=grid,
scratch_shapes=scratch_shapes,
),
compiler_params=dict(mosaic=dict(dimension_semantics=dimension_sematics)),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=dimension_sematics),
out_shape=[
jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch),
jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32),

View File

@ -1071,11 +1071,6 @@ def _splash_attention_forward(
out_shapes += [None]
out_specs += [None]
mosaic_params = dict(
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True},
)
kernel_name = get_kernel_name(
dataclasses.asdict(block_sizes),
is_mqa=is_mqa,
@ -1112,7 +1107,9 @@ def _splash_attention_forward(
out_specs=out_specs,
grid=grid,
),
compiler_params=dict(mosaic=mosaic_params),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
),
out_shape=out_shapes,
name=kernel_name,
interpret=interpret,
@ -1545,11 +1542,6 @@ def _splash_attention_bwd_dq(
)
num_scalar_prefetch = 3
mosaic_params = dict(
dimension_semantics=("arbitrary", "arbitrary", "arbitrary"),
flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True},
)
kernel_name = get_kernel_name(
dict(
block_q_dq=bq,
@ -1573,7 +1565,9 @@ def _splash_attention_bwd_dq(
grid=grid,
),
out_shape=out_shapes,
compiler_params=dict(mosaic=mosaic_params),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("arbitrary", "arbitrary", "arbitrary"),
),
name=kernel_name,
interpret=interpret,
)(
@ -2088,16 +2082,6 @@ def _splash_attention_bwd_dkv(
)
num_scalar_prefetch = 3
# We set all dimensions to arbitrary because:
# 1) for kv_seq_len, the splash attention prefetch schedule assumes no
# megacore
# 2) for heads, we are reducing over heads
# 3) for q_seq_len, we are reducing over it to compute dkv
mosaic_params = dict(
dimension_semantics=("arbitrary", "arbitrary", "arbitrary"),
flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True},
)
kernel_name = get_kernel_name(
dict(
block_q_dkv=bq,
@ -2122,7 +2106,14 @@ def _splash_attention_bwd_dkv(
grid=grid,
),
out_shape=out_shapes,
compiler_params=dict(mosaic=mosaic_params),
# We set all dimensions to arbitrary because:
# 1) for kv_seq_len, the splash attention prefetch schedule assumes no
# megacore
# 2) for heads, we are reducing over heads
# 3) for q_seq_len, we are reducing over it to compute dkv
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("arbitrary", "arbitrary", "arbitrary"),
),
name=kernel_name,
interpret=interpret,
)(