|
|
@@ -0,0 +1,217 @@ |
|
|
import argparse |
|
|
import os |
|
|
import time |
|
|
import sys |
|
|
from openai import OpenAI |
|
|
|
|
|
def chunk_text(text, chunk_size=400, overlap=100): |
|
|
"""Split text into overlapping chunks of roughly chunk_size characters.""" |
|
|
if len(text) <= chunk_size: |
|
|
return [text] |
|
|
|
|
|
chunks = [] |
|
|
start = 0 |
|
|
|
|
|
while start < len(text): |
|
|
# Determine end of current chunk |
|
|
end = min(start + chunk_size, len(text)) |
|
|
|
|
|
# If we're not at the end of the text, try to find a good break point |
|
|
if end < len(text): |
|
|
# Look for a period, question mark, or exclamation followed by space or newline |
|
|
for i in range(end, max(start, end - 200), -1): |
|
|
if i < len(text) and text[i-1] in '.!?\n' and (i == len(text) or text[i].isspace()): |
|
|
end = i |
|
|
break |
|
|
|
|
|
# Add the chunk |
|
|
chunks.append(text[start:end]) |
|
|
|
|
|
# Move start position for next chunk, accounting for overlap |
|
|
start = end - overlap |
|
|
|
|
|
# Make sure we're making progress |
|
|
if start >= end: |
|
|
start = end |
|
|
|
|
|
return chunks |
|
|
|
|
|
def summarize_chunk(client, chunk, system_prompt, model, max_tokens): |
|
|
"""Summarize a single chunk of text.""" |
|
|
try: |
|
|
response = client.chat.completions.create( |
|
|
model=model, |
|
|
max_tokens=max_tokens, |
|
|
stream=False, # No streaming for individual chunks |
|
|
messages=[ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": system_prompt |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": chunk |
|
|
} |
|
|
] |
|
|
) |
|
|
return response.choices[0].message.content |
|
|
except Exception as e: |
|
|
print(f"Error summarizing chunk: {e}") |
|
|
return f"Error processing this part of the text: {e}" |
|
|
|
|
|
def main(): |
|
|
# This is fast, relatively small model from Google that summarizes decently |
|
|
# DEFAULT_OPEN_WEIGHTS_MODEL = "gemma2:9b" |
|
|
DEFAULT_OPEN_WEIGHTS_MODEL = "nomic-embed-text:latest" |
|
|
IGNORED_OLLAMA_API_KEY = "ollama" |
|
|
|
|
|
# This is the ollama server installed from ollama.com |
|
|
DEFAULT_OLLAMA_SERVER_URL = "http://localhost:11434/v1" |
|
|
|
|
|
# Default system prompt |
|
|
DEFAULT_SYSTEM_PROMPT = """You are a text summarization assistant. |
|
|
Generate a concise summary of the given input text while preserving the key information and main points. |
|
|
Provide the summary in three bullet points, totalling 100 words or less.""" |
|
|
|
|
|
# System prompt for combining chunks |
|
|
COMBINE_CHUNKS_PROMPT = """You are a text summarization assistant. |
|
|
Combine the following summaries into a coherent overall summary. |
|
|
Eliminate redundancies and ensure the final summary captures all key points. |
|
|
Provide the summary in three to five bullet points, totalling 150 words or less.""" |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Summarize text using Ollama models') |
|
|
parser.add_argument('--input-file', type=str, default='', help='Path to the input text file') |
|
|
parser.add_argument('--input-text', type=str, default='', help='Input text to summarize') |
|
|
parser.add_argument('--output-file', type=str, default='', help='Path to save the output summary') |
|
|
parser.add_argument('--model', type=str, default=DEFAULT_OPEN_WEIGHTS_MODEL, help='Model to use for the API') |
|
|
parser.add_argument('--base-url', type=str, default=DEFAULT_OLLAMA_SERVER_URL, help='Base URL for the Ollama server (which is OpenAI-compatible)') |
|
|
parser.add_argument('--max-tokens', type=int, default=100, help='Maximum number of tokens in the summary') |
|
|
parser.add_argument('--system-prompt', type=str, default=DEFAULT_SYSTEM_PROMPT, help='Custom system prompt to use') |
|
|
parser.add_argument('--chunk-size', type=int, default=1000, help='Character count per chunk for long texts') |
|
|
parser.add_argument('--chunk-overlap', type=int, default=100, help='Character overlap between chunks') |
|
|
parser.add_argument('--no-chunking', action='store_true', help='Disable chunking regardless of text length') |
|
|
args = parser.parse_args() |
|
|
|
|
|
user_message = "" |
|
|
if args.input_file: |
|
|
# Read input from file |
|
|
try: |
|
|
with open(args.input_file, 'r') as file: |
|
|
user_message = file.read() |
|
|
except Exception as e: |
|
|
print(f"Error reading input file: {e}") |
|
|
sys.exit(1) |
|
|
elif args.input_text: |
|
|
# Use input text from command-line argument |
|
|
user_message = args.input_text |
|
|
else: |
|
|
print("Either input-file or input-text must be provided") |
|
|
sys.exit(1) |
|
|
|
|
|
client = OpenAI( |
|
|
api_key=IGNORED_OLLAMA_API_KEY, |
|
|
base_url=args.base_url |
|
|
) |
|
|
|
|
|
start = time.time() |
|
|
|
|
|
# Determine if we need to chunk the text |
|
|
should_chunk = len(user_message) > args.chunk_size * 3 and not args.no_chunking |
|
|
print(f"Prompt is {len(user_message)}. Chunking is {should_chunk}") |
|
|
|
|
|
|
|
|
if should_chunk: |
|
|
print(f"Text is {len(user_message)} characters long. Processing in chunks...") |
|
|
chunks = chunk_text(user_message, args.chunk_size, args.chunk_overlap) |
|
|
print(f"Split into {len(chunks)} chunks") |
|
|
|
|
|
# Process each chunk |
|
|
chunk_summaries = [] |
|
|
for i, chunk in enumerate(chunks): |
|
|
print(f"\nProcessing chunk {i+1}/{len(chunks)} ({len(chunk)} characters)...") |
|
|
chunk_summary = summarize_chunk(client, chunk, args.system_prompt, args.model, args.max_tokens) |
|
|
chunk_summaries.append(chunk_summary) |
|
|
print(f"Chunk {i+1} summary: {chunk_summary[:100]}...") |
|
|
|
|
|
# Combine the summaries |
|
|
combined_text = "\n\n".join([f"Summary {i+1}:\n{summary}" for i, summary in enumerate(chunk_summaries)]) |
|
|
|
|
|
print("\nCombining all summaries into final result...") |
|
|
|
|
|
# Stream the final combined summary |
|
|
try: |
|
|
stream = client.chat.completions.create( |
|
|
model=args.model, |
|
|
max_tokens=args.max_tokens * 2, # Allow more tokens for the combined summary |
|
|
stream=True, |
|
|
messages=[ |
|
|
{ |
|
|
"role": "system", |
|
|
# "content": COMBINE_CHUNKS_PROMPT |
|
|
"content": args.system_prompt |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": combined_text |
|
|
} |
|
|
] |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error creating final summary: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
else: |
|
|
# Process normally without chunking |
|
|
try: |
|
|
stream = client.chat.completions.create( |
|
|
model=args.model, |
|
|
max_tokens=args.max_tokens, |
|
|
stream=True, |
|
|
messages=[ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": args.system_prompt |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": user_message |
|
|
} |
|
|
] |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"ChatCompletionStream error: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
print("\nFinal Summary: ") |
|
|
|
|
|
content = "" |
|
|
completion_tokens = 0 |
|
|
|
|
|
for chunk in stream: |
|
|
if chunk.choices[0].delta.content is not None: |
|
|
content += chunk.choices[0].delta.content |
|
|
print(chunk.choices[0].delta.content, end='', flush=True) |
|
|
completion_tokens += len(chunk.choices[0].delta.content) |
|
|
|
|
|
print(f"\n\nFinal Output: \n{content}") |
|
|
|
|
|
elapsed = time.time() - start |
|
|
print(f"\n\nTokens generated in final Output: {completion_tokens}") |
|
|
print(f"Output tokens per Second: {completion_tokens/elapsed:.2f}") |
|
|
print(f"Total Execution Time: {elapsed:.2f} seconds") |
|
|
|
|
|
# Save to output file if specified |
|
|
if args.output_file: |
|
|
try: |
|
|
with open(args.output_file, 'w') as file: |
|
|
file.write(content) |
|
|
print(f"Output saved to {args.output_file}") |
|
|
except Exception as e: |
|
|
print(f"Error saving to output file: {e}") |
|
|
|
|
|
# TIP: the FIRST time you run this code, the model is loaded into memory, and this will |
|
|
# be slow. On my Mac M2 I got 4 tokens/s. |
|
|
# But run it a second time within 5 minutes and it'll run 10x faster! |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |