Skip to content

Instantly share code, notes, and snippets.

@tomaarsen
Created September 24, 2025 16:08
Show Gist options
  • Select an option

  • Save tomaarsen/f02b628162b8d49a9f93d40758af6ef3 to your computer and use it in GitHub Desktop.

Select an option

Save tomaarsen/f02b628162b8d49a9f93d40758af6ef3 to your computer and use it in GitHub Desktop.

Revisions

  1. tomaarsen created this gist Sep 24, 2025.
    116 changes: 116 additions & 0 deletions update_e5_nl.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,116 @@
    import re
    from huggingface_hub import get_collection, ModelCard
    from sentence_transformers import SentenceTransformer
    from sentence_transformers.models import Normalize

    collection = get_collection(collection_slug="clips/e5-nl-68be9d3760240ce5c7d9f831")

    ST_SNIPPET_PATTERN = r"""\
    from sentence_transformers import SentenceTransformer
    model = SentenceTransformer\((?:'|")([a-zA-Z0-9_\/\.-]+?)(?:'|")\)
    input_texts = \[
    (?:'|")(.+?)(?:'|"),
    (?:'|")(.+?)(?:'|"),
    (?:'|")(.+?)(?:'|"),
    (?:'|")(.+?)(?:'|")
    \]
    embeddings = model.encode\(input_texts, normalize_embeddings=True\)
    """

    ST_SNIPPET_TEMPLATE = """
    from sentence_transformers import SentenceTransformer
    # Load the model from Hugging Face
    model = SentenceTransformer("{model_name}")
    # Perform inference using encode_query/encode_document for retrieval,
    # or encode_query for general purpose embeddings. Prompt prefixes
    # are automatically added with these two methods.
    queries = [
    {query1},
    {query2},
    ]
    documents = [
    {document1},
    {document2},
    ]
    query_embeddings = model.encode_query(queries)
    document_embeddings = model.encode_document(documents)
    print(query_embeddings.shape, document_embeddings.shape)
    {shapes}
    similarities = model.similarity(query_embeddings, document_embeddings)
    {similarities}
    """

    FINISHED_MODELS = []


    for item in collection.items:
    if item.item_type != "model":
    continue

    model_id = item.item_id

    if model_id in FINISHED_MODELS:
    continue

    model = SentenceTransformer(
    model_id,
    prompts={
    "query": "query: ",
    "document": "passage: ",
    },
    )
    model.add_module("2", Normalize())

    model_card = ModelCard.load(model_id)
    model_card.data.library_name = "sentence-transformers"
    model_card.data.language = "nl"
    tags = model_card.data.tags or []
    if "transformers" not in tags:
    tags.append("transformers")
    model_card.data.tags = tags

    content = model_card.content

    match = re.search(ST_SNIPPET_PATTERN, content)
    if match:
    model_name = match.group(1)
    queries = [match.group(2), match.group(3)]
    documents = [match.group(4), match.group(5)]
    if not queries[0].startswith("query: "):
    print("Unexpected query format in model card for", model_id)
    breakpoint()

    queries = [query.split("query: ")[-1].strip() for query in queries]
    documents = [doc.split("passage: ")[-1].strip() for doc in documents]

    query_embeddings = model.encode_query(queries, normalize_embeddings=True)
    doc_embeddings = model.encode_document(documents, normalize_embeddings=True)
    shapes = f"# {query_embeddings.shape} {doc_embeddings.shape}"

    similarities = model.similarity(query_embeddings, doc_embeddings)
    similarities = "# " + str(similarities).replace("\n", "\n# ")

    content = content[:match.start()] + ST_SNIPPET_TEMPLATE.format(
    model_name=model_id,
    query1=repr(queries[0]),
    query2=repr(queries[1]),
    document1=repr(documents[0]),
    document2=repr(documents[1]),
    shapes=shapes,
    similarities=similarities,
    ) + content[match.end():]
    else:
    print("No match found in model card for", model_id)
    breakpoint()

    model_card.content = content
    model_card.validate("model")

    model._model_card_text = str(model_card)
    # model.push_to_hub(model_id.replace("clips/", "tomaarsen/"), private=True)
    url = model.push_to_hub(model_id, create_pr=True)
    print("Pushed", model_id, "->", url)
    breakpoint()