mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Pallas] Pass in compiler params via explicit compiler_params argument instead of passing via **kwargs
This is a change that makes the API a bit more intuitive and avoids footguns like accidentally passing in `in_spec` instead of `in_specs` because previously kwargs that weren't used by any downstream lowering would be ignored and users would get weird errors as a result. This change doesn't deprecate the old way of passing in compiler params but it will be deprecated soon after this. PiperOrigin-RevId: 613239439
This commit is contained in:
parent
d0e0ca1e52
commit
30973a9474
@ -619,7 +619,8 @@
|
||||
" in_specs=[block_spec, block_spec],\n",
|
||||
" out_specs=block_spec,\n",
|
||||
" grid=(2,),\n",
|
||||
" mosaic_params=dict(dimension_semantics=(\"parallel\",)))(x, y)\n",
|
||||
" compiler_params=dict(mosaic=dict(dimension_semantics=(\"parallel\",))))(\n",
|
||||
" x, y)\n",
|
||||
"\n",
|
||||
"x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n",
|
||||
"add_matrices_pipelined_megacore(x, y)"
|
||||
|
@ -373,7 +373,8 @@ def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
|
||||
in_specs=[block_spec, block_spec],
|
||||
out_specs=block_spec,
|
||||
grid=(2,),
|
||||
mosaic_params=dict(dimension_semantics=("parallel",)))(x, y)
|
||||
compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",))))(
|
||||
x, y)
|
||||
|
||||
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
|
||||
add_matrices_pipelined_megacore(x, y)
|
||||
|
@ -41,8 +41,7 @@ def pallas_call_tpu_lowering_rule(
|
||||
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
mosaic_params: dict[str, Any] | None = None,
|
||||
**compiler_params: Any):
|
||||
compiler_params: dict[str, Any]):
|
||||
"""Lowers a pallas_call to a Mosaic TPU custom call."""
|
||||
if interpret:
|
||||
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
|
||||
@ -51,9 +50,17 @@ def pallas_call_tpu_lowering_rule(
|
||||
which_linear=which_linear,
|
||||
interpret=interpret, debug=debug,
|
||||
input_output_aliases=input_output_aliases,
|
||||
grid_mapping=grid_mapping, **compiler_params)
|
||||
grid_mapping=grid_mapping,
|
||||
compiler_params=compiler_params)
|
||||
if debug:
|
||||
print(jaxpr)
|
||||
if 'mosaic_params' in compiler_params:
|
||||
assert 'mosaic' not in compiler_params
|
||||
mosaic_params = compiler_params['mosaic_params']
|
||||
elif 'mosaic' in compiler_params:
|
||||
mosaic_params = compiler_params['mosaic']
|
||||
else:
|
||||
mosaic_params = {}
|
||||
mesh = None
|
||||
axis_context = ctx.module_context.axis_context
|
||||
if axis_context is not None:
|
||||
|
@ -104,7 +104,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
|
||||
in_shapes,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
**compiler_params: Any):
|
||||
compiler_params: Any):
|
||||
dynamic_grid_args, args = split_list( # type: ignore
|
||||
args, [grid_mapping.num_dynamic_grid_bounds]
|
||||
)
|
||||
@ -234,7 +234,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
|
||||
grid_mapping=grid_mapping, interpret=interpret,
|
||||
debug=debug,
|
||||
input_output_aliases=input_output_aliases,
|
||||
**compiler_params)
|
||||
compiler_params=compiler_params)
|
||||
pallas_call_p.def_impl(_pallas_call_impl)
|
||||
|
||||
def _pallas_call_abstract_eval(*avals, out_shapes, **_):
|
||||
@ -243,7 +243,7 @@ pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
|
||||
|
||||
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
in_shapes, out_shapes, grid_mapping, debug, interpret, **compiler_params: Any):
|
||||
in_shapes, out_shapes, grid_mapping, debug, interpret, compiler_params: Any):
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
|
||||
if grid_mapping.num_index_operands:
|
||||
@ -285,7 +285,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
input_output_aliases=(),
|
||||
**compiler_params,
|
||||
compiler_params=compiler_params,
|
||||
)
|
||||
out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2])
|
||||
return out_primals, out_tangents
|
||||
@ -336,7 +336,7 @@ def _pallas_call_batching_rule(args, dims, *,
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
which_linear: tuple[bool, ...],
|
||||
**compiler_params: Any):
|
||||
compiler_params: Any):
|
||||
|
||||
def _maybe_squeeze_out_bdim(
|
||||
x: jax.Array, bdim: int | batching.NotMapped
|
||||
@ -449,7 +449,7 @@ def _pallas_call_batching_rule(args, dims, *,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
**compiler_params,
|
||||
compiler_params=compiler_params,
|
||||
)
|
||||
return out, (0,) * len(out)
|
||||
batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule
|
||||
@ -535,9 +535,15 @@ def pallas_call(
|
||||
input_output_aliases: dict[int, int] = {},
|
||||
interpret: bool = False,
|
||||
name: str | None = None,
|
||||
**compiler_params: Any,
|
||||
compiler_params: dict[str, Any] | None = None,
|
||||
**compiler_params_: Any,
|
||||
):
|
||||
name = _extract_function_name(f, name)
|
||||
if compiler_params is None:
|
||||
compiler_params = {}
|
||||
assert not (compiler_params and compiler_params_)
|
||||
if compiler_params_:
|
||||
compiler_params = compiler_params_
|
||||
if grid is not None and grid_spec is not None:
|
||||
raise ValueError("Cannot specify both grid and grid_spec at the same time.")
|
||||
if grid_spec is None:
|
||||
@ -568,7 +574,7 @@ def pallas_call(
|
||||
interpret=interpret,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=tuple(input_output_aliases.items()),
|
||||
**compiler_params)
|
||||
compiler_params=compiler_params)
|
||||
out = tree_util.tree_unflatten(out_tree, out_flat)
|
||||
return out
|
||||
return wrapped
|
||||
|
@ -261,8 +261,7 @@ def pallas_call_lowering(
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
triton_params: dict[str, Any] | None = None,
|
||||
**compiler_params: Any,
|
||||
compiler_params: dict[str, Any],
|
||||
):
|
||||
if interpret:
|
||||
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
|
||||
@ -277,24 +276,22 @@ def pallas_call_lowering(
|
||||
debug=debug,
|
||||
input_output_aliases=input_output_aliases,
|
||||
grid_mapping=grid_mapping,
|
||||
**compiler_params,
|
||||
compiler_params=compiler_params,
|
||||
)
|
||||
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError(
|
||||
"dynamic grid bounds not supported in the Triton backend"
|
||||
)
|
||||
|
||||
num_warps = compiler_params.pop("num_warps", 4)
|
||||
triton_params = compiler_params.get("triton_params", {})
|
||||
triton_compiler_params = compiler_params.get("triton", {})
|
||||
num_warps = triton_compiler_params.pop("num_warps", 4)
|
||||
if len(ctx.module_context.platforms) > 1:
|
||||
raise NotImplementedError("multi-platform lowering for Pallas kernels")
|
||||
if ctx.module_context.platforms[0] == "rocm":
|
||||
num_stages = compiler_params.pop("num_stages", 1)
|
||||
num_stages = triton_compiler_params.pop("num_stages", 1)
|
||||
else:
|
||||
num_stages = compiler_params.pop("num_stages", 3)
|
||||
|
||||
if triton_params is None:
|
||||
triton_params = {}
|
||||
num_stages = triton_compiler_params.pop("num_stages", 3)
|
||||
|
||||
if debug:
|
||||
print(jaxpr)
|
||||
@ -318,7 +315,6 @@ def pallas_call_lowering(
|
||||
triton_params=triton_params,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
**compiler_params,
|
||||
)
|
||||
|
||||
|
||||
|
@ -220,8 +220,9 @@ def mha(
|
||||
out_specs=pl.BlockSpec(
|
||||
lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
||||
),
|
||||
num_warps=num_warps_,
|
||||
num_stages=num_stages,
|
||||
compiler_params=dict(
|
||||
triton=dict(num_warps=num_warps_, num_stages=num_stages)
|
||||
),
|
||||
out_shape=out_shape,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
@ -294,8 +295,9 @@ def _mha_forward(
|
||||
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
|
||||
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
|
||||
],
|
||||
num_warps=num_warps_,
|
||||
num_stages=num_stages,
|
||||
compiler_params=dict(
|
||||
triton=dict(num_warps=num_warps_, num_stages=num_stages)
|
||||
),
|
||||
out_shape=out_shape,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
@ -342,8 +344,9 @@ def _preprocess_backward(out, do, l, block_q: int,
|
||||
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
||||
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
|
||||
],
|
||||
num_warps=4,
|
||||
num_stages=3,
|
||||
compiler_params=dict(
|
||||
triton=dict(num_warps=4, num_stages=3)
|
||||
),
|
||||
out_shape=out_shape,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
@ -536,8 +539,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
|
||||
name="mha_backward",
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)),
|
||||
input_output_aliases=input_output_aliases,
|
||||
)(q, k, v, segment_ids, out, do_scaled, l, m, delta, dq)
|
||||
else:
|
||||
|
@ -93,9 +93,15 @@ def layer_norm_forward(
|
||||
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype),
|
||||
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype)
|
||||
]
|
||||
method = pl.pallas_call(kernel, num_warps=num_warps,
|
||||
grid=(), out_shape=out_shape, debug=False,
|
||||
interpret=interpret, name='ln_forward')
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=dict(triton=dict(num_warps=num_warps)),
|
||||
grid=(),
|
||||
out_shape=out_shape,
|
||||
debug=False,
|
||||
interpret=interpret,
|
||||
name="ln_forward",
|
||||
)
|
||||
|
||||
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
|
||||
out, mean, rstd = method(x, weight, bias)
|
||||
@ -208,9 +214,15 @@ def layer_norm_backward(
|
||||
kernel = functools.partial(layer_norm_backward_kernel_dx, eps=eps,
|
||||
block_size=block_size)
|
||||
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
|
||||
method = pl.pallas_call(kernel, num_warps=num_warps,
|
||||
grid=(), out_shape=out_shape_dx, debug=False,
|
||||
interpret=interpret, name='ln_backward_dx')
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=dict(triton=dict(num_warps=num_warps)),
|
||||
grid=(),
|
||||
out_shape=out_shape_dx,
|
||||
debug=False,
|
||||
interpret=interpret,
|
||||
name="ln_backward_dx",
|
||||
)
|
||||
|
||||
method = jax.vmap(method, in_axes=(0, None, None, 0, 0, 0))
|
||||
dx = method(reshaped_x, weight, bias, reshaped_do, reshaped_mean, reshaped_rstd)
|
||||
@ -234,9 +246,15 @@ def layer_norm_backward(
|
||||
jax.ShapeDtypeStruct(shape=bias.shape, dtype=bias.dtype)
|
||||
]
|
||||
grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),)
|
||||
method = pl.pallas_call(kernel, num_warps=num_warps,
|
||||
grid=grid_, out_shape=out_shape_dwbias, debug=False,
|
||||
interpret=interpret, name='ln_backward_dw_db')
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=dict(triton=dict(num_warps=num_warps)),
|
||||
grid=grid_,
|
||||
out_shape=out_shape_dwbias,
|
||||
debug=False,
|
||||
interpret=interpret,
|
||||
name="ln_backward_dw_db",
|
||||
)
|
||||
dw, dbias = method(reshaped_x, weight, bias, reshaped_do, reshaped_mean, reshaped_rstd)
|
||||
return dx, dw, dbias
|
||||
|
||||
@ -264,9 +282,16 @@ def layer_norm(
|
||||
kernel = functools.partial(layer_norm_forward_kernel, eps=eps,
|
||||
block_size=block_size)
|
||||
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
|
||||
method = pl.pallas_call(kernel, num_warps=num_warps, num_stages=num_stages,
|
||||
grid=(), out_shape=out_shape, debug=False,
|
||||
interpret=interpret)
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=dict(
|
||||
triton=dict(num_warps=num_warps, num_stages=num_stages)
|
||||
),
|
||||
grid=(),
|
||||
out_shape=out_shape,
|
||||
debug=False,
|
||||
interpret=interpret,
|
||||
)
|
||||
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
|
||||
return method(x, weight, bias)
|
||||
layer_norm.defvjp(layer_norm_forward, layer_norm_backward)
|
||||
|
@ -81,9 +81,15 @@ def rms_norm_forward(
|
||||
jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype),
|
||||
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype)
|
||||
]
|
||||
method = pl.pallas_call(kernel, num_warps=num_warps,
|
||||
grid=(), out_shape=out_shape, debug=False,
|
||||
interpret=interpret, name='rms_forward')
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=dict(triton=dict(num_warps=num_warps)),
|
||||
grid=(),
|
||||
out_shape=out_shape,
|
||||
debug=False,
|
||||
interpret=interpret,
|
||||
name="rms_forward",
|
||||
)
|
||||
|
||||
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
|
||||
out, rstd = method(x, weight, bias)
|
||||
@ -189,9 +195,15 @@ def rms_norm_backward(
|
||||
kernel = functools.partial(rms_norm_backward_kernel_dx, eps=eps,
|
||||
block_size=block_size)
|
||||
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
|
||||
method = pl.pallas_call(kernel, num_warps=num_warps,
|
||||
grid=(), out_shape=out_shape_dx, debug=False,
|
||||
interpret=interpret, name='ln_backward_dx')
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=dict(triton=dict(num_warps=num_warps)),
|
||||
grid=(),
|
||||
out_shape=out_shape_dx,
|
||||
debug=False,
|
||||
interpret=interpret,
|
||||
name="ln_backward_dx",
|
||||
)
|
||||
|
||||
method = jax.vmap(method, in_axes=(0, None, None, 0, 0))
|
||||
dx = method(reshaped_x, weight, bias, reshaped_do, reshaped_rstd)
|
||||
@ -215,9 +227,15 @@ def rms_norm_backward(
|
||||
jax.ShapeDtypeStruct(shape=bias.shape, dtype=bias.dtype)
|
||||
]
|
||||
grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),)
|
||||
method = pl.pallas_call(kernel, num_warps=num_warps,
|
||||
grid=grid_, out_shape=out_shape_dwbias, debug=False,
|
||||
interpret=interpret, name='ln_backward_dw_db')
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=dict(triton=dict(num_warps=num_warps)),
|
||||
grid=grid_,
|
||||
out_shape=out_shape_dwbias,
|
||||
debug=False,
|
||||
interpret=interpret,
|
||||
name="ln_backward_dw_db",
|
||||
)
|
||||
dw, dbias = method(reshaped_x, weight, bias, reshaped_do, reshaped_rstd)
|
||||
return dx, dw, dbias
|
||||
|
||||
@ -245,9 +263,16 @@ def rms_norm(
|
||||
kernel = functools.partial(rms_norm_forward_kernel, eps=eps,
|
||||
block_size=block_size)
|
||||
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
|
||||
method = pl.pallas_call(kernel, num_warps=num_warps, num_stages=num_stages,
|
||||
grid=(), out_shape=out_shape, debug=False,
|
||||
interpret=interpret)
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=dict(
|
||||
triton=dict(num_warps=num_warps, num_stages=num_stages)
|
||||
),
|
||||
grid=(),
|
||||
out_shape=out_shape,
|
||||
debug=False,
|
||||
interpret=interpret,
|
||||
)
|
||||
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
|
||||
return method(x, weight, bias)
|
||||
rms_norm.defvjp(rms_norm_forward, rms_norm_backward)
|
||||
|
@ -77,8 +77,14 @@ def softmax(
|
||||
out_shape = jax.ShapeDtypeStruct(shape=(row_len,), dtype=x.dtype)
|
||||
|
||||
kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row)
|
||||
f = pl.pallas_call(kernel, num_warps=num_warps, num_stages=1, grid=(),
|
||||
out_shape=out_shape, debug=debug, interpret=interpret)
|
||||
f = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)),
|
||||
grid=(),
|
||||
out_shape=out_shape,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
)
|
||||
|
||||
for _ in range(len(x.shape) - 1):
|
||||
f = jax.vmap(f)
|
||||
|
@ -745,8 +745,15 @@ def _flash_attention_impl(
|
||||
),
|
||||
out_shape=out_shape,
|
||||
debug=debug,
|
||||
mosaic_params=dict(
|
||||
dimension_semantics=("parallel", "parallel", "parallel", "arbitrary")
|
||||
compiler_params=dict(
|
||||
mosaic=dict(
|
||||
dimension_semantics=(
|
||||
"parallel",
|
||||
"parallel",
|
||||
"parallel",
|
||||
"arbitrary",
|
||||
)
|
||||
)
|
||||
),
|
||||
)(q, k, v, ab, q_segment_ids, kv_segment_ids)
|
||||
if save_residuals:
|
||||
@ -1098,12 +1105,14 @@ def _flash_attention_bwd_dkv(
|
||||
),
|
||||
out_shape=out_shapes,
|
||||
debug=debug,
|
||||
mosaic_params=dict(
|
||||
dimension_semantics=(
|
||||
"parallel",
|
||||
"parallel",
|
||||
"parallel",
|
||||
"arbitrary",
|
||||
compiler_params=dict(
|
||||
mosaic=dict(
|
||||
dimension_semantics=(
|
||||
"parallel",
|
||||
"parallel",
|
||||
"parallel",
|
||||
"arbitrary",
|
||||
)
|
||||
)
|
||||
),
|
||||
)(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di)
|
||||
@ -1441,12 +1450,14 @@ def _flash_attention_bwd_dq(
|
||||
),
|
||||
out_shape=out_shapes,
|
||||
debug=debug,
|
||||
mosaic_params=dict(
|
||||
dimension_semantics=(
|
||||
"parallel",
|
||||
"parallel",
|
||||
"parallel",
|
||||
"arbitrary",
|
||||
compiler_params=dict(
|
||||
mosaic=dict(
|
||||
dimension_semantics=(
|
||||
"parallel",
|
||||
"parallel",
|
||||
"parallel",
|
||||
"arbitrary",
|
||||
)
|
||||
)
|
||||
),
|
||||
)(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di)
|
||||
|
@ -1106,7 +1106,7 @@ def _splash_attention_forward(
|
||||
out_specs=out_specs,
|
||||
grid=grid,
|
||||
),
|
||||
mosaic_params=mosaic_params,
|
||||
compiler_params=dict(mosaic=mosaic_params),
|
||||
out_shape=out_shapes,
|
||||
name=kernel_name,
|
||||
)(
|
||||
@ -1558,7 +1558,7 @@ def _splash_attention_bwd_dq(
|
||||
grid=grid,
|
||||
),
|
||||
out_shape=out_shapes,
|
||||
mosaic_params=mosaic_params,
|
||||
compiler_params=dict(mosaic=mosaic_params),
|
||||
name=kernel_name,
|
||||
)(
|
||||
mask_info.data_next,
|
||||
@ -2111,7 +2111,7 @@ def _splash_attention_bwd_dkv(
|
||||
grid=grid,
|
||||
),
|
||||
out_shape=out_shapes,
|
||||
mosaic_params=mosaic_params,
|
||||
compiler_params=dict(mosaic=mosaic_params),
|
||||
name=kernel_name,
|
||||
)(
|
||||
mask_info.data_next,
|
||||
|
@ -1252,7 +1252,7 @@ class PallasCallRemoteDMATest(parameterized.TestCase):
|
||||
in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)],
|
||||
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
|
||||
out_shape=x,
|
||||
mosaic_params=dict(collective_id=0)
|
||||
compiler_params=dict(mosaic=dict(collective_id=0)),
|
||||
)(x)
|
||||
|
||||
device_mesh = mesh_utils.create_device_mesh(
|
||||
@ -1281,9 +1281,11 @@ class PallasCallTest(PallasTPUTest):
|
||||
f = pl.pallas_call(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
|
||||
mosaic_params=dict(
|
||||
cost_estimate=pltpu.CostEstimate(
|
||||
flops=1234, transcendentals=21, bytes_accessed=12345
|
||||
compiler_params=dict(
|
||||
mosaic=dict(
|
||||
cost_estimate=pltpu.CostEstimate(
|
||||
flops=1234, transcendentals=21, bytes_accessed=12345
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
@ -1301,10 +1303,14 @@ class PallasCallTest(PallasTPUTest):
|
||||
x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
with self.assertRaises(xla_extension.XlaRuntimeError):
|
||||
pl.pallas_call(
|
||||
kernel, out_shape=x, mosaic_params=dict(vmem_limit_bytes=256)
|
||||
kernel,
|
||||
out_shape=x,
|
||||
compiler_params=dict(mosaic=dict(vmem_limit_bytes=256)),
|
||||
)(x)
|
||||
pl.pallas_call(
|
||||
kernel, out_shape=x, mosaic_params=dict(vmem_limit_bytes=int(2**18))
|
||||
kernel,
|
||||
out_shape=x,
|
||||
compiler_params=dict(mosaic=dict(vmem_limit_bytes=int(2**18))),
|
||||
)(x)
|
||||
|
||||
|
||||
@ -1946,8 +1952,8 @@ class PallasCallPipelineTest(parameterized.TestCase):
|
||||
)
|
||||
],
|
||||
),
|
||||
mosaic_params=dict(
|
||||
collective_id=0, vmem_limit_bytes=int(134217728 * 0.9)
|
||||
compiler_params=dict(
|
||||
mosaic=dict(collective_id=0, vmem_limit_bytes=int(134217728 * 0.9))
|
||||
),
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user