From 30973a9474c44d06fa2edba8e200938dba739d64 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Wed, 6 Mar 2024 09:15:36 -0800 Subject: [PATCH] [Pallas] Pass in compiler params via explicit compiler_params argument instead of passing via **kwargs This is a change that makes the API a bit more intuitive and avoids footguns like accidentally passing in `in_spec` instead of `in_specs` because previously kwargs that weren't used by any downstream lowering would be ignored and users would get weird errors as a result. This change doesn't deprecate the old way of passing in compiler params but it will be deprecated soon after this. PiperOrigin-RevId: 613239439 --- docs/pallas/tpu/pipelining.ipynb | 3 +- docs/pallas/tpu/pipelining.md | 3 +- .../pallas/mosaic/pallas_call_registration.py | 13 +++-- jax/_src/pallas/pallas_call.py | 22 ++++++--- .../pallas/triton/pallas_call_registration.py | 18 +++---- jax/experimental/pallas/ops/attention.py | 18 ++++--- jax/experimental/pallas/ops/layer_norm.py | 49 ++++++++++++++----- jax/experimental/pallas/ops/rms_norm.py | 49 ++++++++++++++----- jax/experimental/pallas/ops/softmax.py | 10 +++- .../pallas/ops/tpu/flash_attention.py | 39 +++++++++------ .../splash_attention_kernel.py | 6 +-- tests/pallas/pallas_call_tpu_test.py | 22 ++++++--- 12 files changed, 169 insertions(+), 83 deletions(-) diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index 4a927b85f..fe9271e92 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -619,7 +619,8 @@ " in_specs=[block_spec, block_spec],\n", " out_specs=block_spec,\n", " grid=(2,),\n", - " mosaic_params=dict(dimension_semantics=(\"parallel\",)))(x, y)\n", + " compiler_params=dict(mosaic=dict(dimension_semantics=(\"parallel\",))))(\n", + " x, y)\n", "\n", "x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n", "add_matrices_pipelined_megacore(x, y)" diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index 65b86554f..8e42364c2 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -373,7 +373,8 @@ def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array: in_specs=[block_spec, block_spec], out_specs=block_spec, grid=(2,), - mosaic_params=dict(dimension_semantics=("parallel",)))(x, y) + compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",))))( + x, y) x, y = jnp.ones((512, 512)), jnp.ones((512, 512)) add_matrices_pipelined_megacore(x, y) diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 96d4f32b0..aeb2a733e 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -41,8 +41,7 @@ def pallas_call_tpu_lowering_rule( out_shapes: tuple[jax.ShapeDtypeStruct, ...], debug: bool, interpret: bool, - mosaic_params: dict[str, Any] | None = None, - **compiler_params: Any): + compiler_params: dict[str, Any]): """Lowers a pallas_call to a Mosaic TPU custom call.""" if interpret: return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( @@ -51,9 +50,17 @@ def pallas_call_tpu_lowering_rule( which_linear=which_linear, interpret=interpret, debug=debug, input_output_aliases=input_output_aliases, - grid_mapping=grid_mapping, **compiler_params) + grid_mapping=grid_mapping, + compiler_params=compiler_params) if debug: print(jaxpr) + if 'mosaic_params' in compiler_params: + assert 'mosaic' not in compiler_params + mosaic_params = compiler_params['mosaic_params'] + elif 'mosaic' in compiler_params: + mosaic_params = compiler_params['mosaic'] + else: + mosaic_params = {} mesh = None axis_context = ctx.module_context.axis_context if axis_context is not None: diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 7dac9e579..ef6067841 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -104,7 +104,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear, in_shapes, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, - **compiler_params: Any): + compiler_params: Any): dynamic_grid_args, args = split_list( # type: ignore args, [grid_mapping.num_dynamic_grid_bounds] ) @@ -234,7 +234,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear, grid_mapping=grid_mapping, interpret=interpret, debug=debug, input_output_aliases=input_output_aliases, - **compiler_params) + compiler_params=compiler_params) pallas_call_p.def_impl(_pallas_call_impl) def _pallas_call_abstract_eval(*avals, out_shapes, **_): @@ -243,7 +243,7 @@ pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval) def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, input_output_aliases: tuple[tuple[int, int], ...], - in_shapes, out_shapes, grid_mapping, debug, interpret, **compiler_params: Any): + in_shapes, out_shapes, grid_mapping, debug, interpret, compiler_params: Any): if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError("interpret with dynamic grid bounds unsupported") if grid_mapping.num_index_operands: @@ -285,7 +285,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, interpret=interpret, debug=debug, input_output_aliases=(), - **compiler_params, + compiler_params=compiler_params, ) out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2]) return out_primals, out_tangents @@ -336,7 +336,7 @@ def _pallas_call_batching_rule(args, dims, *, debug: bool, interpret: bool, which_linear: tuple[bool, ...], - **compiler_params: Any): + compiler_params: Any): def _maybe_squeeze_out_bdim( x: jax.Array, bdim: int | batching.NotMapped @@ -449,7 +449,7 @@ def _pallas_call_batching_rule(args, dims, *, input_output_aliases=input_output_aliases, debug=debug, interpret=interpret, - **compiler_params, + compiler_params=compiler_params, ) return out, (0,) * len(out) batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule @@ -535,9 +535,15 @@ def pallas_call( input_output_aliases: dict[int, int] = {}, interpret: bool = False, name: str | None = None, - **compiler_params: Any, + compiler_params: dict[str, Any] | None = None, + **compiler_params_: Any, ): name = _extract_function_name(f, name) + if compiler_params is None: + compiler_params = {} + assert not (compiler_params and compiler_params_) + if compiler_params_: + compiler_params = compiler_params_ if grid is not None and grid_spec is not None: raise ValueError("Cannot specify both grid and grid_spec at the same time.") if grid_spec is None: @@ -568,7 +574,7 @@ def pallas_call( interpret=interpret, grid_mapping=grid_mapping, input_output_aliases=tuple(input_output_aliases.items()), - **compiler_params) + compiler_params=compiler_params) out = tree_util.tree_unflatten(out_tree, out_flat) return out return wrapped diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 8735ceff4..aa7b90b1b 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -261,8 +261,7 @@ def pallas_call_lowering( debug: bool, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: pallas_core.GridMapping, - triton_params: dict[str, Any] | None = None, - **compiler_params: Any, + compiler_params: dict[str, Any], ): if interpret: return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( @@ -277,24 +276,22 @@ def pallas_call_lowering( debug=debug, input_output_aliases=input_output_aliases, grid_mapping=grid_mapping, - **compiler_params, + compiler_params=compiler_params, ) if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError( "dynamic grid bounds not supported in the Triton backend" ) - - num_warps = compiler_params.pop("num_warps", 4) + triton_params = compiler_params.get("triton_params", {}) + triton_compiler_params = compiler_params.get("triton", {}) + num_warps = triton_compiler_params.pop("num_warps", 4) if len(ctx.module_context.platforms) > 1: raise NotImplementedError("multi-platform lowering for Pallas kernels") if ctx.module_context.platforms[0] == "rocm": - num_stages = compiler_params.pop("num_stages", 1) + num_stages = triton_compiler_params.pop("num_stages", 1) else: - num_stages = compiler_params.pop("num_stages", 3) - - if triton_params is None: - triton_params = {} + num_stages = triton_compiler_params.pop("num_stages", 3) if debug: print(jaxpr) @@ -318,7 +315,6 @@ def pallas_call_lowering( triton_params=triton_params, num_warps=num_warps, num_stages=num_stages, - **compiler_params, ) diff --git a/jax/experimental/pallas/ops/attention.py b/jax/experimental/pallas/ops/attention.py index f6b5a773a..a96d3e1cc 100644 --- a/jax/experimental/pallas/ops/attention.py +++ b/jax/experimental/pallas/ops/attention.py @@ -220,8 +220,9 @@ def mha( out_specs=pl.BlockSpec( lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) ), - num_warps=num_warps_, - num_stages=num_stages, + compiler_params=dict( + triton=dict(num_warps=num_warps_, num_stages=num_stages) + ), out_shape=out_shape, debug=debug, interpret=interpret, @@ -294,8 +295,9 @@ def _mha_forward( pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), ], - num_warps=num_warps_, - num_stages=num_stages, + compiler_params=dict( + triton=dict(num_warps=num_warps_, num_stages=num_stages) + ), out_shape=out_shape, debug=debug, interpret=interpret, @@ -342,8 +344,9 @@ def _preprocess_backward(out, do, l, block_q: int, pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), ], - num_warps=4, - num_stages=3, + compiler_params=dict( + triton=dict(num_warps=4, num_stages=3) + ), out_shape=out_shape, debug=debug, interpret=interpret, @@ -536,8 +539,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, name="mha_backward", debug=debug, interpret=interpret, - num_warps=num_warps, - num_stages=1, + compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)), input_output_aliases=input_output_aliases, )(q, k, v, segment_ids, out, do_scaled, l, m, delta, dq) else: diff --git a/jax/experimental/pallas/ops/layer_norm.py b/jax/experimental/pallas/ops/layer_norm.py index 31c6fe43c..269f29dc7 100644 --- a/jax/experimental/pallas/ops/layer_norm.py +++ b/jax/experimental/pallas/ops/layer_norm.py @@ -93,9 +93,15 @@ def layer_norm_forward( jax.ShapeDtypeStruct(shape=(), dtype=x.dtype), jax.ShapeDtypeStruct(shape=(), dtype=x.dtype) ] - method = pl.pallas_call(kernel, num_warps=num_warps, - grid=(), out_shape=out_shape, debug=False, - interpret=interpret, name='ln_forward') + method = pl.pallas_call( + kernel, + compiler_params=dict(triton=dict(num_warps=num_warps)), + grid=(), + out_shape=out_shape, + debug=False, + interpret=interpret, + name="ln_forward", + ) method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None)) out, mean, rstd = method(x, weight, bias) @@ -208,9 +214,15 @@ def layer_norm_backward( kernel = functools.partial(layer_norm_backward_kernel_dx, eps=eps, block_size=block_size) out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) - method = pl.pallas_call(kernel, num_warps=num_warps, - grid=(), out_shape=out_shape_dx, debug=False, - interpret=interpret, name='ln_backward_dx') + method = pl.pallas_call( + kernel, + compiler_params=dict(triton=dict(num_warps=num_warps)), + grid=(), + out_shape=out_shape_dx, + debug=False, + interpret=interpret, + name="ln_backward_dx", + ) method = jax.vmap(method, in_axes=(0, None, None, 0, 0, 0)) dx = method(reshaped_x, weight, bias, reshaped_do, reshaped_mean, reshaped_rstd) @@ -234,9 +246,15 @@ def layer_norm_backward( jax.ShapeDtypeStruct(shape=bias.shape, dtype=bias.dtype) ] grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),) - method = pl.pallas_call(kernel, num_warps=num_warps, - grid=grid_, out_shape=out_shape_dwbias, debug=False, - interpret=interpret, name='ln_backward_dw_db') + method = pl.pallas_call( + kernel, + compiler_params=dict(triton=dict(num_warps=num_warps)), + grid=grid_, + out_shape=out_shape_dwbias, + debug=False, + interpret=interpret, + name="ln_backward_dw_db", + ) dw, dbias = method(reshaped_x, weight, bias, reshaped_do, reshaped_mean, reshaped_rstd) return dx, dw, dbias @@ -264,9 +282,16 @@ def layer_norm( kernel = functools.partial(layer_norm_forward_kernel, eps=eps, block_size=block_size) out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) - method = pl.pallas_call(kernel, num_warps=num_warps, num_stages=num_stages, - grid=(), out_shape=out_shape, debug=False, - interpret=interpret) + method = pl.pallas_call( + kernel, + compiler_params=dict( + triton=dict(num_warps=num_warps, num_stages=num_stages) + ), + grid=(), + out_shape=out_shape, + debug=False, + interpret=interpret, + ) method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None)) return method(x, weight, bias) layer_norm.defvjp(layer_norm_forward, layer_norm_backward) diff --git a/jax/experimental/pallas/ops/rms_norm.py b/jax/experimental/pallas/ops/rms_norm.py index f6a5bc6f7..e1dfa3c5b 100644 --- a/jax/experimental/pallas/ops/rms_norm.py +++ b/jax/experimental/pallas/ops/rms_norm.py @@ -81,9 +81,15 @@ def rms_norm_forward( jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype), jax.ShapeDtypeStruct(shape=(), dtype=x.dtype) ] - method = pl.pallas_call(kernel, num_warps=num_warps, - grid=(), out_shape=out_shape, debug=False, - interpret=interpret, name='rms_forward') + method = pl.pallas_call( + kernel, + compiler_params=dict(triton=dict(num_warps=num_warps)), + grid=(), + out_shape=out_shape, + debug=False, + interpret=interpret, + name="rms_forward", + ) method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None)) out, rstd = method(x, weight, bias) @@ -189,9 +195,15 @@ def rms_norm_backward( kernel = functools.partial(rms_norm_backward_kernel_dx, eps=eps, block_size=block_size) out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) - method = pl.pallas_call(kernel, num_warps=num_warps, - grid=(), out_shape=out_shape_dx, debug=False, - interpret=interpret, name='ln_backward_dx') + method = pl.pallas_call( + kernel, + compiler_params=dict(triton=dict(num_warps=num_warps)), + grid=(), + out_shape=out_shape_dx, + debug=False, + interpret=interpret, + name="ln_backward_dx", + ) method = jax.vmap(method, in_axes=(0, None, None, 0, 0)) dx = method(reshaped_x, weight, bias, reshaped_do, reshaped_rstd) @@ -215,9 +227,15 @@ def rms_norm_backward( jax.ShapeDtypeStruct(shape=bias.shape, dtype=bias.dtype) ] grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),) - method = pl.pallas_call(kernel, num_warps=num_warps, - grid=grid_, out_shape=out_shape_dwbias, debug=False, - interpret=interpret, name='ln_backward_dw_db') + method = pl.pallas_call( + kernel, + compiler_params=dict(triton=dict(num_warps=num_warps)), + grid=grid_, + out_shape=out_shape_dwbias, + debug=False, + interpret=interpret, + name="ln_backward_dw_db", + ) dw, dbias = method(reshaped_x, weight, bias, reshaped_do, reshaped_rstd) return dx, dw, dbias @@ -245,9 +263,16 @@ def rms_norm( kernel = functools.partial(rms_norm_forward_kernel, eps=eps, block_size=block_size) out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) - method = pl.pallas_call(kernel, num_warps=num_warps, num_stages=num_stages, - grid=(), out_shape=out_shape, debug=False, - interpret=interpret) + method = pl.pallas_call( + kernel, + compiler_params=dict( + triton=dict(num_warps=num_warps, num_stages=num_stages) + ), + grid=(), + out_shape=out_shape, + debug=False, + interpret=interpret, + ) method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None)) return method(x, weight, bias) rms_norm.defvjp(rms_norm_forward, rms_norm_backward) diff --git a/jax/experimental/pallas/ops/softmax.py b/jax/experimental/pallas/ops/softmax.py index a53620923..3671331b8 100644 --- a/jax/experimental/pallas/ops/softmax.py +++ b/jax/experimental/pallas/ops/softmax.py @@ -77,8 +77,14 @@ def softmax( out_shape = jax.ShapeDtypeStruct(shape=(row_len,), dtype=x.dtype) kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row) - f = pl.pallas_call(kernel, num_warps=num_warps, num_stages=1, grid=(), - out_shape=out_shape, debug=debug, interpret=interpret) + f = pl.pallas_call( + kernel, + compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)), + grid=(), + out_shape=out_shape, + debug=debug, + interpret=interpret, + ) for _ in range(len(x.shape) - 1): f = jax.vmap(f) diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index de6cf2361..c63929bb6 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -745,8 +745,15 @@ def _flash_attention_impl( ), out_shape=out_shape, debug=debug, - mosaic_params=dict( - dimension_semantics=("parallel", "parallel", "parallel", "arbitrary") + compiler_params=dict( + mosaic=dict( + dimension_semantics=( + "parallel", + "parallel", + "parallel", + "arbitrary", + ) + ) ), )(q, k, v, ab, q_segment_ids, kv_segment_ids) if save_residuals: @@ -1098,12 +1105,14 @@ def _flash_attention_bwd_dkv( ), out_shape=out_shapes, debug=debug, - mosaic_params=dict( - dimension_semantics=( - "parallel", - "parallel", - "parallel", - "arbitrary", + compiler_params=dict( + mosaic=dict( + dimension_semantics=( + "parallel", + "parallel", + "parallel", + "arbitrary", + ) ) ), )(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di) @@ -1441,12 +1450,14 @@ def _flash_attention_bwd_dq( ), out_shape=out_shapes, debug=debug, - mosaic_params=dict( - dimension_semantics=( - "parallel", - "parallel", - "parallel", - "arbitrary", + compiler_params=dict( + mosaic=dict( + dimension_semantics=( + "parallel", + "parallel", + "parallel", + "arbitrary", + ) ) ), )(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di) diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py index 602c07558..313b2a677 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py @@ -1106,7 +1106,7 @@ def _splash_attention_forward( out_specs=out_specs, grid=grid, ), - mosaic_params=mosaic_params, + compiler_params=dict(mosaic=mosaic_params), out_shape=out_shapes, name=kernel_name, )( @@ -1558,7 +1558,7 @@ def _splash_attention_bwd_dq( grid=grid, ), out_shape=out_shapes, - mosaic_params=mosaic_params, + compiler_params=dict(mosaic=mosaic_params), name=kernel_name, )( mask_info.data_next, @@ -2111,7 +2111,7 @@ def _splash_attention_bwd_dkv( grid=grid, ), out_shape=out_shapes, - mosaic_params=mosaic_params, + compiler_params=dict(mosaic=mosaic_params), name=kernel_name, )( mask_info.data_next, diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index 2e74ae45e..c27e6468b 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -1252,7 +1252,7 @@ class PallasCallRemoteDMATest(parameterized.TestCase): in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), out_shape=x, - mosaic_params=dict(collective_id=0) + compiler_params=dict(mosaic=dict(collective_id=0)), )(x) device_mesh = mesh_utils.create_device_mesh( @@ -1281,9 +1281,11 @@ class PallasCallTest(PallasTPUTest): f = pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), - mosaic_params=dict( - cost_estimate=pltpu.CostEstimate( - flops=1234, transcendentals=21, bytes_accessed=12345 + compiler_params=dict( + mosaic=dict( + cost_estimate=pltpu.CostEstimate( + flops=1234, transcendentals=21, bytes_accessed=12345 + ) ) ), ) @@ -1301,10 +1303,14 @@ class PallasCallTest(PallasTPUTest): x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape) with self.assertRaises(xla_extension.XlaRuntimeError): pl.pallas_call( - kernel, out_shape=x, mosaic_params=dict(vmem_limit_bytes=256) + kernel, + out_shape=x, + compiler_params=dict(mosaic=dict(vmem_limit_bytes=256)), )(x) pl.pallas_call( - kernel, out_shape=x, mosaic_params=dict(vmem_limit_bytes=int(2**18)) + kernel, + out_shape=x, + compiler_params=dict(mosaic=dict(vmem_limit_bytes=int(2**18))), )(x) @@ -1946,8 +1952,8 @@ class PallasCallPipelineTest(parameterized.TestCase): ) ], ), - mosaic_params=dict( - collective_id=0, vmem_limit_bytes=int(134217728 * 0.9) + compiler_params=dict( + mosaic=dict(collective_id=0, vmem_limit_bytes=int(134217728 * 0.9)) ), )