Move CostEstimate from pltu to pl

* Move CostEstimate from TPU-specific `compiler_params` to a platform-independent argument of `pallas_call`.
Passing a CostEstimate in `compiler_params` is now deprecated and will be removed in 3 months time.
* Update the CostEstimate when batching a kernel by scaling it by the size of the batch axis.

PiperOrigin-RevId: 659560330
This commit is contained in:
jax authors 2024-08-05 08:17:18 -07:00 committed by jax authors
parent ecf9f64240
commit 9762ac53c8
10 changed files with 157 additions and 45 deletions

View File

@ -826,3 +826,16 @@ class PallasMesh(mesh_lib.Mesh):
@property
def _is_jax_device_mesh(self):
return False
@dataclasses.dataclass(frozen=True)
class CostEstimate:
flops: int
transcendentals: int
bytes_accessed: int
def to_json(self) -> bytes:
return (
f'{{"flops": {self.flops}, "transcendentals": {self.transcendentals},'
f' "bytes_accessed": {self.bytes_accessed}}}'
).encode("ascii")

View File

@ -22,8 +22,8 @@ from typing import Any
import warnings
import jax
from jax import dtypes
from jax import core as jax_core
from jax import dtypes
from jax._src import config
from jax._src import core as jax_src_core
from jax._src import sharding_impls
@ -34,6 +34,7 @@ from jax._src.pallas.mosaic import lowering
from jax._src.pallas.mosaic import verification
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
from jax.experimental.pallas import tpu as pltpu
def _maybe_cast_to_int(x: jax.Array | jax_core.ShapedArray):
"""Casts boolean values to integers.
@ -71,7 +72,9 @@ def pallas_call_tpu_lowering_rule(
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
interpret: bool,
compiler_params: dict[str, Any]):
compiler_params: dict[str, Any],
cost_estimate: core.CostEstimate | None,
):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
del interpret
if debug:
@ -90,6 +93,22 @@ def pallas_call_tpu_lowering_rule(
mosaic_params = compiler_params["mosaic"]
else:
mosaic_params = {}
if "cost_estimate" in mosaic_params:
# TODO(amagni): Remove this branch after October 22th 2024.
if cost_estimate is not None:
raise ValueError(
"Passing cost estimate via both compiler_params=dict(mosaic=...) and"
" pallas_call(..., cost_estimate=...) is not supported."
)
warnings.warn(
"Passing cost estimate via compiler_params=dict(cost_estimate=...) is"
" deprecated. Use pallas_call(..., cost_estimate=...) instead.",
DeprecationWarning,
)
cost_estimate = mosaic_params["cost_estimate"]
mesh = None
axis_context = ctx.module_context.axis_context
if axis_context is not None:
@ -172,27 +191,33 @@ def pallas_call_tpu_lowering_rule(
# Dynamic grid bounds have to go at the front.
dynamic_grid_args, args = in_nodes[:num_dyn_bounds], in_nodes[num_dyn_bounds:]
kernel_ctx = ctx.replace(avals_in=kernel_in_avals, avals_out=kernel_out_avals)
out_nodes = mosaic.lower_module_to_custom_call(
kernel_ctx,
*dynamic_grid_args,
*extra_args,
*args,
module=mosaic_module,
out_type=kernel_out_avals,
backend="tpu",
kernel_name=name_and_src_info.name,
cost_estimate=mosaic_params.get("cost_estimate"),
vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"),
flags=mosaic_params.get("flags"),
allow_input_fusion=mosaic_params.get("allow_input_fusion"),
input_output_aliases=input_output_aliases,
serialization_format=mosaic_params.get("serialization_format", 1),
device_type=mosaic_params.get("device_type"),
internal_scratch_in_bytes=mosaic_params.get(
"internal_scratch_in_bytes"
),
collective_id=mosaic_params.get("collective_id", None),
if cost_estimate is not None:
mosaic_cost_estimate = pltpu.CostEstimate(
flops=cost_estimate.flops,
bytes_accessed=cost_estimate.bytes_accessed,
transcendentals=cost_estimate.transcendentals,
)
else:
mosaic_cost_estimate = None
out_nodes = mosaic.lower_module_to_custom_call(
kernel_ctx,
*dynamic_grid_args,
*extra_args,
*args,
module=mosaic_module,
out_type=kernel_out_avals,
backend="tpu",
kernel_name=name_and_src_info.name,
cost_estimate=mosaic_cost_estimate,
vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"),
flags=mosaic_params.get("flags"),
allow_input_fusion=mosaic_params.get("allow_input_fusion"),
input_output_aliases=input_output_aliases,
serialization_format=mosaic_params.get("serialization_format", 1),
device_type=mosaic_params.get("device_type"),
internal_scratch_in_bytes=mosaic_params.get("internal_scratch_in_bytes"),
collective_id=mosaic_params.get("collective_id", None),
)
_maybe_cast_to_bool = lambda x, aval: x.astype(
jax.numpy.bool_) if aval.dtype == jax.numpy.bool_ else x
def _maybe_cast_outputs(*args):

View File

@ -146,7 +146,9 @@ def lower_jaxpr_to_module(
jaxpr: jax_core.Jaxpr,
name_and_src_info: pallas_core.NameAndSrcInfo,
compiler_params: dict[str, Any],
cost_estimate: pallas_core.CostEstimate | None,
) -> LoweringResult:
del cost_estimate # Unused.
in_structs = tuple(grid_mapping.in_shapes)
out_structs = grid_mapping.out_shapes
assert len(jaxpr.outvars) == 0

View File

@ -36,6 +36,7 @@ def pallas_call_lowering(
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: pallas_core.GridMapping,
compiler_params: dict[str, Any],
cost_estimate: pallas_core.CostEstimate | None,
):
del interpret
if grid_mapping.num_dynamic_grid_bounds:
@ -58,6 +59,7 @@ def pallas_call_lowering(
jaxpr,
name_and_src_info,
compiler_params,
cost_estimate,
)
if debug:
print(f"\nThe Mosaic GPU module for pallas_call {name_and_src_info}:")

View File

@ -60,6 +60,7 @@ BlockSpec = pallas_core.BlockSpec
BlockSpecTree = pallas_core.BlockSpecTree
NoBlockSpec = pallas_core.NoBlockSpec
no_block_spec = pallas_core.no_block_spec
CostEstimate = pallas_core.CostEstimate
# See the docstring for GridMapping for the calling convention
pallas_call_p = jax_core.Primitive('pallas_call')
@ -164,6 +165,7 @@ def _get_next_indices(grid, indices):
def _pallas_call_impl(*args, **kwargs):
assert False # We always jit a pallas call, we only need the lowering rule
def _pallas_call_impl_interpret(
*args,
jaxpr: jax_core.Jaxpr,
@ -171,8 +173,10 @@ def _pallas_call_impl_interpret(
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
compiler_params: Any):
del compiler_params
compiler_params: Any,
cost_estimate: CostEstimate,
):
del compiler_params, cost_estimate
# If we're in interpreter mode, we *scan* over the grid and eval the
# discharged jaxpr.
dynamic_grid_args, args = split_list( # type: ignore
@ -294,6 +298,7 @@ def _pallas_call_impl_interpret(
out_nopad.append(o)
return out_nopad
pallas_call_p.def_impl(_pallas_call_impl)
def _pallas_call_abstract_eval(*avals, grid_mapping: GridMapping, **_):
@ -302,9 +307,20 @@ def _pallas_call_abstract_eval(*avals, grid_mapping: GridMapping, **_):
for bm in grid_mapping.block_mappings_output)
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name_and_src_info,
def _pallas_call_jvp_rule(
primals,
tangents,
*,
jaxpr,
name_and_src_info,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping, debug, interpret, compiler_params: Any):
grid_mapping,
debug,
interpret,
compiler_params: Any,
cost_estimate: CostEstimate | None,
):
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
if grid_mapping.num_index_operands:
@ -346,20 +362,32 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name_and_src_info,
num_inputs=grid_mapping.num_inputs * 2,
num_outputs=grid_mapping.num_outputs * 2,
)
if cost_estimate is not None:
jvp_cost_estimate = CostEstimate(
flops=2 * cost_estimate.flops,
bytes_accessed=2 * cost_estimate.bytes_accessed,
transcendentals=2 * cost_estimate.transcendentals,
)
else:
jvp_cost_estimate = None
out_flat = pallas_call_p.bind(
*primals,
*tangents,
jaxpr=jvp_jaxpr,
name_and_src_info=name_and_src_info.replace(
name=f"{name_and_src_info.name}_jvp"),
name=f"{name_and_src_info.name}_jvp"
),
grid_mapping=jvp_grid_mapping,
interpret=interpret,
debug=debug,
input_output_aliases=(),
compiler_params=compiler_params,
cost_estimate=jvp_cost_estimate,
)
out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2])
return out_primals, out_tangents
ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule
def _batch_block_mapping(grid_mapping: GridMapping,
@ -436,6 +464,7 @@ def _batch_with_explicit_loop(
debug: bool,
interpret: bool,
compiler_params: Any,
cost_estimate: CostEstimate | None,
):
"""Batch the pallas_call by calling it in loop over the batch size.
@ -501,6 +530,7 @@ def _batch_with_explicit_loop(
debug=debug,
interpret=interpret,
compiler_params=compiler_params,
cost_estimate=cost_estimate,
)
for i, batch_out_array in enumerate(batch_out):
state[i] = jax.lax.dynamic_update_index_in_dim(
@ -528,6 +558,7 @@ def _pallas_call_batching_rule(
debug: bool,
interpret: bool,
compiler_params: Any,
cost_estimate: CostEstimate | None,
):
def _maybe_squeeze_out_bdim(
x: jax.Array, bdim: int | batching.NotMapped
@ -536,8 +567,9 @@ def _pallas_call_batching_rule(
return x
return jnp.squeeze(x, axis=bdim)
axis_size, = {x.shape[d] for x, d in zip(args, dims)
if d is not batching.not_mapped}
(axis_size,) = {
x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped
}
if axis_size == 1:
# Why are we even vmapping?
args = map(_maybe_squeeze_out_bdim, args, dims)
@ -550,6 +582,7 @@ def _pallas_call_batching_rule(
debug=debug,
interpret=interpret,
compiler_params=compiler_params,
cost_estimate=cost_estimate,
)
return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out)
@ -581,6 +614,7 @@ def _pallas_call_batching_rule(
debug=debug,
interpret=interpret,
compiler_params=compiler_params,
cost_estimate=cost_estimate,
)
else:
pass # No dynamic grid dimensions
@ -613,6 +647,7 @@ def _pallas_call_batching_rule(
debug=debug,
interpret=interpret,
compiler_params=compiler_params,
cost_estimate=cost_estimate,
)
if not dims:
@ -657,18 +692,29 @@ def _pallas_call_batching_rule(
block_mappings=tuple(batched_block_mappings),
index_map_avals=batched_index_map_avals,
index_map_tree=batched_index_map_tree,
vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims))
vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims),
)
if cost_estimate is not None:
batched_cost_estimate = CostEstimate(
flops=cost_estimate.flops * axis_size,
bytes_accessed=cost_estimate.bytes_accessed * axis_size,
transcendentals=cost_estimate.transcendentals * axis_size,
)
else:
batched_cost_estimate = None
out = pallas_call_p.bind(
*dynamic_grid_args,
*args,
jaxpr=jaxpr,
name_and_src_info=name_and_src_info.replace(
name=f"{name_and_src_info.name}_batched"),
name=f"{name_and_src_info.name}_batched"
),
grid_mapping=batched_grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
compiler_params=compiler_params,
cost_estimate=batched_cost_estimate,
)
return out, (0,) * len(out)
@ -971,6 +1017,7 @@ def pallas_call(
interpret: bool = False,
name: str | None = None,
compiler_params: dict[str, Any] | None = None,
cost_estimate: CostEstimate | None = None,
) -> Callable[..., Any]:
"""Invokes a Pallas kernel on some inputs.
@ -1106,14 +1153,18 @@ def pallas_call(
index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands])
out_flat = pallas_call_p.bind(
*dynamic_grid_bounds, *index_args, *rest_args,
*dynamic_grid_bounds,
*index_args,
*rest_args,
jaxpr=jaxpr,
name_and_src_info=name_and_src_info,
debug=debug,
interpret=interpret,
grid_mapping=grid_mapping,
input_output_aliases=tuple(input_output_aliases.items()),
compiler_params=compiler_params)
compiler_params=compiler_params,
cost_estimate=cost_estimate,
)
out = tree_util.tree_unflatten(out_tree, out_flat)
return out
return wrapped

View File

@ -48,6 +48,7 @@ def pallas_call_lowering(
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: pallas_core.GridMapping,
compiler_params: dict[str, Any],
cost_estimate: pallas_core.CostEstimate | None,
):
del interpret
if grid_mapping.num_dynamic_grid_bounds:

View File

@ -21,6 +21,7 @@ https://jax.readthedocs.io/en/latest/pallas.html.
from jax._src.deprecations import register as _register_deprecation
from jax._src.pallas.core import Blocked
from jax._src.pallas.core import BlockSpec
from jax._src.pallas.core import CostEstimate
from jax._src.pallas.core import IndexingMode
from jax._src.pallas.core import no_block_spec
from jax._src.pallas.core import Unblocked

View File

@ -520,7 +520,7 @@ def gmm(
(lhs_bytes * tiles_n) + (rhs_bytes * max_active_tiles) + out_bytes
)
flops = 2 * m * k * n
cost_estimate = pltpu.CostEstimate(
cost_estimate = pl.CostEstimate(
flops=flops, bytes_accessed=bytes_accessed, transcendentals=0
)
call_gmm = pl.pallas_call(
@ -541,10 +541,10 @@ def gmm(
compiler_params=dict(
mosaic=dict(
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
cost_estimate=cost_estimate,
)
),
interpret=interpret,
cost_estimate=cost_estimate,
)
out = call_gmm(
@ -759,7 +759,7 @@ def tgmm(
(lhs_bytes * tiles_n) + (rhs_bytes * tiles_k) + out_bytes
)
flops = 2 * m * k * n
cost_estimate = pltpu.CostEstimate(
cost_estimate = pl.CostEstimate(
flops=flops, bytes_accessed=bytes_accessed, transcendentals=0
)
lhs = lhs.swapaxes(0, 1)

View File

@ -22,12 +22,12 @@ from jax._src.pallas.mosaic.core import semaphore
from jax._src.pallas.mosaic.core import SemaphoreType
from jax._src.pallas.mosaic.core import TPUMemorySpace
from jax._src.pallas.mosaic.lowering import LoweringException
from jax._src.pallas.mosaic.pipeline import ARBITRARY
from jax._src.pallas.mosaic.pipeline import BufferedRef
from jax._src.pallas.mosaic.pipeline import emit_pipeline
from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations
from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule
from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations
from jax._src.pallas.mosaic.pipeline import ARBITRARY
from jax._src.pallas.mosaic.pipeline import PARALLEL
from jax._src.pallas.mosaic.primitives import async_copy
from jax._src.pallas.mosaic.primitives import async_remote_copy
@ -38,14 +38,15 @@ from jax._src.pallas.mosaic.primitives import DeviceIdType
from jax._src.pallas.mosaic.primitives import get_barrier_semaphore
from jax._src.pallas.mosaic.primitives import make_async_copy
from jax._src.pallas.mosaic.primitives import make_async_remote_copy
from jax._src.pallas.mosaic.primitives import prng_random_bits
from jax._src.pallas.mosaic.primitives import prng_seed
from jax._src.pallas.mosaic.primitives import repeat
from jax._src.pallas.mosaic.primitives import roll
from jax._src.pallas.mosaic.primitives import semaphore_read
from jax._src.pallas.mosaic.primitives import semaphore_signal
from jax._src.pallas.mosaic.primitives import semaphore_wait
from jax._src.pallas.mosaic.primitives import prng_seed
from jax._src.pallas.mosaic.primitives import prng_random_bits
from jax._src.pallas.mosaic.random import to_pallas_key
# Remove this import after October 22th 2024.
from jax._src.tpu_custom_call import CostEstimate
# TODO(cperivol): Temporary alias to the global run_scoped. Remove

View File

@ -1487,12 +1487,8 @@ class PallasCallTest(PallasBaseTest):
f = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
compiler_params=dict(
mosaic=dict(
cost_estimate=pltpu.CostEstimate(
flops=1234, transcendentals=21, bytes_accessed=12345
)
)
cost_estimate=pl.CostEstimate(
flops=1234, transcendentals=21, bytes_accessed=12345
),
)
(analysis_result,) = jax.jit(f).lower(x).compile().cost_analysis()
@ -1500,6 +1496,26 @@ class PallasCallTest(PallasBaseTest):
self.assertEqual(analysis_result['transcendentals'], 21)
self.assertEqual(analysis_result['bytes accessed'], 12345)
def test_cost_analysis_vmap(self):
def kernel(x, y):
y[:] = x[:]
batch_size = 3
x = jnp.arange(batch_size * 1024.).reshape(batch_size, 8, 128)
f = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
cost_estimate=pl.CostEstimate(
flops=1234, transcendentals=21, bytes_accessed=12345
),
)
f = jax.vmap(f)
(analysis_result,) = jax.jit(f).lower(x).compile().cost_analysis()
self.assertEqual(analysis_result['flops'], batch_size * 1234)
self.assertEqual(analysis_result['transcendentals'], batch_size * 21)
self.assertEqual(analysis_result['bytes accessed'], batch_size * 12345)
def test_vmem_limit(self):
shape = (128, 128)