mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Added a new approx_math
flag to Mosaic GPU params in Pallas
The flag allows to control the precision of some operations, e.g. `exp`. PiperOrigin-RevId: 674305430
This commit is contained in:
parent
8fa0e925dd
commit
40040e3f69
@ -29,11 +29,13 @@ import jax.numpy as jnp
|
||||
AbstractMemoryRef = pallas_core.AbstractMemoryRef
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class GPUCompilerParams(pallas_core.CompilerParams):
|
||||
"""Mosaic GPU compiler parameters.
|
||||
|
||||
Attributes:
|
||||
approx_math: If True, the compiler is allowed to use approximate
|
||||
implementations of some math operations, e.g. ``exp``. Defaults to False.
|
||||
dimension_semantics: A list of dimension semantics for each grid
|
||||
dimension of the kernel. Either "parallel" for dimensions that can
|
||||
execute in any order, or "sequential" for dimensions that must be
|
||||
@ -42,6 +44,7 @@ class GPUCompilerParams(pallas_core.CompilerParams):
|
||||
meaning no pipelining is done.
|
||||
"""
|
||||
PLATFORM: ClassVar[str] = "mosaic_gpu"
|
||||
approx_math: bool = False
|
||||
dimension_semantics: Sequence[Literal["parallel", "sequential"]] | None = None
|
||||
num_stages: int = 1
|
||||
|
||||
|
@ -97,8 +97,9 @@ def _reduce_sum_smem_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int:
|
||||
class ModuleContext:
|
||||
name: str
|
||||
grid_mapping: pallas_core.GridMapping
|
||||
approx_math: bool
|
||||
runtime_smem: ir.Value # ir.MemRefType
|
||||
smem_used_bytes: int
|
||||
smem_used_bytes: int = 0
|
||||
|
||||
# TODO(cperivol): Only return the shapes and figure out the sizes when freeing.
|
||||
def scratch_view(
|
||||
@ -233,11 +234,12 @@ def lower_jaxpr_to_module(
|
||||
grid += (1,) * (3 - len(grid))
|
||||
block = (128,) + (1,) * (len(grid) - 1)
|
||||
params = compiler_params.get("mosaic_gpu", {})
|
||||
approx_math = params.get("approx_math", False)
|
||||
num_stages = params.get("num_stages", 1)
|
||||
dimension_semantics = params.get(
|
||||
"dimension_semantics", ["parallel"] * len(grid_mapping.grid)
|
||||
)
|
||||
if len(dimension_semantics) != len(grid_mapping.grid):
|
||||
dimension_semantics = params.get("dimension_semantics")
|
||||
if dimension_semantics is None:
|
||||
dimension_semantics = ["parallel"] * len(grid_mapping.grid)
|
||||
elif len(dimension_semantics) != len(grid_mapping.grid):
|
||||
raise ValueError(
|
||||
"dimension_semantics must have an entrey for each grid dimension:"
|
||||
f" {len(dimension_semantics)=}, but len(grid={grid_mapping.grid})."
|
||||
@ -295,7 +297,7 @@ def lower_jaxpr_to_module(
|
||||
)
|
||||
|
||||
module_ctx = ModuleContext(
|
||||
name_and_src_info.name, grid_mapping, runtime_smem, smem_used_bytes=0
|
||||
name_and_src_info.name, grid_mapping, approx_math, runtime_smem
|
||||
)
|
||||
program_ids = map(_program_id, range(len(grid_mapping.grid)))
|
||||
start_indices = map(
|
||||
@ -622,7 +624,7 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y):
|
||||
|
||||
@register_lowering_rule(lax.rsqrt_p)
|
||||
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
return _ensure_fa(x, *ctx.avals_in).rsqrt()
|
||||
return _ensure_fa(x, *ctx.avals_in).rsqrt(ctx.module_context.approx_math)
|
||||
|
||||
|
||||
@register_lowering_rule(lax.reduce_sum_p)
|
||||
|
@ -127,6 +127,19 @@ class PallasCallTest(PallasTest):
|
||||
x = jnp.arange(128).astype(jnp.float32)
|
||||
np.testing.assert_array_equal(kernel(x), x + x.sum()*2)
|
||||
|
||||
@parameterized.parameters(False, True)
|
||||
def test_rsqrt(self, approx_math):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
|
||||
compiler_params=plgpu.GPUCompilerParams(approx_math=approx_math),
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
o_ref[...] = jax.lax.rsqrt(x_ref[...])
|
||||
|
||||
x = jnp.arange(128).astype(jnp.float32)
|
||||
np.testing.assert_allclose(kernel(x), jax.lax.rsqrt(x))
|
||||
|
||||
@parameterized.product(input_factor=[0.001, 1, 10, 100, 100])
|
||||
def test_layer_norm(self, input_factor):
|
||||
eps = 1e-5
|
||||
|
Loading…
x
Reference in New Issue
Block a user