Skip to content

Instantly share code, notes, and snippets.

@varrek
Last active January 15, 2022 20:16
Show Gist options
  • Save varrek/fe7470abbe55e3369170956268b77d6c to your computer and use it in GitHub Desktop.
Save varrek/fe7470abbe55e3369170956268b77d6c to your computer and use it in GitHub Desktop.
from sklearn.metrics.pairwise import cosine_similarity
cos_sim = cosine_similarity(df_test['embed'].tolist(), df['cat_embeding'].tolist())
indexes = np.argmax(cos_sim, axis=1)
cats = df.loc[indexes]['cat_name']
df_test['prediction'] = cats.tolist()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment