[Pallas] Pass in compiler params via explicit compiler_params argument instead of passing via **kwargs

This is a change that makes the API a bit more intuitive and avoids footguns like accidentally passing in `in_spec` instead of `in_specs` because previously kwargs that weren't used by any downstream lowering would be ignored and users would get weird errors as a result.

This change doesn't deprecate the old way of passing in compiler params but it will be deprecated soon after this.

PiperOrigin-RevId: 613239439
This commit is contained in:
Sharad Vikram 2024-03-06 09:15:36 -08:00 committed by jax authors
parent d0e0ca1e52
commit 30973a9474
12 changed files with 169 additions and 83 deletions

View File

@ -619,7 +619,8 @@
" in_specs=[block_spec, block_spec],\n",
" out_specs=block_spec,\n",
" grid=(2,),\n",
" mosaic_params=dict(dimension_semantics=(\"parallel\",)))(x, y)\n",
" compiler_params=dict(mosaic=dict(dimension_semantics=(\"parallel\",))))(\n",
" x, y)\n",
"\n",
"x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n",
"add_matrices_pipelined_megacore(x, y)"

View File

@ -373,7 +373,8 @@ def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,),
mosaic_params=dict(dimension_semantics=("parallel",)))(x, y)
compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",))))(
x, y)
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices_pipelined_megacore(x, y)

View File

@ -41,8 +41,7 @@ def pallas_call_tpu_lowering_rule(
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
debug: bool,
interpret: bool,
mosaic_params: dict[str, Any] | None = None,
**compiler_params: Any):
compiler_params: dict[str, Any]):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
if interpret:
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
@ -51,9 +50,17 @@ def pallas_call_tpu_lowering_rule(
which_linear=which_linear,
interpret=interpret, debug=debug,
input_output_aliases=input_output_aliases,
grid_mapping=grid_mapping, **compiler_params)
grid_mapping=grid_mapping,
compiler_params=compiler_params)
if debug:
print(jaxpr)
if 'mosaic_params' in compiler_params:
assert 'mosaic' not in compiler_params
mosaic_params = compiler_params['mosaic_params']
elif 'mosaic' in compiler_params:
mosaic_params = compiler_params['mosaic']
else:
mosaic_params = {}
mesh = None
axis_context = ctx.module_context.axis_context
if axis_context is not None:

View File

@ -104,7 +104,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
in_shapes,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
**compiler_params: Any):
compiler_params: Any):
dynamic_grid_args, args = split_list( # type: ignore
args, [grid_mapping.num_dynamic_grid_bounds]
)
@ -234,7 +234,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
grid_mapping=grid_mapping, interpret=interpret,
debug=debug,
input_output_aliases=input_output_aliases,
**compiler_params)
compiler_params=compiler_params)
pallas_call_p.def_impl(_pallas_call_impl)
def _pallas_call_abstract_eval(*avals, out_shapes, **_):
@ -243,7 +243,7 @@ pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
input_output_aliases: tuple[tuple[int, int], ...],
in_shapes, out_shapes, grid_mapping, debug, interpret, **compiler_params: Any):
in_shapes, out_shapes, grid_mapping, debug, interpret, compiler_params: Any):
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
if grid_mapping.num_index_operands:
@ -285,7 +285,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
interpret=interpret,
debug=debug,
input_output_aliases=(),
**compiler_params,
compiler_params=compiler_params,
)
out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2])
return out_primals, out_tangents
@ -336,7 +336,7 @@ def _pallas_call_batching_rule(args, dims, *,
debug: bool,
interpret: bool,
which_linear: tuple[bool, ...],
**compiler_params: Any):
compiler_params: Any):
def _maybe_squeeze_out_bdim(
x: jax.Array, bdim: int | batching.NotMapped
@ -449,7 +449,7 @@ def _pallas_call_batching_rule(args, dims, *,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
**compiler_params,
compiler_params=compiler_params,
)
return out, (0,) * len(out)
batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule
@ -535,9 +535,15 @@ def pallas_call(
input_output_aliases: dict[int, int] = {},
interpret: bool = False,
name: str | None = None,
**compiler_params: Any,
compiler_params: dict[str, Any] | None = None,
**compiler_params_: Any,
):
name = _extract_function_name(f, name)
if compiler_params is None:
compiler_params = {}
assert not (compiler_params and compiler_params_)
if compiler_params_:
compiler_params = compiler_params_
if grid is not None and grid_spec is not None:
raise ValueError("Cannot specify both grid and grid_spec at the same time.")
if grid_spec is None:
@ -568,7 +574,7 @@ def pallas_call(
interpret=interpret,
grid_mapping=grid_mapping,
input_output_aliases=tuple(input_output_aliases.items()),
**compiler_params)
compiler_params=compiler_params)
out = tree_util.tree_unflatten(out_tree, out_flat)
return out
return wrapped

