Skip to content

Instantly share code, notes, and snippets.

@voluntas
Created April 18, 2025 13:49
Show Gist options
  • Select an option

  • Save voluntas/52cce879d1bf2b5d86a5bf31481bb1e4 to your computer and use it in GitHub Desktop.

Select an option

Save voluntas/52cce879d1bf2b5d86a5bf31481bb1e4 to your computer and use it in GitHub Desktop.

Revisions

  1. voluntas revised this gist Apr 18, 2025. No changes.
  2. voluntas created this gist Apr 18, 2025.
    65 changes: 65 additions & 0 deletions main.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,65 @@
    import duckdb
    import torch
    from transformers import AutoModel, AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(
    "pfnet/plamo-embedding-1b", trust_remote_code=True
    )
    model = AutoModel.from_pretrained("pfnet/plamo-embedding-1b", trust_remote_code=True)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    # https://sora.shiguredo.jp/
    docs = [
    "WebRTC による音声・映像・メッセージメッセージのリアルタイムな配信と、その録音・録画を実現します",
    "お客様ご自身のサーバーにインストールしてご利用いただくパッケージソフトウェアです",
    "株式会社時雨堂がフルスクラッチで開発しており、日本語によるサポートとドキュメントを提供します",
    ]


    def main():
    # メモリ上にデータベースを作成
    conn = duckdb.connect()
    conn.sql("INSTALL vss")
    conn.sql("LOAD vss")

    conn.sql("CREATE SEQUENCE IF NOT EXISTS id_sequence START 1;")
    conn.sql(
    "CREATE TABLE IF NOT EXISTS sora_doc (id INTEGER DEFAULT nextval('id_sequence'), content TEXT, vector FLOAT[2048]);"
    )

    with torch.inference_mode():
    for doc, doc_embedding in zip(docs, model.encode_document(docs, tokenizer)):
    conn.execute(
    "INSERT INTO sora_doc (content, vector) VALUES (?, ?)",
    [
    doc,
    doc_embedding.cpu().squeeze().numpy().tolist(),
    ],
    )

    query = "時雨堂について教えてください"
    print("query:", query)

    with torch.inference_mode():
    query_embedding = model.encode_query(query, tokenizer)
    result = conn.sql(
    """
    SELECT content, array_cosine_distance(vector, ?::FLOAT[2048]) as distance
    FROM sora_doc
    ORDER BY distance
    """,
    params=[query_embedding.cpu().squeeze().numpy().tolist()],
    )

    for row in result.fetchall():
    print("distance:", row[1], "|", row[0])


    if __name__ == "__main__":
    main()
    # query: 時雨堂について教えてください
    # distance: 0.22656434774398804 | 株式会社時雨堂がフルスクラッチで開発しており、日本語によるサポートとドキュメントを提供します
    # distance: 0.39890891313552856 | お客様ご自身のサーバーにインストールしてご利用いただくパッケージソフトウェアです
    # distance: 0.5286199450492859 | WebRTC による音声・映像・メッセージメッセージのリアルタイムな配信と、その録音・録画を実現します
    11 changes: 11 additions & 0 deletions pyproject.toml
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,11 @@
    [project]
    name = "oreore-rag"
    version = "0.1.0"
    readme = "README.md"
    requires-python = ">=3.11"
    dependencies = [
    "duckdb>=1.2.2",
    "sentencepiece>=0.2.0",
    "torch>=2.6.0",
    "transformers>=4.51.3",
    ]