Added a new approx_math flag to Mosaic GPU params in Pallas

The flag allows to control the precision of some operations, e.g. `exp`.

PiperOrigin-RevId: 674305430
This commit is contained in:
Sergei Lebedev 2024-09-13 08:19:57 -07:00 committed by jax authors
parent 8fa0e925dd
commit 40040e3f69
3 changed files with 26 additions and 8 deletions

View File

@ -29,11 +29,13 @@ import jax.numpy as jnp
AbstractMemoryRef = pallas_core.AbstractMemoryRef
@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, kw_only=True)
class GPUCompilerParams(pallas_core.CompilerParams):
"""Mosaic GPU compiler parameters.
Attributes:
approx_math: If True, the compiler is allowed to use approximate
implementations of some math operations, e.g. ``exp``. Defaults to False.
dimension_semantics: A list of dimension semantics for each grid
dimension of the kernel. Either "parallel" for dimensions that can
execute in any order, or "sequential" for dimensions that must be
@ -42,6 +44,7 @@ class GPUCompilerParams(pallas_core.CompilerParams):
meaning no pipelining is done.
"""
PLATFORM: ClassVar[str] = "mosaic_gpu"
approx_math: bool = False
dimension_semantics: Sequence[Literal["parallel", "sequential"]] | None = None
num_stages: int = 1

View File

@ -97,8 +97,9 @@ def _reduce_sum_smem_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int:
class ModuleContext:
name: str
grid_mapping: pallas_core.GridMapping
approx_math: bool
runtime_smem: ir.Value # ir.MemRefType
smem_used_bytes: int
smem_used_bytes: int = 0
# TODO(cperivol): Only return the shapes and figure out the sizes when freeing.
def scratch_view(
@ -233,11 +234,12 @@ def lower_jaxpr_to_module(
grid += (1,) * (3 - len(grid))
block = (128,) + (1,) * (len(grid) - 1)
params = compiler_params.get("mosaic_gpu", {})
approx_math = params.get("approx_math", False)
num_stages = params.get("num_stages", 1)
dimension_semantics = params.get(
"dimension_semantics", ["parallel"] * len(grid_mapping.grid)
)
if len(dimension_semantics) != len(grid_mapping.grid):
dimension_semantics = params.get("dimension_semantics")
if dimension_semantics is None:
dimension_semantics = ["parallel"] * len(grid_mapping.grid)
elif len(dimension_semantics) != len(grid_mapping.grid):
raise ValueError(
"dimension_semantics must have an entrey for each grid dimension:"
f" {len(dimension_semantics)=}, but len(grid={grid_mapping.grid})."
@ -295,7 +297,7 @@ def lower_jaxpr_to_module(
)
module_ctx = ModuleContext(
name_and_src_info.name, grid_mapping, runtime_smem, smem_used_bytes=0
name_and_src_info.name, grid_mapping, approx_math, runtime_smem
)
program_ids = map(_program_id, range(len(grid_mapping.grid)))
start_indices = map(
@ -622,7 +624,7 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y):
@register_lowering_rule(lax.rsqrt_p)
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
return _ensure_fa(x, *ctx.avals_in).rsqrt()
return _ensure_fa(x, *ctx.avals_in).rsqrt(ctx.module_context.approx_math)
@register_lowering_rule(lax.reduce_sum_p)

View File

@ -127,6 +127,19 @@ class PallasCallTest(PallasTest):
x = jnp.arange(128).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + x.sum()*2)
@parameterized.parameters(False, True)
def test_rsqrt(self, approx_math):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
compiler_params=plgpu.GPUCompilerParams(approx_math=approx_math),
)
def kernel(x_ref, o_ref):
o_ref[...] = jax.lax.rsqrt(x_ref[...])
x = jnp.arange(128).astype(jnp.float32)
np.testing.assert_allclose(kernel(x), jax.lax.rsqrt(x))
@parameterized.product(input_factor=[0.001, 1, 10, 100, 100])
def test_layer_norm(self, input_factor):
eps = 1e-5