From 90271017375110371779a165358c217beac127d2 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 26 Aug 2024 16:51:05 -0700 Subject: [PATCH] Update usages of mosaic compiler params with TPUCompilerParams. PiperOrigin-RevId: 667762992 --- jax/_src/pallas/mosaic/core.py | 4 +- jax/experimental/pallas/ops/tpu/all_gather.py | 2 +- .../pallas/ops/tpu/flash_attention.py | 12 ++---- jax/experimental/pallas/ops/tpu/matmul.py | 5 +-- .../pallas/ops/tpu/megablox/gmm.py | 16 +++----- .../paged_attention/paged_attention_kernel.py | 3 +- .../splash_attention_kernel.py | 37 +++++++------------ 7 files changed, 29 insertions(+), 50 deletions(-) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 94b53a406..e549ee05e 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -68,8 +68,8 @@ class TPUCompilerParams(pallas_core.CompilerParams): device_type: The device type to compile for. """ PLATFORM: ClassVar[str] = "mosaic" - dimension_semantics: list[str] | None = None - allow_input_fusion: list[bool] | None = None + dimension_semantics: Sequence[str] | None = None + allow_input_fusion: Sequence[bool] | None = None vmem_limit_bytes: int | None = None collective_id: int | None = None flags: dict[str, Any] | None = None diff --git a/jax/experimental/pallas/ops/tpu/all_gather.py b/jax/experimental/pallas/ops/tpu/all_gather.py index e121db894..8fb975504 100644 --- a/jax/experimental/pallas/ops/tpu/all_gather.py +++ b/jax/experimental/pallas/ops/tpu/all_gather.py @@ -136,7 +136,7 @@ def all_gather(x, *, mesh: jax.sharding.Mesh, axis_name: str | Sequence[str], out = pl.pallas_call( functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh), out_shape=out_shape, - compiler_params=dict(mosaic=dict(collective_id=0)), + compiler_params=pltpu.TPUCompilerParams(collective_id=0), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, scratch_shapes=( diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index f0332a87b..6ce3a1886 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -745,15 +745,13 @@ def _flash_attention_impl( ), out_shape=out_shape, debug=debug, - compiler_params=dict( - mosaic=dict( + compiler_params=pltpu.TPUCompilerParams( dimension_semantics=( "parallel", "parallel", "parallel", "arbitrary", ) - ) ), )(q, k, v, ab, q_segment_ids, kv_segment_ids) if save_residuals: @@ -1105,15 +1103,13 @@ def _flash_attention_bwd_dkv( ), out_shape=out_shapes, debug=debug, - compiler_params=dict( - mosaic=dict( + compiler_params=pltpu.TPUCompilerParams( dimension_semantics=( "parallel", "parallel", "parallel", "arbitrary", ) - ) ), )(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di) assert dk.shape == k.shape @@ -1450,15 +1446,13 @@ def _flash_attention_bwd_dq( ), out_shape=out_shapes, debug=debug, - compiler_params=dict( - mosaic=dict( + compiler_params=pltpu.TPUCompilerParams( dimension_semantics=( "parallel", "parallel", "parallel", "arbitrary", ) - ) ), )(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di) diff --git a/jax/experimental/pallas/ops/tpu/matmul.py b/jax/experimental/pallas/ops/tpu/matmul.py index 2145fbc95..4ff82acbb 100644 --- a/jax/experimental/pallas/ops/tpu/matmul.py +++ b/jax/experimental/pallas/ops/tpu/matmul.py @@ -78,8 +78,7 @@ def matmul( grid=(x.shape[0] // l, y.shape[1] // r, x.shape[1] // block_k), scratch_shapes=[pltpu.VMEM((l, r), acc_dtype)], ), - compiler_params=dict( - mosaic=dict(dimension_semantics=("parallel", "parallel", "arbitrary")) - ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), debug=debug, )(x, y) diff --git a/jax/experimental/pallas/ops/tpu/megablox/gmm.py b/jax/experimental/pallas/ops/tpu/megablox/gmm.py index 320851422..5c2f93859 100644 --- a/jax/experimental/pallas/ops/tpu/megablox/gmm.py +++ b/jax/experimental/pallas/ops/tpu/megablox/gmm.py @@ -538,11 +538,8 @@ def gmm( scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)], ), input_output_aliases=input_output_aliases, - compiler_params=dict( - mosaic=dict( - dimension_semantics=("parallel", "arbitrary", "arbitrary"), - ) - ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "arbitrary", "arbitrary")), interpret=interpret, cost_estimate=cost_estimate, ) @@ -780,13 +777,10 @@ def tgmm( scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)], ), input_output_aliases=input_output_aliases, - compiler_params=dict( - mosaic=dict( - dimension_semantics=("parallel", "arbitrary", "arbitrary"), - cost_estimate=cost_estimate, - ) - ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "arbitrary", "arbitrary")), interpret=interpret, + cost_estimate=cost_estimate, ) out = call_gmm( diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index 82fa5f742..cd811a874 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -640,7 +640,8 @@ def paged_attention( grid=grid, scratch_shapes=scratch_shapes, ), - compiler_params=dict(mosaic=dict(dimension_semantics=dimension_sematics)), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=dimension_sematics), out_shape=[ jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch), jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32), diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py index 4ae761d78..536c32e57 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py @@ -1071,11 +1071,6 @@ def _splash_attention_forward( out_shapes += [None] out_specs += [None] - mosaic_params = dict( - dimension_semantics=("parallel", "arbitrary", "arbitrary"), - flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True}, - ) - kernel_name = get_kernel_name( dataclasses.asdict(block_sizes), is_mqa=is_mqa, @@ -1112,7 +1107,9 @@ def _splash_attention_forward( out_specs=out_specs, grid=grid, ), - compiler_params=dict(mosaic=mosaic_params), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "arbitrary", "arbitrary"), + ), out_shape=out_shapes, name=kernel_name, interpret=interpret, @@ -1545,11 +1542,6 @@ def _splash_attention_bwd_dq( ) num_scalar_prefetch = 3 - mosaic_params = dict( - dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), - flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True}, - ) - kernel_name = get_kernel_name( dict( block_q_dq=bq, @@ -1573,7 +1565,9 @@ def _splash_attention_bwd_dq( grid=grid, ), out_shape=out_shapes, - compiler_params=dict(mosaic=mosaic_params), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), + ), name=kernel_name, interpret=interpret, )( @@ -2088,16 +2082,6 @@ def _splash_attention_bwd_dkv( ) num_scalar_prefetch = 3 - # We set all dimensions to arbitrary because: - # 1) for kv_seq_len, the splash attention prefetch schedule assumes no - # megacore - # 2) for heads, we are reducing over heads - # 3) for q_seq_len, we are reducing over it to compute dkv - mosaic_params = dict( - dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), - flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True}, - ) - kernel_name = get_kernel_name( dict( block_q_dkv=bq, @@ -2122,7 +2106,14 @@ def _splash_attention_bwd_dkv( grid=grid, ), out_shape=out_shapes, - compiler_params=dict(mosaic=mosaic_params), + # We set all dimensions to arbitrary because: + # 1) for kv_seq_len, the splash attention prefetch schedule assumes no + # megacore + # 2) for heads, we are reducing over heads + # 3) for q_seq_len, we are reducing over it to compute dkv + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), + ), name=kernel_name, interpret=interpret, )(