Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created January 25, 2019 13:43
Show Gist options
  • Save koshian2/d28a3cbdfc8f398f7d836739dbc6b5b2 to your computer and use it in GitHub Desktop.
Save koshian2/d28a3cbdfc8f398f7d836739dbc6b5b2 to your computer and use it in GitHub Desktop.
Train with 1000 arccos triplet loss
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import Callback, History, LearningRateScheduler
import tensorflow.keras.backend as K
from tensorflow.contrib.tpu.python.tpu import keras_support
from train1000 import cifar10
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import os, json, math
# VGG-like model
def create_siamese(latent_dims):
input = layers.Input((32, 32, 3))
x = input
for i in range(4):
for j in range(3):
x = layers.Conv2D(64*(2**i), 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
if i != 3:
x = layers.AveragePooling2D(2)(x)
x = layers.GlobalAveragePooling2D()(x)
if latent_dims != 512:
x = layers.Dense(latent_dims)(x)
return Model(input, x)
MARGIN = 0.5
def upper_triangle(matrix):
upper = tf.matrix_band_part(matrix, 0, -1)
diagonal = tf.matrix_band_part(matrix, 0, 0)
diagonal_mask = tf.sign(tf.abs(tf.matrix_band_part(diagonal, 0, 0)))
return upper * (1.0 - diagonal_mask)
# Idea of arccos distance is from
# ArcFace: Additive Angular Margin Loss for Deep Face Recognition
# https://arxiv.org/pdf/1801.07698.pdf
def arc_triplet_loss(label, embeddings):
# cosine similarity matrix
norm_emb = tf.nn.l2_normalize(embeddings, axis=-1)
x1 = tf.expand_dims(norm_emb, axis=0)
x2 = tf.expand_dims(norm_emb, axis=1)
cosine_sim = tf.reduce_sum(x1*x2, axis=-1)
# arccos distance
eps = K.epsilon() # 浮動小数点でnanが出るのでclip時に入れること
arccos_dist = tf.acos(tf.clip_by_value(cosine_sim, -1.0+eps, 1.0-eps))
# label equal matrix (* shape=[None, latent_dims])
lb1 = tf.expand_dims(label[:, 0], axis=0)
lb2 = tf.expand_dims(label[:, 0], axis=1)
equal_mat = K.cast(tf.equal(lb1, lb2), "float32")
# postives tf.whereが使えないので総当たりにする
positive_flag = upper_triangle(equal_mat)
# positive以外は-piを入れる
positive_dist = positive_flag*arccos_dist + (1.0-positive_flag)*(-math.pi)
positive_dist = tf.reshape(positive_dist, [-1,1])
# negatives
negative_flag = upper_triangle(1.0-equal_mat)
negative_dist = negative_flag*arccos_dist + (1.0-negative_flag)*math.pi
negative_dist = tf.reshape(negative_dist, [1,-1])
# triplet loss
#loss = tf.clip_by_value(positive_dist - negative_dist + MARGIN, 0.0, math.pi)
loss = tf.maximum(positive_dist - negative_dist + MARGIN, 0.0)
return tf.reduce_sum(loss)
class EmbeddingCallback(Callback):
def __init__(self, siamese_model, X_train, y_train, X_test, y_test):
self.model = siamese_model
self.X_train = X_train
self.y_train = y_train
self.X_test = X_test
self.y_test = y_test
self.test_nearest_neighbor_acc = []
self.test_threshold_simple = []
self.test_threshold_weighted = []
# スレッショルド使ってもうちょっと丁寧にやる
def pairwise_distance_matrix(self, anchor_embedding, target_embedding):
simirarities = np.zeros((target_embedding.shape[0], anchor_embedding.shape[0]), dtype=np.float32)
for i in tqdm(range(simirarities.shape[0])):
sim = cosine_similarity(target_embedding[i,:].reshape(1,-1), anchor_embedding)[0]
simirarities[i, :] = sim
eps = 1e-10
simirarities = np.clip(simirarities, -1.0+eps, 1.0-eps)
return np.arccos(simirarities)
# val_rate
def true_accept(self, distance, labels, threshold):
# uppper_mask
upper_mask = np.triu(np.ones(distance.shape, dtype=np.bool), k=1)
# true is same
truth_same = np.expand_dims(labels, axis=1) == np.expand_dims(labels, axis=0)
# pred same
pred_same = distance <= threshold
# true accept
ta = np.logical_and(pred_same, truth_same)
calc_true_same = np.logical_and(upper_mask, truth_same)
calc_ta = np.logical_and(upper_mask, ta)
# val rate
return np.sum(calc_ta) / np.sum(calc_true_same)
def find_threshold(self, distance_matrix, onehots, nof_fold=10):
assert distance_matrix.shape[0] == distance_matrix.shape[1]
n = distance_matrix.shape[0]
assert onehots.shape[0] == n
labels = np.sum(np.arange(onehots.shape[1]) * onehots, axis=1).astype(np.int32)
thresholds = np.arange(0.0, 4.0, 0.001, dtype=np.float32)
rate = np.zeros(thresholds.shape, dtype=np.float32)
for i, th in enumerate(thresholds):
rate[i] = self.true_accept(distance_matrix, labels, th)
print("val_rate", rate)
best_idx = np.argmax(rate)
print(f"Best threshold : {thresholds[best_idx]}, Best VAL : {rate[best_idx]:.04}")
return thresholds[best_idx]
# 最近傍を返す推定
def one_nearest_neighbor(self, distance_matrix, anchor_onehots, target_onehots):
assert distance_matrix.shape[0] == target_onehots.shape[0]
assert distance_matrix.shape[1] == anchor_onehots.shape[0]
indices = np.argsort(distance_matrix, axis=-1)
# train の場合は正方行列
if distance_matrix.shape[0] == distance_matrix.shape[1]:
print("This is train set")
index = indices[:, 1]
else:
index = indices[:, 0]
anchor_label = np.sum(np.arange(anchor_onehots.shape[1]).reshape(1,-1) * anchor_onehots, axis=-1)
y_pred = anchor_label[index]
y_true = np.sum(np.arange(target_onehots.shape[1]) * target_onehots, axis=-1)
return accuracy_score(y_true, y_pred)
# スレッショルドによる推定
def thresholding_pred(self, distance_matrix, threshold, anchor_onehots, target_onehots, use_weighted):
assert distance_matrix.shape[0] == target_onehots.shape[0]
assert distance_matrix.shape[1] == anchor_onehots.shape[0]
thresholded_distance = np.maximum(threshold-distance_matrix, 0.0)
if not use_weighted:
thresholded_distance = np.sign(thresholded_distance)
y_pred = np.zeros(target_onehots.shape[0])
print("thresholded num : ", np.mean(np.sum(np.logical_not(np.isclose(thresholded_distance,0.0)), axis=-1) ))
for i in range(y_pred.shape[0]):
score = thresholded_distance[i, :].reshape(-1,1) * anchor_onehots
pred_index = np.argmax(np.sum(score, axis=0))
y_pred[i] = pred_index
y_true = np.sum(np.arange(target_onehots.shape[1]) * target_onehots, axis=-1)
print(np.bincount(y_pred.astype(np.int32)))
return accuracy_score(y_true, y_pred)
def on_epoch_end(self, epoch, logs):
train_embedding = self.model.predict(self.X_train)
test_embedding = self.model.predict(self.X_test)
# distance matrix
distance_train = self.pairwise_distance_matrix(train_embedding, train_embedding)
distance_test = self.pairwise_distance_matrix(train_embedding, test_embedding)
# threshold
print("")
threshold = self.find_threshold(distance_train, self.y_train)
# 最近傍推定
print("Simple 1-Nearest Neighbor")
train_acc = self.one_nearest_neighbor(distance_train, self.y_train, self.y_train)
test_acc = self.one_nearest_neighbor(distance_test, self.y_train, self.y_test)
self.test_nearest_neighbor_acc.append(test_acc)
print(f"Train acc:{train_acc:.04}, Test acc:{test_acc:.04}, Best Test:{max(self.test_nearest_neighbor_acc):.04}")
# スレッショルドによる推定
print("Simple thresholding")
train_acc = self.thresholding_pred(distance_train, threshold, self.y_train, self.y_train, False)
test_acc = self.thresholding_pred(distance_test, threshold, self.y_train, self.y_test, False)
self.test_threshold_simple.append(test_acc)
print(f"Train acc:{train_acc:.04}, Test acc:{test_acc:.04}, Best Test:{max(self.test_threshold_simple):.04}")
print("Weiged thresholding")
train_acc = self.thresholding_pred(distance_train, threshold, self.y_train, self.y_train, True)
test_acc = self.thresholding_pred(distance_test, threshold, self.y_train, self.y_test, True)
self.test_threshold_weighted.append(test_acc)
print(f"Train acc:{train_acc:.04}, Test acc:{test_acc:.04}, Best Test:{max(self.test_threshold_weighted):.04}")
def data_augmentation(image):
outputs = np.zeros(image.shape, dtype=np.float32)
# crop
crop_x = np.random.randint(0, 4)
crop_y = np.random.randint(0, 4)
outputs[crop_x:crop_x+28, crop_y:crop_y+28, :] = image[crop_x:crop_x+28, crop_y:crop_y+28, :]
# flip
if np.random.rand() >= 0.5:
outputs = outputs[:, ::-1, :]
return outputs
def generator(X, y, batch_size, use_augmentation, n_latent_dims):
while True:
X_cache, y_cache = [], []
indices = np.random.permutation(X.shape[0])
for i in indices:
if use_augmentation:
X_cache.append(data_augmentation(X[i]))
else:
X_cache.append(X[i])
y_item = np.zeros(n_latent_dims)
y_item[0] = np.sum(np.arange(y.shape[1]) * y[i]) #1列目にラベルの数字を突っ込んでそれ以外はダミー
y_cache.append(y_item)
if(len(y_cache)==batch_size): #255で割ってあるから割らなくて良い
X_batch = np.asarray(X_cache, np.float32)
y_batch = np.asarray(y_cache, np.float32)
X_cache, y_cache = [], []
yield X_batch, y_batch
def step_decay(epoch):
x = 1e-3
if epoch >= 12: x = 2e-4
if epoch >= 19: x = 4e-5
return x
def train(n_dims, use_aug):
(X_train, y_train), (X_test, y_test) = cifar10()
siamese = create_siamese(n_dims)
siamese.compile("adam", arc_triplet_loss)
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
siamese = tf.contrib.tpu.keras_to_tpu_model(siamese, strategy=strategy)
embed_cb = EmbeddingCallback(siamese, X_train, y_train, X_test, y_test)
hist = History()
scheduler = LearningRateScheduler(step_decay)
batch_size = 200
siamese.fit_generator(generator(X_train, y_train, batch_size, use_aug, n_dims),
steps_per_epoch=X_train.shape[0]*400//batch_size, callbacks=[embed_cb, hist, scheduler],
max_queue_size=1, epochs=25)
if __name__ == "__main__":
K.clear_session()
print(512, "starts")
train(512, False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment