Last active
September 26, 2023 17:04
-
-
Save Codegass/74e901428ab7cbe9722985e708349f5d to your computer and use it in GitHub Desktop.
embedding.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import pickle | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| from torch.nn.functional import cosine_similarity | |
| class MethodNameEmbedding: | |
| def __init__(self, cache_filepath="token_cache.pkl"): | |
| self.tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base") | |
| self.model = AutoModel.from_pretrained("microsoft/codebert-base") | |
| self.cache_filepath = cache_filepath | |
| self.new_tokens_count = 0 | |
| # 从硬盘中加载缓存,如果文件存在的话 | |
| try: | |
| with open(self.cache_filepath, 'rb') as f: | |
| self.token_cache = pickle.load(f) | |
| except FileNotFoundError: | |
| self.token_cache = {} | |
| def tokenize_camel_case(self, name: str) -> list: | |
| tokens = [] | |
| start = 0 | |
| for i in range(1, len(name)): | |
| if name[i].isupper(): | |
| tokens.append(name[start:i]) | |
| start = i | |
| tokens.append(name[start:]) | |
| return tokens | |
| def save_cache(self): | |
| with open(self.cache_filepath, 'wb') as f: | |
| pickle.dump(self.token_cache, f) | |
| self.new_tokens_count = 0 | |
| def get_embedding(self, token: str) -> torch.Tensor: | |
| if token in self.token_cache: | |
| return self.token_cache[token] | |
| nl_tokens = self.tokenizer.tokenize(token) | |
| token_ids = self.tokenizer.convert_tokens_to_ids(nl_tokens) | |
| embedding = self.model(torch.tensor(token_ids)[None,:])[1] | |
| self.token_cache[token] = embedding | |
| self.new_tokens_count += 1 | |
| # 当达到50个新的token时保存缓存 | |
| if self.new_tokens_count >= 50: | |
| self.save_cache() | |
| return embedding | |
| def get_method_name_embedding(self, method_name: str) -> torch.Tensor: | |
| tokens = self.tokenize_camel_case(method_name) | |
| embeddings = [self.get_embedding(token) for token in tokens] | |
| method_embedding = torch.mean(torch.stack(embeddings), dim=0) | |
| return method_embedding | |
| def cosine_similarity_method_names(self, method_name1: str, method_name2: str) -> float: | |
| embed1 = self.get_method_name_embedding(method_name1) | |
| embed2 = self.get_method_name_embedding(method_name2) | |
| return cosine_similarity(embed1, embed2).item() | |
| # 使用示例 | |
| embedder = MethodNameEmbedding() | |
| similarity = embedder.cosine_similarity_method_names("returnMaximumValue", "getValueFromDictionary") | |
| print(similarity) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment