diff --git a/jax/tools/pgo_nsys_converter.py b/jax/tools/pgo_nsys_converter.py index 69f085d3f..10209c9a8 100644 --- a/jax/tools/pgo_nsys_converter.py +++ b/jax/tools/pgo_nsys_converter.py @@ -54,7 +54,7 @@ if __name__ == '__main__': proc.wait() thunk_re = re.compile("hlo_op=(.*)#") - cost_dictionary = dict() + cost_dictionary: dict[str, list] = dict() with open(f"{args.pgle_output_path}", 'w', newline='') as protofile: with open(f"{pgle_folder}{pgle_filename}.pbtxt_{report_name}.csv", newline='') as csvfile: reader = csv.DictReader(csvfile)