View File

@ -261,8 +261,7 @@ def pallas_call_lowering(
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: pallas_core.GridMapping,
triton_params: dict[str, Any] | None = None,
**compiler_params: Any,
compiler_params: dict[str, Any],
):
if interpret:
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
@ -277,24 +276,22 @@ def pallas_call_lowering(
debug=debug,
input_output_aliases=input_output_aliases,
grid_mapping=grid_mapping,
**compiler_params,
compiler_params=compiler_params,
)
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError(
"dynamic grid bounds not supported in the Triton backend"
)
num_warps = compiler_params.pop("num_warps", 4)
triton_params = compiler_params.get("triton_params", {})
triton_compiler_params = compiler_params.get("triton", {})
num_warps = triton_compiler_params.pop("num_warps", 4)
if len(ctx.module_context.platforms) > 1:
raise NotImplementedError("multi-platform lowering for Pallas kernels")
if ctx.module_context.platforms[0] == "rocm":
num_stages = compiler_params.pop("num_stages", 1)
num_stages = triton_compiler_params.pop("num_stages", 1)
else:
num_stages = compiler_params.pop("num_stages", 3)
if triton_params is None:
triton_params = {}
num_stages = triton_compiler_params.pop("num_stages", 3)
if debug:
print(jaxpr)
@ -318,7 +315,6 @@ def pallas_call_lowering(
triton_params=triton_params,
num_warps=num_warps,
num_stages=num_stages,
**compiler_params,
)

View File

@ -220,8 +220,9 @@ def mha(
out_specs=pl.BlockSpec(
lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
),
num_warps=num_warps_,
num_stages=num_stages,
compiler_params=dict(
triton=dict(num_warps=num_warps_, num_stages=num_stages)
),
out_shape=out_shape,
debug=debug,
interpret=interpret,
@ -294,8 +295,9 @@ def _mha_forward(
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
],
num_warps=num_warps_,
num_stages=num_stages,
compiler_params=dict(
triton=dict(num_warps=num_warps_, num_stages=num_stages)
),
out_shape=out_shape,
debug=debug,
interpret=interpret,
@ -342,8 +344,9 @@ def _preprocess_backward(out, do, l, block_q: int,
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
],
num_warps=4,
num_stages=3,
compiler_params=dict(
triton=dict(num_warps=4, num_stages=3)
),
out_shape=out_shape,
debug=debug,
interpret=interpret,
@ -536,8 +539,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
name="mha_backward",
debug=debug,
interpret=interpret,
num_warps=num_warps,
num_stages=1,
compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)),
input_output_aliases=input_output_aliases,
)(q, k, v, segment_ids, out, do_scaled, l, m, delta, dq)
else:

View File

