-
-
Save eubide/74338b11f161db444b93141c413865d0 to your computer and use it in GitHub Desktop.
Very fast function to get cosine similarity between 2 short texts, where counting the number of words is no needed (i.e. binary bag of words) but it works pretty well with non-ascii weird characters.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from unidecode import unidecode | |
| import re | |
| import sys | |
| import inflection | |
| import numpy as np | |
| import math | |
| from collections import defaultdict | |
| # Using cosine_similarity, own faster implementation, inspired by | |
| # https://towardsdatascience.com/calculating-string-similarity-in-python-276e18a7d33a | |
| _tokens_cache = defaultdict(lambda: None) | |
| _phone_regex = re.compile(r'[^\d]|^0+') | |
| def cosine_similarity(text1, text2, cache=False): | |
| # Return cosine similarity between text1 and text2 | |
| tok1 = tok2 = None | |
| if cache: | |
| tok1 = _tokens_cache[text1] | |
| tok2 = _tokens_cache[text2] | |
| if tok1 is None: | |
| tok1 = get_tokens(text1) | |
| if cache: | |
| _tokens_cache[text1] = tok1 | |
| if tok2 is None: | |
| tok2 = get_tokens(text2) | |
| if cache: | |
| _tokens_cache[text2] = tok2 | |
| if not tok1 or not tok2: | |
| return 0.0 | |
| if tok1 == tok2: | |
| return 1.0 | |
| vocabulary = set(tok1 + tok2) | |
| if len(vocabulary) == len(tok1) + len(tok2): | |
| # No intersections | |
| return 0.0 | |
| v1 = np.zeros(len(vocabulary)) | |
| v2 = np.zeros(len(vocabulary)) | |
| for i, w in enumerate(vocabulary): | |
| if w in tok1: | |
| v1[i] = 1 | |
| if w in tok2: | |
| v2[i] = 1 | |
| # This the cosine = v1 DOT v2 / (norm-2(v1) * norm-2(v2)) | |
| # equivalent but +2x faster than np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) | |
| return np.dot(v1, v2) / (math.sqrt(np.dot(v1, v1)) * math.sqrt(np.dot(v2, v2))) | |
| _stopwords = set(('hotel', 'amp')) | |
| # last is a subset from string.punctuation | |
| _nopunctuation = str.maketrans('()[]-&:;./-.', ' ', '\'`´!"#$%*+,<=>?@\\^_`{|}~') | |
| def get_tokens(text): | |
| text = text.translate(_nopunctuation) | |
| text = unidecode(text) | |
| text = text.lower() | |
| tokens = [inflection.singularize(w) for w in text.split() if len(w) > 1 and w not in _stopwords] | |
| for i, w in enumerate(tokens): | |
| # Erase text afterword like below, a common case, hotels that changed their names | |
| if w == 'ex' or w == 'formerly': | |
| tokens = tokens[:i] | |
| break | |
| return sorted(tokens) | |
| def phonenumber_equal(a, b): | |
| a = _phone_regex.sub('', a) | |
| b = _phone_regex.sub('', b) | |
| if len(a) > 8 or len(b) > 8 and a == b: # Only if they have at least 9 numbers | |
| return True | |
| return False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment