# -*- coding: utf-8 -*- import sys import pprint import copy from collections import OrderedDict class ComputeAPICall: def __init__(self, api, original): self.api = api self.original = original # original logs self.attrs = OrderedDict() def __str__(self): return self.api + "\n" + pprint.pformat(self.attrs) + "\n" __repr__ = __str__ def shape_hash(self): dic = copy.deepcopy(self.attrs) dic.pop("algo") return hash(frozenset(dic.items()) | frozenset({self.api})) def all_hash(self): return hash(frozenset(self.attrs.items()) | frozenset({self.api})) def parse_compute_block(api, file): ret = [] for line in file: line = line.strip() if not line: break for ignore in ["handle", "workSpace", "workSpaceSizeInBytes", "arrayLength", "reorderType", "dxData", "vect:", "nbDims:", "dyData", "wData", "alpha", "beta", "xData", "yData", "Process", "mathType:", "mode:", "Time:"]: if ignore in line: break else: ret.append(line) call = ComputeAPICall(api, "\n".join(ret)) def parse_line(line): ret = {} ret["name"], rest = line.strip(":").split(":") for attr in rest.strip(";").split(";"): attr = attr.strip() name, value = attr.split("=") if name == "type": continue ret[name] = value return ret curr_arg_stack = [] curr_indent = 0 for line in ret: if line.startswith("i!"): line = line[2:] line_strip = line.lstrip() indent = len(line) - len(line_strip) line = parse_line(line_strip) if indent <= curr_indent: # pop previous line curr_arg_stack.pop() if indent < curr_indent: # pop previous block curr_arg_stack.pop() curr_indent = indent curr_arg_stack.append(line["name"]) name = ".".join(curr_arg_stack) if "val" in line: call.attrs[name] = line["val"] return call def find_compute_calls(filename): allowed = ["cudnnConvolutionForward", "cudnnConvolutionBackwardFilter", "cudnnConvolutionBackwardData"] with open(filename) as f: for line in f: line = line.strip() if line.startswith("I!"): for api in allowed: if api + "()" in line: blk = parse_compute_block(api, f) yield blk else: continue def find_compute_calls_dedup(filename): all_calls = set() cnt = 0 for call in find_compute_calls(filename): h = call.all_hash() if h not in all_calls: all_calls.add(h) yield call else: cnt += 1 if cnt > 2000: # found 2k dup, assume that no more new convs will appear return if __name__ == "__main__": # filename = sys.argv[1] # for call in find_compute_calls_dedup(filename): # print(call) v7 = list(find_compute_calls_dedup("cudnnlog_cudnn7_cu102.txt")) v8 = list(find_compute_calls_dedup("cudnnlog_cudnn8_cu102.txt")) v7_map = {x.shape_hash() : x for x in v7} for v8_call in v8: shape_hash = v8_call.shape_hash() if shape_hash in v7_map: v7_call = v7_map[shape_hash] if v7_call.all_hash() != v8_call.all_hash(): print("v7", v7_call) print("v8", v8_call) print('--' * 10) else: import IPython as IP; IP.embed()