rocm_jax/jax/jaxpr_util.py

209 lines
6.4 KiB
Python
Raw Normal View History

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for the Jaxpr IR."""
import collections
[JAX] Add support for generating "equation profiles" in JAX. An "equation profile" is a pprof profile that maps equations in a jaxpr to the Python stack traces at which they were generated. Pprof can be used a number of ways to analyze and visualize the result. For example, for a profile from a Resnet-50 training step from Flax, we can identify the most common primitives: ``` $ pprof --tags /tmp/myprof Main binary filename not available. primitive: Total 6062.0 1509.0 (24.89%): mul 936.0 (15.44%): add 589.0 ( 9.72%): reshape 492.0 ( 8.12%): div 485.0 ( 8.00%): broadcast_in_dim 330.0 ( 5.44%): reduce_sum 322.0 ( 5.31%): integer_pow 230.0 ( 3.79%): add_any 174.0 ( 2.87%): convert_element_type 160.0 ( 2.64%): select 158.0 ( 2.61%): conv_general_dilated 116.0 ( 1.91%): sub 110.0 ( 1.81%): eq 110.0 ( 1.81%): neg 104.0 ( 1.72%): max 53.0 ( 0.87%): rsqrt 52.0 ( 0.86%): rev 49.0 ( 0.81%): custom_jvp_call_jaxpr 49.0 ( 0.81%): gt 5.0 (0.082%): xla_call 4.0 (0.066%): min 3.0 (0.049%): dot_general 3.0 (0.049%): lt 2.0 (0.033%): cos 2.0 (0.033%): exp 2.0 (0.033%): iota 2.0 (0.033%): log 2.0 (0.033%): psum 2.0 (0.033%): reduce_max 2.0 (0.033%): stop_gradient 1.0 (0.016%): argmax 1.0 (0.016%): reduce_window_max 1.0 (0.016%): select_and_scatter_add 1.0 (0.016%): transpose 1.0 (0.016%): xla_pmap ``` Or the lines of code that generated the most equations: ``` $ pprof --text /tmp/myprof Main binary filename not available. Type: equations Showing nodes accounting for 6038, 99.60% of 6062 total Dropped 5 nodes (cum <= 30) flat flat% sum% cum cum% 1537 25.35% 25.35% 1537 25.35% _compute_stats 1484 24.48% 49.84% 1484 24.48% _normalize 849 14.01% 63.84% 6062 100% __call__ 644 10.62% 74.46% 644 10.62% <unknown> 483 7.97% 82.43% 483 7.97% <unknown> 392 6.47% 88.90% 6061 100% train_step 324 5.34% 94.24% 324 5.34% <unknown> 161 2.66% 96.90% 161 2.66% <unknown> 57 0.94% 97.84% 4292 70.80% loss_fn 52 0.86% 98.70% 52 0.86% schedule 39 0.64% 99.34% 39 0.64% softmax_cross_entropy 8 0.13% 99.47% 30 0.49% compute_metrics 6 0.099% 99.57% 61 1.01% cross_entropy_loss 1 0.016% 99.59% 1321 21.79% apply_gradients 1 0.016% 99.60% 6062 100% train_and_evaluate 0 0% 99.60% 6062 100% <unknown> 0 0% 99.60% 6062 100% __init__ 0 0% 99.60% 3872 63.87% _call_wrapped_method 0 0% 99.60% 6062 100% _run_and_get_tests_result 0 0% 99.60% 6062 100% _run_code_in_main 0 0% 99.60% 6062 100% _run_in_app 0 0% 99.60% 6062 100% _run_main 0 0% 99.60% 3872 63.87% apply 0 0% 99.60% 161 2.66% apply_updates 0 0% 99.60% 6062 100% main 0 0% 99.60% 6062 100% main_function 0 0% 99.60% 6062 100% run 0 0% 99.60% 6062 100% runTests 0 0% 99.60% 6062 100% run_filename_as_main 0 0% 99.60% 6062 100% run_tests 0 0% 99.60% 3872 63.87% scope_fn 0 0% 99.60% 6062 100% test_train_and_evaluate 0 0% 99.60% 1159 19.12% update_fn 0 0% 99.60% 3872 63.87% wrapped_fn 0 0% 99.60% 3872 63.87% wrapped_module_method 0 0% 99.60% 3872 63.87% wrapper ``` I highly recommend the pprof HTTP visualization, using --tagleaf to introduce pseudoframes for each primitive, and to use the "flame" visualization. ``` pprof --tagleaf=primitive --http=:8080 myprof ``` [XLA:Python] Add helpers to Traceback and for working with pprof profiles. * Define hash and equality operators on Tracebacks. * Add functions for converting JSON to and from pprof profile protocol buffers. * Add a helper method that exposes PyCode_Addr2Line to Python. PiperOrigin-RevId: 421395346
2022-01-12 14:27:17 -08:00
import gzip
import itertools
import json
import types
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Tuple
from jax._src import core
from jax._src import util
from jax._src import source_info_util
from jax._src.lib import xla_client
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
def all_eqns(jaxpr: core.Jaxpr):
for eqn in jaxpr.eqns:
yield (jaxpr, eqn)
for subjaxpr in core.subjaxprs(jaxpr):
yield from all_eqns(subjaxpr)
def collect_eqns(jaxpr: core.Jaxpr, key: Callable):
d = collections.defaultdict(list)
for _, eqn in all_eqns(jaxpr):
d[key(eqn)].append(eqn)
return dict(d)
def histogram(jaxpr: core.Jaxpr, key: Callable,
key_fmt: Callable = lambda x: x):
d = collect_eqns(jaxpr, key)
return {key_fmt(k): len(v) for k, v in d.items()}
def primitives(jaxpr: core.Jaxpr):
return histogram(jaxpr, lambda eqn: eqn.primitive.name)
def primitives_by_source(jaxpr: core.Jaxpr):
def key(eqn):
src = source_info_util.summarize(eqn.source_info)
return (eqn.primitive.name, src)
return histogram(jaxpr, key, ' @ '.join)
def primitives_by_shape(jaxpr: core.Jaxpr):
def shape_fmt(var):
return '*' if isinstance(var, core.DropVar) else var.aval.str_short()
def key(eqn):
return (eqn.primitive.name, ' '.join(map(shape_fmt, eqn.outvars)))
return histogram(jaxpr, key, ' :: '.join)
def source_locations(jaxpr: core.Jaxpr):
def key(eqn):
return source_info_util.summarize(eqn.source_info)
return histogram(jaxpr, key)
MaybeEqn = Optional[core.JaxprEqn]
def var_defs_and_refs(jaxpr: core.Jaxpr):
defs: Dict[core.Var, MaybeEqn] = {}
refs: Dict[core.Var, List[MaybeEqn]] = {}
def read(a: core.Atom, eqn: MaybeEqn):
2022-05-02 17:11:44 -07:00
if not isinstance(a, core.Literal):
assert a in defs, a
assert a in refs, a
refs[a].append(eqn)
def write(v: core.Var, eqn: MaybeEqn):
assert v not in defs, v
assert v not in refs, v
if not isinstance(v, core.DropVar):
defs[v] = eqn
refs[v] = []
for v in jaxpr.constvars:
write(v, None)
for v in jaxpr.invars:
write(v, None)
for eqn in jaxpr.eqns:
for a in eqn.invars:
read(a, eqn)
for v in eqn.outvars:
write(v, eqn)
for a in jaxpr.outvars:
read(a, None)
res = [(v, defs[v], refs[v]) for v in defs]
subs = map(var_defs_and_refs, core.subjaxprs(jaxpr))
return [(jaxpr, res), *subs] if subs else (jaxpr, res)
def vars_by_fanout(jaxpr: core.Jaxpr):
def fmt_key(var, eqn):
if eqn is None:
return f'{var} <- invar'
else:
src = source_info_util.summarize(eqn.source_info)
return f'{var} <- {eqn.primitive.name} @ {src}'
def hist(jaxpr, reads):
return {fmt_key(var, var_def): len(var_refs)
for var, var_def, var_refs in reads}
return [(j, hist(j, reads)) for j, reads in var_defs_and_refs(jaxpr)] # pytype: disable=bad-unpacking
def print_histogram(histogram: Dict[Any, int]):
count_width = max(len(str(v)) for v in histogram.values())
count_fmt = '{:>' + str(count_width) + 'd}'
pairs = [(v, k) for k, v in histogram.items()]
for count, name in reversed(sorted(pairs)):
print(count_fmt.format(count), name)
[JAX] Add support for generating "equation profiles" in JAX. An "equation profile" is a pprof profile that maps equations in a jaxpr to the Python stack traces at which they were generated. Pprof can be used a number of ways to analyze and visualize the result. For example, for a profile from a Resnet-50 training step from Flax, we can identify the most common primitives: ``` $ pprof --tags /tmp/myprof Main binary filename not available. primitive: Total 6062.0 1509.0 (24.89%): mul 936.0 (15.44%): add 589.0 ( 9.72%): reshape 492.0 ( 8.12%): div 485.0 ( 8.00%): broadcast_in_dim 330.0 ( 5.44%): reduce_sum 322.0 ( 5.31%): integer_pow 230.0 ( 3.79%): add_any 174.0 ( 2.87%): convert_element_type 160.0 ( 2.64%): select 158.0 ( 2.61%): conv_general_dilated 116.0 ( 1.91%): sub 110.0 ( 1.81%): eq 110.0 ( 1.81%): neg 104.0 ( 1.72%): max 53.0 ( 0.87%): rsqrt 52.0 ( 0.86%): rev 49.0 ( 0.81%): custom_jvp_call_jaxpr 49.0 ( 0.81%): gt 5.0 (0.082%): xla_call 4.0 (0.066%): min 3.0 (0.049%): dot_general 3.0 (0.049%): lt 2.0 (0.033%): cos 2.0 (0.033%): exp 2.0 (0.033%): iota 2.0 (0.033%): log 2.0 (0.033%): psum 2.0 (0.033%): reduce_max 2.0 (0.033%): stop_gradient 1.0 (0.016%): argmax 1.0 (0.016%): reduce_window_max 1.0 (0.016%): select_and_scatter_add 1.0 (0.016%): transpose 1.0 (0.016%): xla_pmap ``` Or the lines of code that generated the most equations: ``` $ pprof --text /tmp/myprof Main binary filename not available. Type: equations Showing nodes accounting for 6038, 99.60% of 6062 total Dropped 5 nodes (cum <= 30) flat flat% sum% cum cum% 1537 25.35% 25.35% 1537 25.35% _compute_stats 1484 24.48% 49.84% 1484 24.48% _normalize 849 14.01% 63.84% 6062 100% __call__ 644 10.62% 74.46% 644 10.62% <unknown> 483 7.97% 82.43% 483 7.97% <unknown> 392 6.47% 88.90% 6061 100% train_step 324 5.34% 94.24% 324 5.34% <unknown> 161 2.66% 96.90% 161 2.66% <unknown> 57 0.94% 97.84% 4292 70.80% loss_fn 52 0.86% 98.70% 52 0.86% schedule 39 0.64% 99.34% 39 0.64% softmax_cross_entropy 8 0.13% 99.47% 30 0.49% compute_metrics 6 0.099% 99.57% 61 1.01% cross_entropy_loss 1 0.016% 99.59% 1321 21.79% apply_gradients 1 0.016% 99.60% 6062 100% train_and_evaluate 0 0% 99.60% 6062 100% <unknown> 0 0% 99.60% 6062 100% __init__ 0 0% 99.60% 3872 63.87% _call_wrapped_method 0 0% 99.60% 6062 100% _run_and_get_tests_result 0 0% 99.60% 6062 100% _run_code_in_main 0 0% 99.60% 6062 100% _run_in_app 0 0% 99.60% 6062 100% _run_main 0 0% 99.60% 3872 63.87% apply 0 0% 99.60% 161 2.66% apply_updates 0 0% 99.60% 6062 100% main 0 0% 99.60% 6062 100% main_function 0 0% 99.60% 6062 100% run 0 0% 99.60% 6062 100% runTests 0 0% 99.60% 6062 100% run_filename_as_main 0 0% 99.60% 6062 100% run_tests 0 0% 99.60% 3872 63.87% scope_fn 0 0% 99.60% 6062 100% test_train_and_evaluate 0 0% 99.60% 1159 19.12% update_fn 0 0% 99.60% 3872 63.87% wrapped_fn 0 0% 99.60% 3872 63.87% wrapped_module_method 0 0% 99.60% 3872 63.87% wrapper ``` I highly recommend the pprof HTTP visualization, using --tagleaf to introduce pseudoframes for each primitive, and to use the "flame" visualization. ``` pprof --tagleaf=primitive --http=:8080 myprof ``` [XLA:Python] Add helpers to Traceback and for working with pprof profiles. * Define hash and equality operators on Tracebacks. * Add functions for converting JSON to and from pprof profile protocol buffers. * Add a helper method that exposes PyCode_Addr2Line to Python. PiperOrigin-RevId: 421395346
2022-01-12 14:27:17 -08:00
def _pprof_profile(
profile: Dict[Tuple[Optional[xla_client.Traceback], core.Primitive], int]
) -> bytes:
"""Converts a profile into a compressed pprof protocol buffer.
The input profile is a map from (traceback, primitive) pairs to counts.
"""
s: DefaultDict[str, int]
func: DefaultDict[types.CodeType, int]
loc: DefaultDict[Tuple[types.CodeType, int], int]
s = collections.defaultdict(itertools.count(1).__next__)
func = collections.defaultdict(itertools.count(1).__next__)
loc = collections.defaultdict(itertools.count(1).__next__)
s[""] = 0
primitive_key = s["primitive"]
samples = []
for (tb, primitive), count in profile.items():
if tb is None:
frames = []
else:
raw_frames = zip(*tb.raw_frames())
frames = [loc[(code, lasti)] for code, lasti in raw_frames
if source_info_util.is_user_filename(code.co_filename)] # type: ignore
[JAX] Add support for generating "equation profiles" in JAX. An "equation profile" is a pprof profile that maps equations in a jaxpr to the Python stack traces at which they were generated. Pprof can be used a number of ways to analyze and visualize the result. For example, for a profile from a Resnet-50 training step from Flax, we can identify the most common primitives: ``` $ pprof --tags /tmp/myprof Main binary filename not available. primitive: Total 6062.0 1509.0 (24.89%): mul 936.0 (15.44%): add 589.0 ( 9.72%): reshape 492.0 ( 8.12%): div 485.0 ( 8.00%): broadcast_in_dim 330.0 ( 5.44%): reduce_sum 322.0 ( 5.31%): integer_pow 230.0 ( 3.79%): add_any 174.0 ( 2.87%): convert_element_type 160.0 ( 2.64%): select 158.0 ( 2.61%): conv_general_dilated 116.0 ( 1.91%): sub 110.0 ( 1.81%): eq 110.0 ( 1.81%): neg 104.0 ( 1.72%): max 53.0 ( 0.87%): rsqrt 52.0 ( 0.86%): rev 49.0 ( 0.81%): custom_jvp_call_jaxpr 49.0 ( 0.81%): gt 5.0 (0.082%): xla_call 4.0 (0.066%): min 3.0 (0.049%): dot_general 3.0 (0.049%): lt 2.0 (0.033%): cos 2.0 (0.033%): exp 2.0 (0.033%): iota 2.0 (0.033%): log 2.0 (0.033%): psum 2.0 (0.033%): reduce_max 2.0 (0.033%): stop_gradient 1.0 (0.016%): argmax 1.0 (0.016%): reduce_window_max 1.0 (0.016%): select_and_scatter_add 1.0 (0.016%): transpose 1.0 (0.016%): xla_pmap ``` Or the lines of code that generated the most equations: ``` $ pprof --text /tmp/myprof Main binary filename not available. Type: equations Showing nodes accounting for 6038, 99.60% of 6062 total Dropped 5 nodes (cum <= 30) flat flat% sum% cum cum% 1537 25.35% 25.35% 1537 25.35% _compute_stats 1484 24.48% 49.84% 1484 24.48% _normalize 849 14.01% 63.84% 6062 100% __call__ 644 10.62% 74.46% 644 10.62% <unknown> 483 7.97% 82.43% 483 7.97% <unknown> 392 6.47% 88.90% 6061 100% train_step 324 5.34% 94.24% 324 5.34% <unknown> 161 2.66% 96.90% 161 2.66% <unknown> 57 0.94% 97.84% 4292 70.80% loss_fn 52 0.86% 98.70% 52 0.86% schedule 39 0.64% 99.34% 39 0.64% softmax_cross_entropy 8 0.13% 99.47% 30 0.49% compute_metrics 6 0.099% 99.57% 61 1.01% cross_entropy_loss 1 0.016% 99.59% 1321 21.79% apply_gradients 1 0.016% 99.60% 6062 100% train_and_evaluate 0 0% 99.60% 6062 100% <unknown> 0 0% 99.60% 6062 100% __init__ 0 0% 99.60% 3872 63.87% _call_wrapped_method 0 0% 99.60% 6062 100% _run_and_get_tests_result 0 0% 99.60% 6062 100% _run_code_in_main 0 0% 99.60% 6062 100% _run_in_app 0 0% 99.60% 6062 100% _run_main 0 0% 99.60% 3872 63.87% apply 0 0% 99.60% 161 2.66% apply_updates 0 0% 99.60% 6062 100% main 0 0% 99.60% 6062 100% main_function 0 0% 99.60% 6062 100% run 0 0% 99.60% 6062 100% runTests 0 0% 99.60% 6062 100% run_filename_as_main 0 0% 99.60% 6062 100% run_tests 0 0% 99.60% 3872 63.87% scope_fn 0 0% 99.60% 6062 100% test_train_and_evaluate 0 0% 99.60% 1159 19.12% update_fn 0 0% 99.60% 3872 63.87% wrapped_fn 0 0% 99.60% 3872 63.87% wrapped_module_method 0 0% 99.60% 3872 63.87% wrapper ``` I highly recommend the pprof HTTP visualization, using --tagleaf to introduce pseudoframes for each primitive, and to use the "flame" visualization. ``` pprof --tagleaf=primitive --http=:8080 myprof ``` [XLA:Python] Add helpers to Traceback and for working with pprof profiles. * Define hash and equality operators on Tracebacks. * Add functions for converting JSON to and from pprof profile protocol buffers. * Add a helper method that exposes PyCode_Addr2Line to Python. PiperOrigin-RevId: 421395346
2022-01-12 14:27:17 -08:00
samples.append({
"location_id": frames,
"value": [count],
"label": [{
"key": primitive_key,
"str": s[primitive.name]
}]
})
locations = [
{"id": loc_id,
"line": [{"function_id": func[code],
"line": xla_client.Traceback.code_addr2line(code, lasti)}]}
for (code, lasti), loc_id in loc.items()
]
functions = [
{"id": func_id,
"name": s[code.co_name],
"system_name": s[code.co_name],
"filename": s[code.co_filename],
"start_line": code.co_firstlineno}
for code, func_id in func.items()
]
sample_type = [{"type": s["equations"], "unit": s["count"]}]
# This is the JSON encoding of a pprof profile protocol buffer. See:
# https://github.com/google/pprof/blob/master/proto/profile.proto for a
# description of the format.
json_profile = json.dumps({
"string_table": list(s.keys()),
"location": locations,
"function": functions,
"sample_type": sample_type,
"sample": samples,
})
return gzip.compress(xla_client._xla.json_to_pprof_profile(json_profile))
def pprof_equation_profile(jaxpr: core.Jaxpr) -> bytes:
"""Generates a pprof profile that maps jaxpr equations to Python stack traces.
By visualizing the profile using pprof, one can identify Python code that is
responsible for yielding large numbers of jaxpr equations.
Args:
jaxpr: a Jaxpr.
Returns:
A gzip-compressed pprof Profile protocol buffer, suitable for passing to
pprof tool for visualization.
"""
d: DefaultDict[Tuple[Optional[xla_client.Traceback], core.Primitive], int]
d = collections.defaultdict(lambda: 0)
for _, eqn in all_eqns(jaxpr):
d[(eqn.source_info.traceback, eqn.primitive)] += 1
return _pprof_profile(d)