Skip to content

Instantly share code, notes, and snippets.

@softwaredoug
Created January 1, 2023 20:51
Show Gist options
  • Save softwaredoug/804eb9cb960f722f0c46d355a21936ba to your computer and use it in GitHub Desktop.
Save softwaredoug/804eb9cb960f722f0c46d355a21936ba to your computer and use it in GitHub Desktop.

Revisions

  1. softwaredoug created this gist Jan 1, 2023.
    84 changes: 84 additions & 0 deletions encode-wikipedia-sentences.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,84 @@
    import numpy as np
    import os
    from time import perf_counter

    from sentence_transformers import SentenceTransformer, LoggingHandler
    import logging


    logging.basicConfig(format='%(asctime)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.DEBUG,
    handlers=[LoggingHandler()])


    def encode(sentences, chunk_size=20000):
    print("Loaded sentences")

    model_mini = SentenceTransformer('all-MiniLM-L6-v2')
    model_mpnet = SentenceTransformer('all-mpnet-base-v2')
    pool = model_mpnet.start_multi_process_pool()
    start = perf_counter()
    # pool = model.start_multi_process_pool()

    for chunk in range(0, len(sentences), chunk_size):
    mini_fname = f"data/wikisent2_{chunk}.npz"
    mpnet_fname = f"data/wikisent2-mpnet_{chunk}.npz"
    begin = chunk
    end = chunk + chunk_size
    if not os.path.exists(mini_fname):
    print(f"Processing mini {chunk}")
    embeddings = model_mini.encode(sentences[begin:end],
    show_progress_bar=True)
    print(f"Encoded sentences chunk {chunk} ({begin}-{end}) - {perf_counter() - start}")
    np.savez(mini_fname, embeddings)
    print("Saved sentences")
    else:
    print(f"Skipping mini {chunk}")

    if not os.path.exists(mpnet_fname):
    print(f"Processing mpnet {chunk}")
    embeddings = model_mpnet.encode_multi_process(sentences[begin:end], pool)
    print(f"Encoded sentences chunk {chunk} ({begin}-{end}) - {perf_counter() - start}")
    np.savez(mpnet_fname, embeddings)
    print("Saved sentences")
    else:
    print(f"Skipping mpnet {chunk}")


    def append(encoding="mini"):
    # Iterate all files in data/
    # Load them and append to a single file
    # This is to make it easier to load the data
    # in the future
    if encoding == "mini":
    encoding = ""
    files = []
    # Get all wikisent2_*.npz files in a list
    for fname in os.listdir("data"):
    if encoding != "":
    if fname.startswith(f"wikisent2-{encoding}") and fname.endswith(".npz"):
    files.append(fname)

    # Sort by chunk number
    files.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))

    # Load and append into one numpy array
    arrs = []
    for fname in files:
    print(f"Loading {fname}")
    arrs.append(np.load(f"data/{fname}").get("arr_0"))

    print("Concatenating")
    arr = np.concatenate(arrs)
    print(arr.shape)
    np.savez("data/wikisent2_{encoding}_all.npz", arr)


    if __name__ == "__main__":
    sentences = []
    with open('wikisent2.txt') as f:
    sentences = [line for line in f]
    encode(sentences)
    # append("mini")
    append("mpnet")