@ -93,9 +93,15 @@ def layer_norm_forward(
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype),
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype)
]
method = pl.pallas_call(kernel, num_warps=num_warps,
grid=(), out_shape=out_shape, debug=False,
interpret=interpret, name='ln_forward')
method = pl.pallas_call(
kernel,
compiler_params=dict(triton=dict(num_warps=num_warps)),
grid=(),
out_shape=out_shape,
debug=False,
interpret=interpret,
name="ln_forward",
)
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
out, mean, rstd = method(x, weight, bias)
@ -208,9 +214,15 @@ def layer_norm_backward(
kernel = functools.partial(layer_norm_backward_kernel_dx, eps=eps,
block_size=block_size)
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(kernel, num_warps=num_warps,
grid=(), out_shape=out_shape_dx, debug=False,
interpret=interpret, name='ln_backward_dx')
method = pl.pallas_call(
kernel,
compiler_params=dict(triton=dict(num_warps=num_warps)),
grid=(),
out_shape=out_shape_dx,
debug=False,
interpret=interpret,
name="ln_backward_dx",
)
method = jax.vmap(method, in_axes=(0, None, None, 0, 0, 0))
dx = method(reshaped_x, weight, bias, reshaped_do, reshaped_mean, reshaped_rstd)
@ -234,9 +246,15 @@ def layer_norm_backward(
jax.ShapeDtypeStruct(shape=bias.shape, dtype=bias.dtype)
]
grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),)
method = pl.pallas_call(kernel, num_warps=num_warps,
grid=grid_, out_shape=out_shape_dwbias, debug=False,
interpret=interpret, name='ln_backward_dw_db')
method = pl.pallas_call(
kernel,
compiler_params=dict(triton=dict(num_warps=num_warps)),
grid=grid_,
out_shape=out_shape_dwbias,
debug=False,
interpret=interpret,
name="ln_backward_dw_db",
)
dw, dbias = method(reshaped_x, weight, bias, reshaped_do, reshaped_mean, reshaped_rstd)
return dx, dw, dbias
@ -264,9 +282,16 @@ def layer_norm(
kernel = functools.partial(layer_norm_forward_kernel, eps=eps,
block_size=block_size)
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(kernel, num_warps=num_warps, num_stages=num_stages,
grid=(), out_shape=out_shape, debug=False,
interpret=interpret)
method = pl.pallas_call(
kernel,
compiler_params=dict(
triton=dict(num_warps=num_warps, num_stages=num_stages)
),
grid=(),
out_shape=out_shape,
debug=False,
interpret=interpret,
)
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
return method(x, weight, bias)
layer_norm.defvjp(layer_norm_forward, layer_norm_backward)

View File

@ -81,9 +81,15 @@ def rms_norm_forward(
jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype),
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype)
]
method = pl.pallas_call(kernel, num_warps=num_warps,
grid=(), out_shape=out_shape, debug=False,
interpret=interpret, name='rms_forward')
method = pl.pallas_call(
kernel,
compiler_params=dict(triton=dict(num_warps=num_warps)),
grid=(),
out_shape=out_shape,
debug=False,
interpret=interpret,
name="rms_forward",
)
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
out, rstd = method(x, weight, bias)
@ -189,9 +195,15 @@ def rms_norm_backward(
kernel = functools.partial(rms_norm_backward_kernel_dx, eps=eps,
block_size=block_size)
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(kernel, num_warps=num_warps,
grid=(), out_shape=out_shape_dx, debug=False,
interpret=interpret, name='ln_backward_dx')
method = pl.pallas_call(
kernel,
compiler_params=dict(triton=dict(num_warps=num_warps)),
grid=(),
out_shape=out_shape_dx,
debug=False,
interpret=interpret,
name="ln_backward_dx",
)
method = jax.vmap(method, in_axes=(0, None, None, 0, 0))
dx = method(reshaped_x, weight, bias, reshaped_do, reshaped_rstd)
@ -215,9 +227,15 @@ def rms_norm_backward(
jax.ShapeDtypeStruct(shape=bias.shape, dtype=bias.dtype)
]
grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),)
method = pl.pallas_call(kernel, num_warps=num_warps,
grid=grid_, out_shape=out_shape_dwbias, debug=False,
interpret=interpret, name='ln_backward_dw_db')
method = pl.pallas_call(
kernel,
compiler_params=dict(triton=dict(num_warps=num_warps)),
grid=grid_,
out_shape=out_shape_dwbias,
debug=False,
interpret=interpret,
name="ln_backward_dw_db",
)
dw, dbias = method(reshaped_x, weight, bias, reshaped_do, reshaped_rstd)
return dx, dw, dbias
@ -245,9 +263,16 @@ def rms_norm(
kernel = functools.partial(rms_norm_forward_kernel, eps=eps,
block_size=block_size)
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(kernel, num_warps=num_warps, num_stages=num_stages,
grid=(), out_shape=out_shape, debug=False,
interpret=interpret)
method = pl.pallas_call(
kernel,
compiler_params=dict(
triton=dict(num_warps=num_warps, num_stages=num_stages)
),
grid=(),
out_shape=out_shape,
debug=False,
interpret=interpret,
)
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
return method(x, weight, bias)
rms_norm.defvjp(rms_norm_forward, rms_norm_backward)

View File

@ -77,8 +77,14 @@ def softmax(
out_shape = jax.ShapeDtypeStruct(shape=(row_len,), dtype=x.dtype)
kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row)
f = pl.pallas_call(kernel, num_warps=num_warps, num_stages=1, grid=(),
out_shape=out_shape, debug=debug, interpret=interpret)
f = pl.pallas_call(
kernel,
compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)),
grid=(),
out_shape=out_shape,
debug=debug,
interpret=interpret,
)
for _ in range(len(x.shape) - 1):
f = jax.vmap(f)

