mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Log the time it takes to lower from jaxpr to stableHLO
PiperOrigin-RevId: 532115098
This commit is contained in:
parent
843106b73c
commit
8e1ad734bc
@ -61,6 +61,7 @@ from jax._src.sharding_impls import (
|
||||
|
||||
|
||||
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
|
||||
JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration"
|
||||
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
@ -258,6 +259,8 @@ def is_single_device_sharding(sharding) -> bool:
|
||||
return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding)
|
||||
|
||||
|
||||
# TODO(yashkatariya): This API takes in a string which means that string is
|
||||
# created even if it is not going to be logged.
|
||||
@contextlib.contextmanager
|
||||
def log_elapsed_time(fmt: str, event: Optional[str] = None):
|
||||
if _on_exit:
|
||||
|
@ -876,21 +876,24 @@ def lower_parallel_callable(
|
||||
raise ValueError("Ordered effects not supported in `pmap`.")
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
ordered_effects,
|
||||
backend,
|
||||
lowering_platform or backend.platform,
|
||||
sharding_impls.ReplicaAxisContext(axis_env),
|
||||
name_stack,
|
||||
donated_invars,
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=None,
|
||||
result_shardings=None,
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
num_replicas=replicas.num_global_replicas)
|
||||
with dispatch.log_elapsed_time(
|
||||
f"Finished jaxpr to MLIR module conversion {name_stack} "
|
||||
"in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
ordered_effects,
|
||||
backend,
|
||||
lowering_platform or backend.platform,
|
||||
sharding_impls.ReplicaAxisContext(axis_env),
|
||||
name_stack,
|
||||
donated_invars,
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=None,
|
||||
result_shardings=None,
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
num_replicas=replicas.num_global_replicas)
|
||||
return PmapComputation(lowering_result.module, pci=pci, replicas=replicas,
|
||||
shards=shards, tuple_args=tuple_args,
|
||||
unordered_effects=unordered_effects,
|
||||
@ -1916,23 +1919,27 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
in closed_jaxpr.effects):
|
||||
raise ValueError("Ordered effects are not supported for more than 1 device.")
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
ordered_effects,
|
||||
backend,
|
||||
# Optionally, override the lowering platform
|
||||
lowering_platform or backend.platform,
|
||||
axis_ctx,
|
||||
name_stack,
|
||||
donated_invars,
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=in_mlir_shardings,
|
||||
result_shardings=out_mlir_shardings,
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
num_replicas=nreps,
|
||||
num_partitions=num_partitions)
|
||||
|
||||
with dispatch.log_elapsed_time(
|
||||
f"Finished jaxpr to MLIR module conversion {name_stack} "
|
||||
"in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
ordered_effects,
|
||||
backend,
|
||||
# Optionally, override the lowering platform
|
||||
lowering_platform or backend.platform,
|
||||
axis_ctx,
|
||||
name_stack,
|
||||
donated_invars,
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=in_mlir_shardings,
|
||||
result_shardings=out_mlir_shardings,
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
num_replicas=nreps,
|
||||
num_partitions=num_partitions)
|
||||
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
@ -2242,22 +2249,26 @@ def lower_mesh_computation(
|
||||
closed_jaxpr.effects))
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(
|
||||
closed_jaxpr.effects))
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
ordered_effects,
|
||||
backend,
|
||||
lowering_platform or backend.platform,
|
||||
axis_ctx,
|
||||
name_stack,
|
||||
donated_invars,
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=in_partitions,
|
||||
result_shardings=out_partitions,
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
num_replicas=num_replicas,
|
||||
num_partitions=num_partitions)
|
||||
with dispatch.log_elapsed_time(
|
||||
f"Finished jaxpr to MLIR module conversion {name_stack} "
|
||||
"in {elapsed_time} sec", event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
ordered_effects,
|
||||
backend,
|
||||
lowering_platform or backend.platform,
|
||||
axis_ctx,
|
||||
name_stack,
|
||||
donated_invars,
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=in_partitions,
|
||||
result_shardings=out_partitions,
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
num_replicas=num_replicas,
|
||||
num_partitions=num_partitions)
|
||||
|
||||
return MeshComputation(
|
||||
str(name_stack),
|
||||
lowering_result.module,
|
||||
|
Loading…
x
Reference in New Issue
Block a user