Last active
September 26, 2023 17:04
-
-
Save Codegass/74e901428ab7cbe9722985e708349f5d to your computer and use it in GitHub Desktop.
Revisions
-
Codegass revised this gist
Sep 26, 2023 . 1 changed file with 10 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -8,6 +8,7 @@ class MethodNameEmbedding: self.tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base") self.model = AutoModel.from_pretrained("microsoft/codebert-base") self.cache_filepath = cache_filepath self.new_tokens_count = 0 # 从硬盘中加载缓存,如果文件存在的话 try: @@ -26,6 +27,11 @@ class MethodNameEmbedding: tokens.append(name[start:]) return tokens def save_cache(self): with open(self.cache_filepath, 'wb') as f: pickle.dump(self.token_cache, f) self.new_tokens_count = 0 def get_embedding(self, token: str) -> torch.Tensor: if token in self.token_cache: return self.token_cache[token] @@ -34,10 +40,11 @@ class MethodNameEmbedding: token_ids = self.tokenizer.convert_tokens_to_ids(nl_tokens) embedding = self.model(torch.tensor(token_ids)[None,:])[1] self.token_cache[token] = embedding self.new_tokens_count += 1 # 当达到50个新的token时保存缓存 if self.new_tokens_count >= 50: self.save_cache() return embedding -
Codegass created this gist
Sep 26, 2023 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,58 @@ import pickle from transformers import AutoTokenizer, AutoModel import torch from torch.nn.functional import cosine_similarity class MethodNameEmbedding: def __init__(self, cache_filepath="token_cache.pkl"): self.tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base") self.model = AutoModel.from_pretrained("microsoft/codebert-base") self.cache_filepath = cache_filepath # 从硬盘中加载缓存,如果文件存在的话 try: with open(self.cache_filepath, 'rb') as f: self.token_cache = pickle.load(f) except FileNotFoundError: self.token_cache = {} def tokenize_camel_case(self, name: str) -> list: tokens = [] start = 0 for i in range(1, len(name)): if name[i].isupper(): tokens.append(name[start:i]) start = i tokens.append(name[start:]) return tokens def get_embedding(self, token: str) -> torch.Tensor: if token in self.token_cache: return self.token_cache[token] nl_tokens = self.tokenizer.tokenize(token) token_ids = self.tokenizer.convert_tokens_to_ids(nl_tokens) embedding = self.model(torch.tensor(token_ids)[None,:])[1] self.token_cache[token] = embedding # 保存缓存到硬盘 with open(self.cache_filepath, 'wb') as f: pickle.dump(self.token_cache, f) return embedding def get_method_name_embedding(self, method_name: str) -> torch.Tensor: tokens = self.tokenize_camel_case(method_name) embeddings = [self.get_embedding(token) for token in tokens] method_embedding = torch.mean(torch.stack(embeddings), dim=0) return method_embedding def cosine_similarity_method_names(self, method_name1: str, method_name2: str) -> float: embed1 = self.get_method_name_embedding(method_name1) embed2 = self.get_method_name_embedding(method_name2) return cosine_similarity(embed1, embed2).item() # 使用示例 embedder = MethodNameEmbedding() similarity = embedder.cosine_similarity_method_names("returnMaximumValue", "getValueFromDictionary") print(similarity)