mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
[Pallas] Add name parameter to core_map
PiperOrigin-RevId: 731536152
This commit is contained in:
parent
0f0d5e90ef
commit
1ecbac9702
@ -1016,6 +1016,7 @@ def core_map(
|
||||
interpret: bool = False,
|
||||
debug: bool = False,
|
||||
cost_estimate: CostEstimate | None = None,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""Runs a function on a mesh, mapping it over the devices in the mesh.
|
||||
|
||||
@ -1030,6 +1031,7 @@ def core_map(
|
||||
cost_estimate: The cost estimate of the function.
|
||||
"""
|
||||
def wrapped(f):
|
||||
name_ = name or f.__name__
|
||||
flat_args, in_tree = tree_util.tree_flatten(((), {}))
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun(
|
||||
lu.wrap_init(f,
|
||||
@ -1042,7 +1044,7 @@ def core_map(
|
||||
compiler_params=compiler_params,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
cost_estimate=cost_estimate)
|
||||
cost_estimate=cost_estimate, name=name_)
|
||||
if out:
|
||||
raise ValueError("core_map-ped functions must not return any outputs.")
|
||||
return tree_util.tree_unflatten(out_tree_thunk(), out)
|
||||
@ -1080,6 +1082,7 @@ def default_mesh_discharge_rule(
|
||||
debug,
|
||||
interpret,
|
||||
cost_estimate,
|
||||
name,
|
||||
):
|
||||
"""Discharges a ``core_map`` over a mesh to a ``pallas_call``."""
|
||||
del out_avals # Unused.
|
||||
@ -1100,6 +1103,7 @@ def default_mesh_discharge_rule(
|
||||
from jax._src.pallas import pallas_call # Avoid circular dependency.
|
||||
outs = pallas_call.pallas_call(
|
||||
body,
|
||||
name=name,
|
||||
out_shape=[in_avals[idx] for idx in modified_idxs],
|
||||
in_specs=[any_spec] * len(in_avals),
|
||||
out_specs=[any_spec] * len(modified_idxs),
|
||||
|
@ -246,6 +246,7 @@ def _tensorcore_mesh_discharge_rule(
|
||||
interpret: bool,
|
||||
debug: bool,
|
||||
cost_estimate: pallas_core.CostEstimate | None,
|
||||
name: str,
|
||||
):
|
||||
assert isinstance(mesh, TensorCoreMesh)
|
||||
if compiler_params and not isinstance(compiler_params, TPUCompilerParams):
|
||||
@ -274,6 +275,7 @@ def _tensorcore_mesh_discharge_rule(
|
||||
interpret=interpret,
|
||||
backend="mosaic_tpu",
|
||||
cost_estimate=cost_estimate,
|
||||
name=name,
|
||||
)
|
||||
|
||||
pallas_core._core_map_mesh_rules[TensorCoreMesh] = (
|
||||
|
@ -545,6 +545,7 @@ def _gpu_mesh_discharge_rule(
|
||||
interpret,
|
||||
debug,
|
||||
cost_estimate,
|
||||
name,
|
||||
):
|
||||
if not isinstance(mesh, GPUMesh):
|
||||
raise TypeError(f"Mesh must be a GPUMesh, got {type(mesh)}")
|
||||
@ -568,6 +569,7 @@ def _gpu_mesh_discharge_rule(
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
cost_estimate=cost_estimate,
|
||||
name=name,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user