mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add fingerprint to debugging log
This commit is contained in:
parent
56e9f7cb92
commit
369ca134fc
@ -59,6 +59,10 @@ from . import partial_eval as pe
|
||||
from . import xla
|
||||
from . import ad
|
||||
|
||||
# Built in Python lists don't support weak refs but subclasses of lists do.
|
||||
class WeakRefList(list):
|
||||
pass
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from functools import cached_property as maybe_cached_property
|
||||
else:
|
||||
@ -616,17 +620,19 @@ def xla_pmap_impl(fun: lu.WrappedFun, *args, backend, axis_name, axis_size,
|
||||
global_axis_size, devices, name, in_axes, out_axes_thunk,
|
||||
donated_invars, global_arg_shapes):
|
||||
abstract_args = unsafe_map(xla.abstractify, args)
|
||||
compiled_fun = parallel_callable(fun, backend, axis_name, axis_size,
|
||||
compiled_fun, xla_executable = parallel_callable(fun, backend, axis_name, axis_size,
|
||||
global_axis_size, devices, name,
|
||||
in_axes, out_axes_thunk,
|
||||
donated_invars, global_arg_shapes,
|
||||
*abstract_args)
|
||||
|
||||
# Don't re-abstractify args unless logging is enabled for performance.
|
||||
if config.jax_distributed_debug:
|
||||
distributed_debug_log(("Running pmapped function", name),
|
||||
("python function", fun.f),
|
||||
("devices", devices),
|
||||
("abstract args", map(xla.abstractify, args)))
|
||||
("abstract args", map(xla.abstractify, args)),
|
||||
("fingerprint", xla_executable.fingerprint))
|
||||
return compiled_fun(*args)
|
||||
|
||||
@lu.cache
|
||||
@ -890,7 +896,8 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
handle_outs)
|
||||
compiled = xla.backend_compile(backend, built, compile_options)
|
||||
handle_args = partial(shard_args, compiled.local_devices(), input_indices)
|
||||
return partial(execute_replicated, compiled, backend, handle_args, handle_outs)
|
||||
execute_fun = partial(execute_replicated, compiled, backend, handle_args, handle_outs)
|
||||
return WeakRefList([execute_fun, compiled])
|
||||
|
||||
multi_host_supported_collectives: Set[core.Primitive] = set()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user