[Pallas] Add name parameter to core_map

PiperOrigin-RevId: 731536152
This commit is contained in:
Sharad Vikram 2025-02-26 18:58:21 -08:00 committed by jax authors
parent 0f0d5e90ef
commit 1ecbac9702
3 changed files with 9 additions and 1 deletions

View File

@ -1016,6 +1016,7 @@ def core_map(
interpret: bool = False,
debug: bool = False,
cost_estimate: CostEstimate | None = None,
name: str | None = None,
):
"""Runs a function on a mesh, mapping it over the devices in the mesh.
@ -1030,6 +1031,7 @@ def core_map(
cost_estimate: The cost estimate of the function.
"""
def wrapped(f):
name_ = name or f.__name__
flat_args, in_tree = tree_util.tree_flatten(((), {}))
flat_fun, out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(f,
@ -1042,7 +1044,7 @@ def core_map(
compiler_params=compiler_params,
interpret=interpret,
debug=debug,
cost_estimate=cost_estimate)
cost_estimate=cost_estimate, name=name_)
if out:
raise ValueError("core_map-ped functions must not return any outputs.")
return tree_util.tree_unflatten(out_tree_thunk(), out)
@ -1080,6 +1082,7 @@ def default_mesh_discharge_rule(
debug,
interpret,
cost_estimate,
name,
):
"""Discharges a ``core_map`` over a mesh to a ``pallas_call``."""
del out_avals # Unused.
@ -1100,6 +1103,7 @@ def default_mesh_discharge_rule(
from jax._src.pallas import pallas_call # Avoid circular dependency.
outs = pallas_call.pallas_call(
body,
name=name,
out_shape=[in_avals[idx] for idx in modified_idxs],
in_specs=[any_spec] * len(in_avals),
out_specs=[any_spec] * len(modified_idxs),

View File

@ -246,6 +246,7 @@ def _tensorcore_mesh_discharge_rule(
interpret: bool,
debug: bool,
cost_estimate: pallas_core.CostEstimate | None,
name: str,
):
assert isinstance(mesh, TensorCoreMesh)
if compiler_params and not isinstance(compiler_params, TPUCompilerParams):
@ -274,6 +275,7 @@ def _tensorcore_mesh_discharge_rule(
interpret=interpret,
backend="mosaic_tpu",
cost_estimate=cost_estimate,
name=name,
)
pallas_core._core_map_mesh_rules[TensorCoreMesh] = (

View File

@ -545,6 +545,7 @@ def _gpu_mesh_discharge_rule(
interpret,
debug,
cost_estimate,
name,
):
if not isinstance(mesh, GPUMesh):
raise TypeError(f"Mesh must be a GPUMesh, got {type(mesh)}")
@ -568,6 +569,7 @@ def _gpu_mesh_discharge_rule(
debug=debug,
interpret=interpret,
cost_estimate=cost_estimate,
name=name,
)