mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
pallas_call now has only one way to pass compiler_params=
Previously, it was possible to do pallas_call(..., foo=42) and also pallas_call(..., compiler_params=dict(foo=42)) PiperOrigin-RevId: 623277572
This commit is contained in:
parent
008f87d7a3
commit
a205c9120a
@ -147,8 +147,10 @@ grid axes over cores. This is an opt-in procedure. To allow that,
|
||||
..
|
||||
pallas_call(
|
||||
...,
|
||||
mosaic_params=dict(
|
||||
dimension_semantics=["parallel", "parallel", "arbitrary"]
|
||||
compiler_params=dict(
|
||||
mosaic=dict(
|
||||
dimension_semantics=["parallel", "parallel", "arbitrary"]
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -561,14 +561,10 @@ def pallas_call(
|
||||
interpret: bool = False,
|
||||
name: str | None = None,
|
||||
compiler_params: dict[str, Any] | None = None,
|
||||
**compiler_params_: Any,
|
||||
):
|
||||
name = _extract_function_name(f, name)
|
||||
if compiler_params is None:
|
||||
compiler_params = {}
|
||||
assert not (compiler_params and compiler_params_)
|
||||
if compiler_params_:
|
||||
compiler_params = compiler_params_
|
||||
if grid is not None and grid_spec is not None:
|
||||
raise ValueError("Cannot specify both grid and grid_spec at the same time.")
|
||||
if grid_spec is None:
|
||||
|
@ -165,8 +165,9 @@ def attn_unbatched(
|
||||
),
|
||||
), # m
|
||||
],
|
||||
num_warps=num_warps_,
|
||||
num_stages=num_stages,
|
||||
compiler_params=dict(
|
||||
triton=dict(num_warps=num_warps_, num_stages=num_stages)
|
||||
),
|
||||
out_shape=[
|
||||
jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o
|
||||
jax.ShapeDtypeStruct(
|
||||
|
@ -136,7 +136,7 @@ def all_gather(x, *, mesh: jax.sharding.Mesh, axis_name: str | Sequence[str],
|
||||
out = pl.pallas_call(
|
||||
functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh),
|
||||
out_shape=out_shape,
|
||||
mosaic_params=dict(collective_id=0),
|
||||
compiler_params=dict(mosaic=dict(collective_id=0)),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
scratch_shapes=(
|
||||
|
@ -537,9 +537,11 @@ def gmm(
|
||||
scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)],
|
||||
),
|
||||
input_output_aliases=input_output_aliases,
|
||||
mosaic_params=dict(
|
||||
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
|
||||
cost_estimate=cost_estimate,
|
||||
compiler_params=dict(
|
||||
mosaic=dict(
|
||||
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
|
||||
cost_estimate=cost_estimate,
|
||||
)
|
||||
),
|
||||
interpret=interpret,
|
||||
)
|
||||
@ -777,9 +779,11 @@ def tgmm(
|
||||
scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)],
|
||||
),
|
||||
input_output_aliases=input_output_aliases,
|
||||
mosaic_params=dict(
|
||||
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
|
||||
cost_estimate=cost_estimate,
|
||||
compiler_params=dict(
|
||||
mosaic=dict(
|
||||
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
|
||||
cost_estimate=cost_estimate,
|
||||
)
|
||||
),
|
||||
interpret=interpret,
|
||||
)
|
||||
|
@ -487,9 +487,7 @@ def paged_attention(
|
||||
pltpu.SemaphoreType.DMA,
|
||||
),
|
||||
),
|
||||
mosaic_params=dict(
|
||||
dimension_semantics=dimension_sematics,
|
||||
),
|
||||
compiler_params=dict(mosaic=dict(dimension_semantics=dimension_sematics)),
|
||||
out_shape=[
|
||||
jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch),
|
||||
jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32),
|
||||
|
Loading…
x
Reference in New Issue
Block a user