pallas_call now has only one way to pass compiler_params=

Previously, it was possible to do

    pallas_call(..., foo=42)

and also

    pallas_call(..., compiler_params=dict(foo=42))

PiperOrigin-RevId: 623277572
This commit is contained in:
Sergei Lebedev 2024-04-09 14:13:06 -07:00 committed by jax authors
parent 008f87d7a3
commit a205c9120a
6 changed files with 19 additions and 18 deletions

View File

@ -147,8 +147,10 @@ grid axes over cores. This is an opt-in procedure. To allow that,
..
pallas_call(
...,
mosaic_params=dict(
dimension_semantics=["parallel", "parallel", "arbitrary"]
compiler_params=dict(
mosaic=dict(
dimension_semantics=["parallel", "parallel", "arbitrary"]
)
),
)

View File

@ -561,14 +561,10 @@ def pallas_call(
interpret: bool = False,
name: str | None = None,
compiler_params: dict[str, Any] | None = None,
**compiler_params_: Any,
):
name = _extract_function_name(f, name)
if compiler_params is None:
compiler_params = {}
assert not (compiler_params and compiler_params_)
if compiler_params_:
compiler_params = compiler_params_
if grid is not None and grid_spec is not None:
raise ValueError("Cannot specify both grid and grid_spec at the same time.")
if grid_spec is None:

View File

@ -165,8 +165,9 @@ def attn_unbatched(
),
), # m
],
num_warps=num_warps_,
num_stages=num_stages,
compiler_params=dict(
triton=dict(num_warps=num_warps_, num_stages=num_stages)
),
out_shape=[
jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o
jax.ShapeDtypeStruct(

View File

@ -136,7 +136,7 @@ def all_gather(x, *, mesh: jax.sharding.Mesh, axis_name: str | Sequence[str],
out = pl.pallas_call(
functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh),
out_shape=out_shape,
mosaic_params=dict(collective_id=0),
compiler_params=dict(mosaic=dict(collective_id=0)),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
scratch_shapes=(

View File

@ -537,9 +537,11 @@ def gmm(
scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)],
),
input_output_aliases=input_output_aliases,
mosaic_params=dict(
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
cost_estimate=cost_estimate,
compiler_params=dict(
mosaic=dict(
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
cost_estimate=cost_estimate,
)
),
interpret=interpret,
)
@ -777,9 +779,11 @@ def tgmm(
scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)],
),
input_output_aliases=input_output_aliases,
mosaic_params=dict(
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
cost_estimate=cost_estimate,
compiler_params=dict(
mosaic=dict(
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
cost_estimate=cost_estimate,
)
),
interpret=interpret,
)

View File

@ -487,9 +487,7 @@ def paged_attention(
pltpu.SemaphoreType.DMA,
),
),
mosaic_params=dict(
dimension_semantics=dimension_sematics,
),
compiler_params=dict(mosaic=dict(dimension_semantics=dimension_sematics)),
out_shape=[
jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch),
jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32),