add fingerprint to debugging log

This commit is contained in:
Liyah Coleman 2021-05-24 17:31:20 +00:00
parent 56e9f7cb92
commit 369ca134fc

View File

@ -59,6 +59,10 @@ from . import partial_eval as pe
from . import xla
from . import ad
# Built in Python lists don't support weak refs but subclasses of lists do.
class WeakRefList(list):
pass
if sys.version_info >= (3, 8):
from functools import cached_property as maybe_cached_property
else:
@ -616,17 +620,19 @@ def xla_pmap_impl(fun: lu.WrappedFun, *args, backend, axis_name, axis_size,
global_axis_size, devices, name, in_axes, out_axes_thunk,
donated_invars, global_arg_shapes):
abstract_args = unsafe_map(xla.abstractify, args)
compiled_fun = parallel_callable(fun, backend, axis_name, axis_size,
compiled_fun, xla_executable = parallel_callable(fun, backend, axis_name, axis_size,
global_axis_size, devices, name,
in_axes, out_axes_thunk,
donated_invars, global_arg_shapes,
*abstract_args)
# Don't re-abstractify args unless logging is enabled for performance.
if config.jax_distributed_debug:
distributed_debug_log(("Running pmapped function", name),
("python function", fun.f),
("devices", devices),
("abstract args", map(xla.abstractify, args)))
("abstract args", map(xla.abstractify, args)),
("fingerprint", xla_executable.fingerprint))
return compiled_fun(*args)
@lu.cache
@ -890,7 +896,8 @@ def parallel_callable(fun: lu.WrappedFun,
handle_outs)
compiled = xla.backend_compile(backend, built, compile_options)
handle_args = partial(shard_args, compiled.local_devices(), input_indices)
return partial(execute_replicated, compiled, backend, handle_args, handle_outs)
execute_fun = partial(execute_replicated, compiled, backend, handle_args, handle_outs)
return WeakRefList([execute_fun, compiled])
multi_host_supported_collectives: Set[core.Primitive] = set()