Skip to content

Instantly share code, notes, and snippets.

@Codegass
Last active September 26, 2023 17:04
Show Gist options
  • Save Codegass/74e901428ab7cbe9722985e708349f5d to your computer and use it in GitHub Desktop.
Save Codegass/74e901428ab7cbe9722985e708349f5d to your computer and use it in GitHub Desktop.
embedding.py
import pickle
from transformers import AutoTokenizer, AutoModel
import torch
from torch.nn.functional import cosine_similarity
class MethodNameEmbedding:
def __init__(self, cache_filepath="token_cache.pkl"):
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
self.model = AutoModel.from_pretrained("microsoft/codebert-base")
self.cache_filepath = cache_filepath
self.new_tokens_count = 0
# 从硬盘中加载缓存,如果文件存在的话
try:
with open(self.cache_filepath, 'rb') as f:
self.token_cache = pickle.load(f)
except FileNotFoundError:
self.token_cache = {}
def tokenize_camel_case(self, name: str) -> list:
tokens = []
start = 0
for i in range(1, len(name)):
if name[i].isupper():
tokens.append(name[start:i])
start = i
tokens.append(name[start:])
return tokens
def save_cache(self):
with open(self.cache_filepath, 'wb') as f:
pickle.dump(self.token_cache, f)
self.new_tokens_count = 0
def get_embedding(self, token: str) -> torch.Tensor:
if token in self.token_cache:
return self.token_cache[token]
nl_tokens = self.tokenizer.tokenize(token)
token_ids = self.tokenizer.convert_tokens_to_ids(nl_tokens)
embedding = self.model(torch.tensor(token_ids)[None,:])[1]
self.token_cache[token] = embedding
self.new_tokens_count += 1
# 当达到50个新的token时保存缓存
if self.new_tokens_count >= 50:
self.save_cache()
return embedding
def get_method_name_embedding(self, method_name: str) -> torch.Tensor:
tokens = self.tokenize_camel_case(method_name)
embeddings = [self.get_embedding(token) for token in tokens]
method_embedding = torch.mean(torch.stack(embeddings), dim=0)
return method_embedding
def cosine_similarity_method_names(self, method_name1: str, method_name2: str) -> float:
embed1 = self.get_method_name_embedding(method_name1)
embed2 = self.get_method_name_embedding(method_name2)
return cosine_similarity(embed1, embed2).item()
# 使用示例
embedder = MethodNameEmbedding()
similarity = embedder.cosine_similarity_method_names("returnMaximumValue", "getValueFromDictionary")
print(similarity)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment