From 9762ac53c87a26b756be9c5e84eec8d98978d54e Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 5 Aug 2024 08:17:18 -0700 Subject: [PATCH] Move CostEstimate from pltu to pl * Move CostEstimate from TPU-specific `compiler_params` to a platform-independent argument of `pallas_call`. Passing a CostEstimate in `compiler_params` is now deprecated and will be removed in 3 months time. * Update the CostEstimate when batching a kernel by scaling it by the size of the batch axis. PiperOrigin-RevId: 659560330 --- jax/_src/pallas/core.py | 13 ++++ .../pallas/mosaic/pallas_call_registration.py | 69 ++++++++++++------ jax/_src/pallas/mosaic_gpu/lowering.py | 2 + .../mosaic_gpu/pallas_call_registration.py | 2 + jax/_src/pallas/pallas_call.py | 73 ++++++++++++++++--- .../pallas/triton/pallas_call_registration.py | 1 + jax/experimental/pallas/__init__.py | 1 + .../pallas/ops/tpu/megablox/gmm.py | 6 +- jax/experimental/pallas/tpu.py | 7 +- tests/pallas/tpu_pallas_test.py | 28 +++++-- 10 files changed, 157 insertions(+), 45 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 08c1c20e3..e99510f8d 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -826,3 +826,16 @@ class PallasMesh(mesh_lib.Mesh): @property def _is_jax_device_mesh(self): return False + + +@dataclasses.dataclass(frozen=True) +class CostEstimate: + flops: int + transcendentals: int + bytes_accessed: int + + def to_json(self) -> bytes: + return ( + f'{{"flops": {self.flops}, "transcendentals": {self.transcendentals},' + f' "bytes_accessed": {self.bytes_accessed}}}' + ).encode("ascii") diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 959519789..c6edddca0 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -22,8 +22,8 @@ from typing import Any import warnings import jax -from jax import dtypes from jax import core as jax_core +from jax import dtypes from jax._src import config from jax._src import core as jax_src_core from jax._src import sharding_impls @@ -34,6 +34,7 @@ from jax._src.pallas.mosaic import lowering from jax._src.pallas.mosaic import verification from jax.experimental import mosaic from jax.experimental.mosaic.dialects import tpu +from jax.experimental.pallas import tpu as pltpu def _maybe_cast_to_int(x: jax.Array | jax_core.ShapedArray): """Casts boolean values to integers. @@ -71,7 +72,9 @@ def pallas_call_tpu_lowering_rule( input_output_aliases: tuple[tuple[int, int], ...], debug: bool, interpret: bool, - compiler_params: dict[str, Any]): + compiler_params: dict[str, Any], + cost_estimate: core.CostEstimate | None, +): """Lowers a pallas_call to a Mosaic TPU custom call.""" del interpret if debug: @@ -90,6 +93,22 @@ def pallas_call_tpu_lowering_rule( mosaic_params = compiler_params["mosaic"] else: mosaic_params = {} + + if "cost_estimate" in mosaic_params: + # TODO(amagni): Remove this branch after October 22th 2024. + if cost_estimate is not None: + raise ValueError( + "Passing cost estimate via both compiler_params=dict(mosaic=...) and" + " pallas_call(..., cost_estimate=...) is not supported." + ) + + warnings.warn( + "Passing cost estimate via compiler_params=dict(cost_estimate=...) is" + " deprecated. Use pallas_call(..., cost_estimate=...) instead.", + DeprecationWarning, + ) + cost_estimate = mosaic_params["cost_estimate"] + mesh = None axis_context = ctx.module_context.axis_context if axis_context is not None: @@ -172,27 +191,33 @@ def pallas_call_tpu_lowering_rule( # Dynamic grid bounds have to go at the front. dynamic_grid_args, args = in_nodes[:num_dyn_bounds], in_nodes[num_dyn_bounds:] kernel_ctx = ctx.replace(avals_in=kernel_in_avals, avals_out=kernel_out_avals) - out_nodes = mosaic.lower_module_to_custom_call( - kernel_ctx, - *dynamic_grid_args, - *extra_args, - *args, - module=mosaic_module, - out_type=kernel_out_avals, - backend="tpu", - kernel_name=name_and_src_info.name, - cost_estimate=mosaic_params.get("cost_estimate"), - vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"), - flags=mosaic_params.get("flags"), - allow_input_fusion=mosaic_params.get("allow_input_fusion"), - input_output_aliases=input_output_aliases, - serialization_format=mosaic_params.get("serialization_format", 1), - device_type=mosaic_params.get("device_type"), - internal_scratch_in_bytes=mosaic_params.get( - "internal_scratch_in_bytes" - ), - collective_id=mosaic_params.get("collective_id", None), + if cost_estimate is not None: + mosaic_cost_estimate = pltpu.CostEstimate( + flops=cost_estimate.flops, + bytes_accessed=cost_estimate.bytes_accessed, + transcendentals=cost_estimate.transcendentals, ) + else: + mosaic_cost_estimate = None + out_nodes = mosaic.lower_module_to_custom_call( + kernel_ctx, + *dynamic_grid_args, + *extra_args, + *args, + module=mosaic_module, + out_type=kernel_out_avals, + backend="tpu", + kernel_name=name_and_src_info.name, + cost_estimate=mosaic_cost_estimate, + vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"), + flags=mosaic_params.get("flags"), + allow_input_fusion=mosaic_params.get("allow_input_fusion"), + input_output_aliases=input_output_aliases, + serialization_format=mosaic_params.get("serialization_format", 1), + device_type=mosaic_params.get("device_type"), + internal_scratch_in_bytes=mosaic_params.get("internal_scratch_in_bytes"), + collective_id=mosaic_params.get("collective_id", None), + ) _maybe_cast_to_bool = lambda x, aval: x.astype( jax.numpy.bool_) if aval.dtype == jax.numpy.bool_ else x def _maybe_cast_outputs(*args): diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 7dfd0258c..be9168e07 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -146,7 +146,9 @@ def lower_jaxpr_to_module( jaxpr: jax_core.Jaxpr, name_and_src_info: pallas_core.NameAndSrcInfo, compiler_params: dict[str, Any], + cost_estimate: pallas_core.CostEstimate | None, ) -> LoweringResult: + del cost_estimate # Unused. in_structs = tuple(grid_mapping.in_shapes) out_structs = grid_mapping.out_shapes assert len(jaxpr.outvars) == 0 diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 1409800fd..9f28fa7c2 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -36,6 +36,7 @@ def pallas_call_lowering( input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: pallas_core.GridMapping, compiler_params: dict[str, Any], + cost_estimate: pallas_core.CostEstimate | None, ): del interpret if grid_mapping.num_dynamic_grid_bounds: @@ -58,6 +59,7 @@ def pallas_call_lowering( jaxpr, name_and_src_info, compiler_params, + cost_estimate, ) if debug: print(f"\nThe Mosaic GPU module for pallas_call {name_and_src_info}:") diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 619110593..f6ee5381a 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -60,6 +60,7 @@ BlockSpec = pallas_core.BlockSpec BlockSpecTree = pallas_core.BlockSpecTree NoBlockSpec = pallas_core.NoBlockSpec no_block_spec = pallas_core.no_block_spec +CostEstimate = pallas_core.CostEstimate # See the docstring for GridMapping for the calling convention pallas_call_p = jax_core.Primitive('pallas_call') @@ -164,6 +165,7 @@ def _get_next_indices(grid, indices): def _pallas_call_impl(*args, **kwargs): assert False # We always jit a pallas call, we only need the lowering rule + def _pallas_call_impl_interpret( *args, jaxpr: jax_core.Jaxpr, @@ -171,8 +173,10 @@ def _pallas_call_impl_interpret( debug: bool, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, - compiler_params: Any): - del compiler_params + compiler_params: Any, + cost_estimate: CostEstimate, +): + del compiler_params, cost_estimate # If we're in interpreter mode, we *scan* over the grid and eval the # discharged jaxpr. dynamic_grid_args, args = split_list( # type: ignore @@ -294,6 +298,7 @@ def _pallas_call_impl_interpret( out_nopad.append(o) return out_nopad + pallas_call_p.def_impl(_pallas_call_impl) def _pallas_call_abstract_eval(*avals, grid_mapping: GridMapping, **_): @@ -302,9 +307,20 @@ def _pallas_call_abstract_eval(*avals, grid_mapping: GridMapping, **_): for bm in grid_mapping.block_mappings_output) pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval) -def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name_and_src_info, + +def _pallas_call_jvp_rule( + primals, + tangents, + *, + jaxpr, + name_and_src_info, input_output_aliases: tuple[tuple[int, int], ...], - grid_mapping, debug, interpret, compiler_params: Any): + grid_mapping, + debug, + interpret, + compiler_params: Any, + cost_estimate: CostEstimate | None, +): if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError("interpret with dynamic grid bounds unsupported") if grid_mapping.num_index_operands: @@ -346,20 +362,32 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name_and_src_info, num_inputs=grid_mapping.num_inputs * 2, num_outputs=grid_mapping.num_outputs * 2, ) + if cost_estimate is not None: + jvp_cost_estimate = CostEstimate( + flops=2 * cost_estimate.flops, + bytes_accessed=2 * cost_estimate.bytes_accessed, + transcendentals=2 * cost_estimate.transcendentals, + ) + else: + jvp_cost_estimate = None out_flat = pallas_call_p.bind( *primals, *tangents, jaxpr=jvp_jaxpr, name_and_src_info=name_and_src_info.replace( - name=f"{name_and_src_info.name}_jvp"), + name=f"{name_and_src_info.name}_jvp" + ), grid_mapping=jvp_grid_mapping, interpret=interpret, debug=debug, input_output_aliases=(), compiler_params=compiler_params, + cost_estimate=jvp_cost_estimate, ) out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2]) return out_primals, out_tangents + + ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule def _batch_block_mapping(grid_mapping: GridMapping, @@ -436,6 +464,7 @@ def _batch_with_explicit_loop( debug: bool, interpret: bool, compiler_params: Any, + cost_estimate: CostEstimate | None, ): """Batch the pallas_call by calling it in loop over the batch size. @@ -501,6 +530,7 @@ def _batch_with_explicit_loop( debug=debug, interpret=interpret, compiler_params=compiler_params, + cost_estimate=cost_estimate, ) for i, batch_out_array in enumerate(batch_out): state[i] = jax.lax.dynamic_update_index_in_dim( @@ -528,6 +558,7 @@ def _pallas_call_batching_rule( debug: bool, interpret: bool, compiler_params: Any, + cost_estimate: CostEstimate | None, ): def _maybe_squeeze_out_bdim( x: jax.Array, bdim: int | batching.NotMapped @@ -536,8 +567,9 @@ def _pallas_call_batching_rule( return x return jnp.squeeze(x, axis=bdim) - axis_size, = {x.shape[d] for x, d in zip(args, dims) - if d is not batching.not_mapped} + (axis_size,) = { + x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped + } if axis_size == 1: # Why are we even vmapping? args = map(_maybe_squeeze_out_bdim, args, dims) @@ -550,6 +582,7 @@ def _pallas_call_batching_rule( debug=debug, interpret=interpret, compiler_params=compiler_params, + cost_estimate=cost_estimate, ) return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out) @@ -581,6 +614,7 @@ def _pallas_call_batching_rule( debug=debug, interpret=interpret, compiler_params=compiler_params, + cost_estimate=cost_estimate, ) else: pass # No dynamic grid dimensions @@ -613,6 +647,7 @@ def _pallas_call_batching_rule( debug=debug, interpret=interpret, compiler_params=compiler_params, + cost_estimate=cost_estimate, ) if not dims: @@ -657,18 +692,29 @@ def _pallas_call_batching_rule( block_mappings=tuple(batched_block_mappings), index_map_avals=batched_index_map_avals, index_map_tree=batched_index_map_tree, - vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims)) + vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims), + ) + if cost_estimate is not None: + batched_cost_estimate = CostEstimate( + flops=cost_estimate.flops * axis_size, + bytes_accessed=cost_estimate.bytes_accessed * axis_size, + transcendentals=cost_estimate.transcendentals * axis_size, + ) + else: + batched_cost_estimate = None out = pallas_call_p.bind( *dynamic_grid_args, *args, jaxpr=jaxpr, name_and_src_info=name_and_src_info.replace( - name=f"{name_and_src_info.name}_batched"), + name=f"{name_and_src_info.name}_batched" + ), grid_mapping=batched_grid_mapping, input_output_aliases=input_output_aliases, debug=debug, interpret=interpret, compiler_params=compiler_params, + cost_estimate=batched_cost_estimate, ) return out, (0,) * len(out) @@ -971,6 +1017,7 @@ def pallas_call( interpret: bool = False, name: str | None = None, compiler_params: dict[str, Any] | None = None, + cost_estimate: CostEstimate | None = None, ) -> Callable[..., Any]: """Invokes a Pallas kernel on some inputs. @@ -1106,14 +1153,18 @@ def pallas_call( index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands]) out_flat = pallas_call_p.bind( - *dynamic_grid_bounds, *index_args, *rest_args, + *dynamic_grid_bounds, + *index_args, + *rest_args, jaxpr=jaxpr, name_and_src_info=name_and_src_info, debug=debug, interpret=interpret, grid_mapping=grid_mapping, input_output_aliases=tuple(input_output_aliases.items()), - compiler_params=compiler_params) + compiler_params=compiler_params, + cost_estimate=cost_estimate, + ) out = tree_util.tree_unflatten(out_tree, out_flat) return out return wrapped diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index e89e1323c..b94adfb8f 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -48,6 +48,7 @@ def pallas_call_lowering( input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: pallas_core.GridMapping, compiler_params: dict[str, Any], + cost_estimate: pallas_core.CostEstimate | None, ): del interpret if grid_mapping.num_dynamic_grid_bounds: diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index e6bde1924..9a768ed53 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -21,6 +21,7 @@ https://jax.readthedocs.io/en/latest/pallas.html. from jax._src.deprecations import register as _register_deprecation from jax._src.pallas.core import Blocked from jax._src.pallas.core import BlockSpec +from jax._src.pallas.core import CostEstimate from jax._src.pallas.core import IndexingMode from jax._src.pallas.core import no_block_spec from jax._src.pallas.core import Unblocked diff --git a/jax/experimental/pallas/ops/tpu/megablox/gmm.py b/jax/experimental/pallas/ops/tpu/megablox/gmm.py index eb197b186..320851422 100644 --- a/jax/experimental/pallas/ops/tpu/megablox/gmm.py +++ b/jax/experimental/pallas/ops/tpu/megablox/gmm.py @@ -520,7 +520,7 @@ def gmm( (lhs_bytes * tiles_n) + (rhs_bytes * max_active_tiles) + out_bytes ) flops = 2 * m * k * n - cost_estimate = pltpu.CostEstimate( + cost_estimate = pl.CostEstimate( flops=flops, bytes_accessed=bytes_accessed, transcendentals=0 ) call_gmm = pl.pallas_call( @@ -541,10 +541,10 @@ def gmm( compiler_params=dict( mosaic=dict( dimension_semantics=("parallel", "arbitrary", "arbitrary"), - cost_estimate=cost_estimate, ) ), interpret=interpret, + cost_estimate=cost_estimate, ) out = call_gmm( @@ -759,7 +759,7 @@ def tgmm( (lhs_bytes * tiles_n) + (rhs_bytes * tiles_k) + out_bytes ) flops = 2 * m * k * n - cost_estimate = pltpu.CostEstimate( + cost_estimate = pl.CostEstimate( flops=flops, bytes_accessed=bytes_accessed, transcendentals=0 ) lhs = lhs.swapaxes(0, 1) diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index 98f9d5c0f..79d773379 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -22,12 +22,12 @@ from jax._src.pallas.mosaic.core import semaphore from jax._src.pallas.mosaic.core import SemaphoreType from jax._src.pallas.mosaic.core import TPUMemorySpace from jax._src.pallas.mosaic.lowering import LoweringException +from jax._src.pallas.mosaic.pipeline import ARBITRARY from jax._src.pallas.mosaic.pipeline import BufferedRef from jax._src.pallas.mosaic.pipeline import emit_pipeline from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations -from jax._src.pallas.mosaic.pipeline import ARBITRARY from jax._src.pallas.mosaic.pipeline import PARALLEL from jax._src.pallas.mosaic.primitives import async_copy from jax._src.pallas.mosaic.primitives import async_remote_copy @@ -38,14 +38,15 @@ from jax._src.pallas.mosaic.primitives import DeviceIdType from jax._src.pallas.mosaic.primitives import get_barrier_semaphore from jax._src.pallas.mosaic.primitives import make_async_copy from jax._src.pallas.mosaic.primitives import make_async_remote_copy +from jax._src.pallas.mosaic.primitives import prng_random_bits +from jax._src.pallas.mosaic.primitives import prng_seed from jax._src.pallas.mosaic.primitives import repeat from jax._src.pallas.mosaic.primitives import roll from jax._src.pallas.mosaic.primitives import semaphore_read from jax._src.pallas.mosaic.primitives import semaphore_signal from jax._src.pallas.mosaic.primitives import semaphore_wait -from jax._src.pallas.mosaic.primitives import prng_seed -from jax._src.pallas.mosaic.primitives import prng_random_bits from jax._src.pallas.mosaic.random import to_pallas_key +# Remove this import after October 22th 2024. from jax._src.tpu_custom_call import CostEstimate # TODO(cperivol): Temporary alias to the global run_scoped. Remove diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 358b81dbd..481e301e2 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1487,12 +1487,8 @@ class PallasCallTest(PallasBaseTest): f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), - compiler_params=dict( - mosaic=dict( - cost_estimate=pltpu.CostEstimate( - flops=1234, transcendentals=21, bytes_accessed=12345 - ) - ) + cost_estimate=pl.CostEstimate( + flops=1234, transcendentals=21, bytes_accessed=12345 ), ) (analysis_result,) = jax.jit(f).lower(x).compile().cost_analysis() @@ -1500,6 +1496,26 @@ class PallasCallTest(PallasBaseTest): self.assertEqual(analysis_result['transcendentals'], 21) self.assertEqual(analysis_result['bytes accessed'], 12345) + + def test_cost_analysis_vmap(self): + def kernel(x, y): + y[:] = x[:] + batch_size = 3 + x = jnp.arange(batch_size * 1024.).reshape(batch_size, 8, 128) + f = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + cost_estimate=pl.CostEstimate( + flops=1234, transcendentals=21, bytes_accessed=12345 + ), + ) + f = jax.vmap(f) + (analysis_result,) = jax.jit(f).lower(x).compile().cost_analysis() + self.assertEqual(analysis_result['flops'], batch_size * 1234) + self.assertEqual(analysis_result['transcendentals'], batch_size * 21) + self.assertEqual(analysis_result['bytes accessed'], batch_size * 12345) + + def test_vmem_limit(self): shape = (128, 128)