import mlx.core as mx import numpy as np from transformers import PreTrainedTokenizer, AutoModel, AutoTokenizer import torch import torch.nn.functional as F from torch import Tensor from typing import List, Dict, Any, Tuple from mlx_lm.utils import load def tokenize_texts( tokenizer: PreTrainedTokenizer, sentences: List[str], max_length: int ) -> Dict[str, mx.array]: if not sentences: return { "input_ids": mx.zeros((0, max_length), dtype=mx.int32), "attention_mask": mx.zeros((0, max_length), dtype=mx.int32), } batch_mx = tokenizer( sentences, max_length=max_length, padding=True, truncation=True, return_tensors="mlx" ) return batch_mx def encode_batch( model: Any, batch_mx: Dict[str, mx.array] ) -> mx.array: model_output = model.model(batch_mx["input_ids"]) return model_output def pool_last_token_simple( last_hidden_state: mx.array, attention_mask: mx.array ) -> mx.array: sequence_lengths = mx.sum(attention_mask, axis=1) - 1 batch_size = last_hidden_state.shape[0] last_token_indices = mx.maximum(sequence_lengths, 0) pooled = last_hidden_state[mx.arange(batch_size), last_token_indices] return pooled def normalize_embeddings( embeddings: mx.array ) -> mx.array: norm = mx.linalg.norm(embeddings, ord=2, axis=-1, keepdims=True) normalized = embeddings / mx.maximum(norm, 1e-9) return normalized def hf_last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] def get_detailed_instruct(task_description: str, query: str) -> str: return f'Instruct: {task_description}\nQuery: {query}' def calculate_similarity_scores( query_embeddings: np.ndarray, doc_embeddings: np.ndarray ) -> np.ndarray: return (query_embeddings @ doc_embeddings.T) * 100 def compare_mlx_hf_embeddings( mlx_model: Any, hf_model: torch.nn.Module, tokenizer: PreTrainedTokenizer, sentences: List[str], max_length: int, rtol: float = 1e-4, atol: float = 1e-5, device: str = 'cpu' ) -> bool: print("\n--- Starting Comparison ---") print(f"Using sentences: {sentences}") print(f"Max length: {max_length}") print(f"PyTorch Device: {device}") print(f"Tolerances: rtol={rtol}, atol={atol}") print("\nRunning MLX Implementation...") try: batch_mx = tokenize_texts(tokenizer, sentences, max_length) print(f"MLX Tokenized input_ids shape: {batch_mx['input_ids'].shape}") mlx_hidden = encode_batch(mlx_model, batch_mx) mlx_hidden = mlx_hidden.astype(mx.float32) print(f"MLX Hidden state shape: {mlx_hidden.shape}") mlx_pooled = pool_last_token_simple(mlx_hidden, batch_mx['attention_mask']) print(f"MLX Pooled shape: {mlx_pooled.shape}") mlx_normalized = normalize_embeddings(mlx_pooled) print(f"MLX Normalized shape: {mlx_normalized.shape}") print(f"MLX Normalized dtype: {mlx_normalized.dtype}") mx.eval(mlx_normalized) mlx_result_np = np.array(mlx_normalized, copy=True) print("MLX Implementation finished.") print(f"MLX Result NumPy shape: {mlx_result_np.shape}, dtype: {mlx_result_np.dtype}") except Exception as e: print(f"Error during MLX execution: {e}") return False print("\nRunning PyTorch (Hugging Face) Reference...") try: hf_model.to(device) hf_model.eval() batch_pt = tokenizer( sentences, max_length=max_length, padding=True, truncation=True, return_tensors="pt" ).to(device) print(f"HF Tokenized input_ids shape: {batch_pt['input_ids'].shape}") with torch.no_grad(): outputs = hf_model(**batch_pt) hf_hidden = outputs.last_hidden_state hf_hidden = hf_hidden.to(torch.float32) print(f"HF Hidden state shape: {hf_hidden.shape}") hf_pooled = hf_last_token_pool(hf_hidden, batch_pt['attention_mask']) print(f"HF Pooled shape: {hf_pooled.shape}") hf_normalized = F.normalize(hf_pooled, p=2, dim=1) print(f"HF Normalized shape: {hf_normalized.shape}") print(f"HF Normalized dtype: {hf_normalized.dtype}") hf_result_np = hf_normalized.cpu().numpy() print("PyTorch (HF) Implementation finished.") except Exception as e: print(f"Error during PyTorch execution: {e}") return False print("\nComparing Results...") passed = True if mlx_result_np.shape != hf_result_np.shape: print(f"❌ FAILED: Shape mismatch!") print(f" MLX Shape: {mlx_result_np.shape}") print(f" HF Shape: {hf_result_np.shape}") passed = False else: print(f"✅ Shapes Match: {mlx_result_np.shape}") if passed: if np.allclose(mlx_result_np, hf_result_np, rtol=rtol, atol=atol): print(f"✅ PASSED: Numerical values are close within tolerance (rtol={rtol}, atol={atol}).") else: print(f"❌ FAILED: Numerical values differ significantly!") diff = np.abs(mlx_result_np - hf_result_np) print(f" Max absolute difference: {np.max(diff)}") print(f" Mean absolute difference: {np.mean(diff)}") passed = False print("\n--- Comparison Finished ---") return passed def test_embedding_similarity( mlx_model: Any, hf_model: torch.nn.Module, tokenizer: PreTrainedTokenizer, queries: List[str], documents: List[str], max_length: int, rtol: float = 1e-4, atol: float = 1e-5, device: str = 'cpu' ) -> Tuple[bool, np.ndarray, np.ndarray]: print("\n--- Starting Similarity Comparison ---") print(f"Queries: {len(queries)}") print(f"Documents: {len(documents)}") print(f"Max length: {max_length}") print(f"PyTorch Device: {device}") input_texts = queries + documents print("\nRunning MLX Implementation...") try: batch_mx = tokenize_texts(tokenizer, input_texts, max_length) mlx_hidden = encode_batch(mlx_model, batch_mx) mlx_hidden = mlx_hidden.astype(mx.float32) mlx_pooled = pool_last_token_simple(mlx_hidden, batch_mx['attention_mask']) mlx_normalized = normalize_embeddings(mlx_pooled) mx.eval(mlx_normalized) mlx_embeddings_np = np.array(mlx_normalized, copy=True) mlx_query_embeddings = mlx_embeddings_np[:len(queries)] mlx_doc_embeddings = mlx_embeddings_np[len(queries):] mlx_scores = calculate_similarity_scores(mlx_query_embeddings, mlx_doc_embeddings) print("MLX scores shape:", mlx_scores.shape) print("MLX similarity scores:") print(mlx_scores) except Exception as e: print(f"Error during MLX similarity calculation: {e}") return False, None, None print("\nRunning PyTorch Implementation...") try: hf_model.to(device) hf_model.eval() batch_pt = tokenizer( input_texts, max_length=max_length, padding=True, truncation=True, return_tensors="pt" ).to(device) with torch.no_grad(): outputs = hf_model(**batch_pt) hf_hidden = outputs.last_hidden_state hf_hidden = hf_hidden.to(torch.float32) hf_pooled = hf_last_token_pool(hf_hidden, batch_pt['attention_mask']) hf_normalized = F.normalize(hf_pooled, p=2, dim=1) hf_embeddings_np = hf_normalized.cpu().numpy() hf_query_embeddings = hf_embeddings_np[:len(queries)] hf_doc_embeddings = hf_embeddings_np[len(queries):] hf_scores = calculate_similarity_scores(hf_query_embeddings, hf_doc_embeddings) print("HF scores shape:", hf_scores.shape) print("HF similarity scores:") print(hf_scores) except Exception as e: print(f"Error during PyTorch similarity calculation: {e}") return False, None, None print("\nComparing Similarity Scores...") passed = True if mlx_scores.shape != hf_scores.shape: print(f"❌ FAILED: Shape mismatch in similarity scores!") print(f" MLX Shape: {mlx_scores.shape}") print(f" HF Shape: {hf_scores.shape}") passed = False else: print(f"✅ Similarity score shapes match: {mlx_scores.shape}") if passed: if np.allclose(mlx_scores, hf_scores, rtol=rtol, atol=atol): print(f"✅ PASSED: Similarity scores are close within tolerance (rtol={rtol}, atol={atol}).") else: print(f"❌ FAILED: Similarity scores differ significantly!") diff = np.abs(mlx_scores - hf_scores) print(f" Max absolute difference: {np.max(diff)}") print(f" Mean absolute difference: {np.mean(diff)}") print("\nDetailed score comparison (MLX vs HF):") for i, query in enumerate(queries): for j, doc in enumerate(documents): print(f"Query {i+1} - Doc {j+1}: {mlx_scores[i,j]:.2f} vs {hf_scores[i,j]:.2f} " + f"(diff: {abs(mlx_scores[i,j] - hf_scores[i,j]):.4f})") passed = False print("\n--- Similarity Comparison Finished ---") return passed, mlx_scores, hf_scores if __name__ == '__main__': MODEL_NAME = "Alibaba-NLP/gte-Qwen2-7B-instruct" MAX_LEN_TEST = 128 print(f"Loading mlx model '{MODEL_NAME}' ...") mlx_model, _ = load(MODEL_NAME) try: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) except Exception as e: print(f"Failed to load tokenizer '{MODEL_NAME}': {e}") exit() print(f"Loading HF PyTorch Model '{MODEL_NAME}'...") try: hf_model = AutoModel.from_pretrained(MODEL_NAME, torch_dtype=torch.float32) print(hf_model) except Exception as e: print(f"Failed to load HF PyTorch model '{MODEL_NAME}': {e}") exit() test_sentences = [ "This is a test sentence.", "Let's compare MLX and PyTorch.", "Short one.", "A significantly longer sentence to test padding and truncation mechanisms effectively." ] test_passed = compare_mlx_hf_embeddings( mlx_model=mlx_model, hf_model=hf_model, tokenizer=tokenizer, sentences=test_sentences, max_length=MAX_LEN_TEST, device="mps", rtol=1e-5, atol=1e-5 ) print(f"\nEmbedding Test Result: {'PASSED' if test_passed else 'FAILED'}") task = 'Given a web search query, retrieve relevant passages that answer the query' queries = [ get_detailed_instruct(task, 'how much protein should a female eat'), get_detailed_instruct(task, 'summit define') ] documents = [ "As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.", "Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments." ] similarity_passed, mlx_scores, hf_scores = test_embedding_similarity( mlx_model=mlx_model, hf_model=hf_model, tokenizer=tokenizer, queries=queries, documents=documents, max_length=MAX_LEN_TEST, device="mps", rtol=1e-3, atol=1e-3 ) print(f"\nSimilarity Test Result: {'PASSED' if similarity_passed else 'FAILED'}") print(f"\nOverall Test Result: {'PASSED' if (test_passed and similarity_passed) else 'FAILED'}")