View File

@ -745,8 +745,15 @@ def _flash_attention_impl(
),
out_shape=out_shape,
debug=debug,
mosaic_params=dict(
dimension_semantics=("parallel", "parallel", "parallel", "arbitrary")
compiler_params=dict(
mosaic=dict(
dimension_semantics=(
"parallel",
"parallel",
"parallel",
"arbitrary",
)
)
),
)(q, k, v, ab, q_segment_ids, kv_segment_ids)
if save_residuals:
@ -1098,12 +1105,14 @@ def _flash_attention_bwd_dkv(
),
out_shape=out_shapes,
debug=debug,
mosaic_params=dict(
dimension_semantics=(
"parallel",
"parallel",
"parallel",
"arbitrary",
compiler_params=dict(
mosaic=dict(
dimension_semantics=(
"parallel",
"parallel",
"parallel",
"arbitrary",
)
)
),
)(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di)
@ -1441,12 +1450,14 @@ def _flash_attention_bwd_dq(
),
out_shape=out_shapes,
debug=debug,
mosaic_params=dict(
dimension_semantics=(
"parallel",
"parallel",
"parallel",
"arbitrary",
compiler_params=dict(
mosaic=dict(
dimension_semantics=(
"parallel",
"parallel",
"parallel",
"arbitrary",
)
)
),
)(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di)

View File

@ -1106,7 +1106,7 @@ def _splash_attention_forward(
out_specs=out_specs,
grid=grid,
),
mosaic_params=mosaic_params,
compiler_params=dict(mosaic=mosaic_params),
out_shape=out_shapes,
name=kernel_name,
)(
@ -1558,7 +1558,7 @@ def _splash_attention_bwd_dq(
grid=grid,
),
out_shape=out_shapes,
mosaic_params=mosaic_params,
compiler_params=dict(mosaic=mosaic_params),
name=kernel_name,
)(
mask_info.data_next,
@ -2111,7 +2111,7 @@ def _splash_attention_bwd_dkv(
grid=grid,
),
out_shape=out_shapes,
mosaic_params=mosaic_params,
compiler_params=dict(mosaic=mosaic_params),
name=kernel_name,
)(
mask_info.data_next,

View File

@ -1252,7 +1252,7 @@ class PallasCallRemoteDMATest(parameterized.TestCase):
in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
out_shape=x,
mosaic_params=dict(collective_id=0)
compiler_params=dict(mosaic=dict(collective_id=0)),
)(x)
device_mesh = mesh_utils.create_device_mesh(
@ -1281,9 +1281,11 @@ class PallasCallTest(PallasTPUTest):
f = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
mosaic_params=dict(
cost_estimate=pltpu.CostEstimate(
flops=1234, transcendentals=21, bytes_accessed=12345
compiler_params=dict(
mosaic=dict(
cost_estimate=pltpu.CostEstimate(
flops=1234, transcendentals=21, bytes_accessed=12345
)
)
),
)
@ -1301,10 +1303,14 @@ class PallasCallTest(PallasTPUTest):
x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape)
with self.assertRaises(xla_extension.XlaRuntimeError):
pl.pallas_call(
kernel, out_shape=x, mosaic_params=dict(vmem_limit_bytes=256)
kernel,
out_shape=x,
compiler_params=dict(mosaic=dict(vmem_limit_bytes=256)),
)(x)
pl.pallas_call(
kernel, out_shape=x, mosaic_params=dict(vmem_limit_bytes=int(2**18))
kernel,
out_shape=x,
compiler_params=dict(mosaic=dict(vmem_limit_bytes=int(2**18))),
)(x)
@ -1946,8 +1952,8 @@ class PallasCallPipelineTest(parameterized.TestCase):
)
],
),
mosaic_params=dict(
collective_id=0, vmem_limit_bytes=int(134217728 * 0.9)
compiler_params=dict(
mosaic=dict(collective_id=0, vmem_limit_bytes=int(134217728 * 0.9))
),
)