Created
September 24, 2025 16:08
-
-
Save tomaarsen/f02b628162b8d49a9f93d40758af6ef3 to your computer and use it in GitHub Desktop.
Revisions
-
tomaarsen created this gist
Sep 24, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,116 @@ import re from huggingface_hub import get_collection, ModelCard from sentence_transformers import SentenceTransformer from sentence_transformers.models import Normalize collection = get_collection(collection_slug="clips/e5-nl-68be9d3760240ce5c7d9f831") ST_SNIPPET_PATTERN = r"""\ from sentence_transformers import SentenceTransformer model = SentenceTransformer\((?:'|")([a-zA-Z0-9_\/\.-]+?)(?:'|")\) input_texts = \[ (?:'|")(.+?)(?:'|"), (?:'|")(.+?)(?:'|"), (?:'|")(.+?)(?:'|"), (?:'|")(.+?)(?:'|") \] embeddings = model.encode\(input_texts, normalize_embeddings=True\) """ ST_SNIPPET_TEMPLATE = """ from sentence_transformers import SentenceTransformer # Load the model from Hugging Face model = SentenceTransformer("{model_name}") # Perform inference using encode_query/encode_document for retrieval, # or encode_query for general purpose embeddings. Prompt prefixes # are automatically added with these two methods. queries = [ {query1}, {query2}, ] documents = [ {document1}, {document2}, ] query_embeddings = model.encode_query(queries) document_embeddings = model.encode_document(documents) print(query_embeddings.shape, document_embeddings.shape) {shapes} similarities = model.similarity(query_embeddings, document_embeddings) {similarities} """ FINISHED_MODELS = [] for item in collection.items: if item.item_type != "model": continue model_id = item.item_id if model_id in FINISHED_MODELS: continue model = SentenceTransformer( model_id, prompts={ "query": "query: ", "document": "passage: ", }, ) model.add_module("2", Normalize()) model_card = ModelCard.load(model_id) model_card.data.library_name = "sentence-transformers" model_card.data.language = "nl" tags = model_card.data.tags or [] if "transformers" not in tags: tags.append("transformers") model_card.data.tags = tags content = model_card.content match = re.search(ST_SNIPPET_PATTERN, content) if match: model_name = match.group(1) queries = [match.group(2), match.group(3)] documents = [match.group(4), match.group(5)] if not queries[0].startswith("query: "): print("Unexpected query format in model card for", model_id) breakpoint() queries = [query.split("query: ")[-1].strip() for query in queries] documents = [doc.split("passage: ")[-1].strip() for doc in documents] query_embeddings = model.encode_query(queries, normalize_embeddings=True) doc_embeddings = model.encode_document(documents, normalize_embeddings=True) shapes = f"# {query_embeddings.shape} {doc_embeddings.shape}" similarities = model.similarity(query_embeddings, doc_embeddings) similarities = "# " + str(similarities).replace("\n", "\n# ") content = content[:match.start()] + ST_SNIPPET_TEMPLATE.format( model_name=model_id, query1=repr(queries[0]), query2=repr(queries[1]), document1=repr(documents[0]), document2=repr(documents[1]), shapes=shapes, similarities=similarities, ) + content[match.end():] else: print("No match found in model card for", model_id) breakpoint() model_card.content = content model_card.validate("model") model._model_card_text = str(model_card) # model.push_to_hub(model_id.replace("clips/", "tomaarsen/"), private=True) url = model.push_to_hub(model_id, create_pr=True) print("Pushed", model_id, "->", url) breakpoint()