#!/usr/bin/env python3 import requests import json from typing import List, Dict, Union, Optional import subprocess import argparse from rich.console import Console from rich.table import Table from rich.panel import Panel from rich.text import Text from rich import box from rich.syntax import Syntax from rich.tree import Tree import os class OllamaModelManager: def __init__(self, base_url: str = "http://localhost:11434"): self.base_url = base_url self.console = Console() self.custom_tags = self.load_custom_tags() def _request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Union[Dict, List[Dict]]: url = f"{self.base_url}{endpoint}" try: response = requests.request(method, url, json=data) response.raise_for_status() return response.json() except requests.exceptions.RequestException as e: self.console.print(f"[bold red]Error making request to {url}: {str(e)}[/bold red]") return None def create_model(self, name: str, modelfile: str, stream: bool = True) -> Union[Dict, List[Dict]]: data = {"name": name, "modelfile": modelfile, "stream": stream} result = self._request("POST", "/api/create", data) if result: self.console.print(f"[green]Model {name} created successfully.[/green]") return result def check_blob_exists(self, digest: str) -> bool: url = f"{self.base_url}/api/blobs/{digest}" try: response = requests.head(url) return response.status_code == 200 except requests.exceptions.RequestException: return False def create_blob(self, digest: str, file_path: str) -> Optional[str]: url = f"{self.base_url}/api/blobs/{digest}" try: with open(file_path, 'rb') as file: response = requests.post(url, data=file) response.raise_for_status() return response.text except (IOError, requests.exceptions.RequestException) as e: self.console.print(f"[bold red]Error creating blob: {str(e)}[/bold red]") return None def list_local_models(self) -> Optional[Dict]: models = self._request("GET", "/api/tags") if models: for model in models.get('models', []): model['tags'] = self.add_tags(model) return models def show_model_info(self, name: str, verbose: bool = False) -> Optional[Dict]: data = {"name": name, "verbose": verbose} info = self._request("POST", "/api/show", data) if info: size = info['details'].get('parameter_size', '') if not size: size = info['details'].get('size', 'Unknown') modified = info.get('modified_at', 'Unknown') info['tags'] = self.add_tags({'name': name, 'size': size, 'modified': modified}) return info def copy_model(self, source: str, destination: str) -> bool: data = {"source": source, "destination": destination} result = self._request("POST", "/api/copy", data) if result: self.console.print(f"[green]Model {source} copied to {destination} successfully.[/green]") if source in self.custom_tags: self.custom_tags[destination] = self.custom_tags[source].copy() self.save_custom_tags() return result is not None def delete_model(self, name: str) -> bool: data = {"name": name} result = self._request("DELETE", "/api/delete", data) if result: self.console.print(f"[green]Model {name} deleted successfully.[/green]") if name in self.custom_tags: del self.custom_tags[name] self.save_custom_tags() return result is not None def pull_model(self, name: str, insecure: bool = False, stream: bool = True) -> Union[Dict, List[Dict], None]: data = {"name": name, "insecure": insecure, "stream": stream} result = self._request("POST", "/api/pull", data) if result: self.console.print(f"[green]Model {name} pulled successfully.[/green]") return result def push_model(self, name: str, insecure: bool = False, stream: bool = True) -> Union[Dict, List[Dict], None]: data = {"name": name, "insecure": insecure, "stream": stream} result = self._request("POST", "/api/push", data) if result: self.console.print(f"[green]Model {name} pushed successfully.[/green]") return result def generate_embeddings(self, model: str, input: Union[str, List[str]], truncate: bool = True, options: Optional[Dict] = None, keep_alive: str = "5m") -> Optional[Dict]: data = { "model": model, "input": input, "truncate": truncate, "options": options or {}, "keep_alive": keep_alive } return self._request("POST", "/api/embed", data) def list_running_models(self) -> Optional[Dict]: models = self._request("GET", "/api/ps") if models: for model in models.get('models', []): model['tags'] = self.add_tags(model) return models def get_ollama_list(self) -> List[str]: try: result = subprocess.run(['ollama', 'list'], capture_output=True, text=True, check=True) return result.stdout.strip().split('\n')[1:] # Skip header except subprocess.CalledProcessError as e: self.console.print(f"[bold red]Error running 'ollama list': {str(e)}[/bold red]") return [] def parse_line(self, line: str) -> Dict[str, str]: parts = line.split('\t') return { 'name': parts[0].strip(), 'id': parts[1].strip(), 'size': parts[2].strip(), 'modified': ' '.join(parts[3:]).strip() } def add_tags(self, model: Dict[str, str]) -> List[str]: tags = [] # Size-based tags size_str = model['size'].split()[0] if size_str.endswith('M'): size = float(size_str[:-1]) / 1024 # Convert MB to GB elif size_str.endswith('B'): size = float(size_str[:-1]) else: size = float(size_str) if size < 1: tags.append('tiny') elif size < 3: tags.append('small') elif size < 7: tags.append('medium') else: tags.append('large') # Recency tags modified = model['modified'] if 'hours' in modified: tags.append('recent') elif 'days' in modified and int(modified.split()[0]) <= 7: tags.append('week-old') else: tags.append('older') # Model type tags name = model['name'].lower() if 'embed' in name: tags.append('embedding') if any(prefix in name for prefix in ['llama', 'mistral', 'phi']): tags.append('llm') if 'code' in name: tags.append('code') # Add custom tags if model['name'] in self.custom_tags: tags.extend(self.custom_tags[model['name']]) return tags def filter_models(self, models: List[Dict[str, str]], tags: Optional[List[str]] = None) -> List[Dict[str, str]]: if not tags: return models return [model for model in models if any(tag in model['tags'] for tag in tags)] def list_models_with_tags(self, tags: Optional[List[str]] = None) -> List[Dict[str, str]]: models = [self.parse_line(line) for line in self.get_ollama_list()] for model in models: model['tags'] = self.add_tags(model) return self.filter_models(models, tags) def get_embedding_model(self) -> Optional[str]: embedding_models = self.list_models_with_tags(['embedding']) if embedding_models: return embedding_models[0]['name'] # Return the first available embedding model return None def get_llm_model(self) -> Optional[str]: llm_models = self.list_models_with_tags(['llm']) if llm_models: return llm_models[0]['name'] # Return the first available LLM model return None def display_all_models(self, models: List[Dict[str, str]]): table = Table(title="All Models", box=box.DOUBLE_EDGE) table.add_column("Name", style="cyan", no_wrap=True) table.add_column("Size", style="magenta") table.add_column("Modified", style="green") table.add_column("Tags", style="yellow") table.add_column("Custom Tags", style="red") for model in models: custom_tags = ", ".join(self.custom_tags.get(model['name'], [])) table.add_row( model['name'], model['size'], model['modified'], ", ".join(model['tags']), custom_tags ) self.console.print(Panel(table, expand=False, border_style="bold white")) def display_filtered_models(self, models: List[Dict[str, str]], filter_tags: List[str]): table = Table(title=f"Models Tagged: {', '.join(filter_tags)}", box=box.SIMPLE_HEAVY) table.add_column("Name", style="cyan", no_wrap=True) table.add_column("Size", style="magenta") table.add_column("Tags", style="yellow") table.add_column("Custom Tags", style="red") for model in models: custom_tags = ", ".join(self.custom_tags.get(model['name'], [])) table.add_row( model['name'], model['size'], ", ".join(model['tags']), custom_tags ) self.console.print(Panel(table, expand=False, border_style="bold green")) def display_model_info(self, model_name: str, info: Dict): tree = Tree(f"[bold cyan]{model_name}[/bold cyan]") details = tree.add("Details") for key, value in info['details'].items(): details.add(f"[yellow]{key}[/yellow]: {value}") modelfile = tree.add("Modelfile") modelfile.add(Syntax(info['modelfile'], "dockerfile", theme="monokai", line_numbers=True)) parameters = tree.add("Parameters") for line in info['parameters'].split('\n'): parameters.add(line) tags = tree.add("Tags") tags.add(", ".join(info['tags'])) self.console.print(Panel(tree, title="Model Information", expand=False, border_style="bold magenta")) def display_embeddings(self, model_name: str, embeddings: List[float]): table = Table(title=f"Embeddings from {model_name}", box=box.SIMPLE) table.add_column("Index", style="cyan", justify="right") table.add_column("Value", style="magenta") for i, value in enumerate(embeddings[:10]): # Display first 10 values table.add_row(str(i), f"{value:.6f}") self.console.print(Panel(table, expand=False, border_style="bold yellow")) def add_custom_tag(self, model_name: str, tag: str): if model_name not in self.custom_tags: self.custom_tags[model_name] = [] if tag not in self.custom_tags[model_name]: self.custom_tags[model_name].append(tag) self.save_custom_tags() return True return False def remove_custom_tag(self, model_name: str, tag: str): if model_name in self.custom_tags and tag in self.custom_tags[model_name]: self.custom_tags[model_name].remove(tag) if not self.custom_tags[model_name]: del self.custom_tags[model_name] self.save_custom_tags() return True return False def load_custom_tags(self): try: with open('custom_tags.json', 'r') as f: return json.load(f) except FileNotFoundError: return {} def save_custom_tags(self): with open('custom_tags.json', 'w') as f: json.dump(self.custom_tags, f) def main(): parser = argparse.ArgumentParser(description="Ollama Model Manager") parser.add_argument("--list", action="store_true", help="List all models") parser.add_argument("--add-tag", nargs=2, metavar=('MODEL', 'TAG'), help="Add a custom tag to a model") parser.add_argument("--remove-tag", nargs=2, metavar=('MODEL', 'TAG'), help="Remove a custom tag from a model") parser.add_argument("--filter", nargs='+', help="Filter models by tags") parser.add_argument("--info", metavar='MODEL', help="Show detailed information for a specific model") parser.add_argument("--embeddings", nargs=2, metavar=('MODEL', 'TEXT'), help="Generate embeddings for the given text using the specified model") parser.add_argument("--create", nargs=2, metavar=('NAME', 'MODELFILE'), help="Create a new model") parser.add_argument("--delete", metavar='MODEL', help="Delete a model") parser.add_argument("--copy", nargs=2, metavar=('SOURCE', 'DESTINATION'), help="Copy a model") parser.add_argument("--pull", metavar='MODEL', help="Pull a model") parser.add_argument("--push", metavar='MODEL', help="Push a model") args = parser.parse_args() api = OllamaModelManager() if args.list or not any(vars(args).values()): all_models = api.list_models_with_tags() api.display_all_models(all_models) if args.add_tag: model, tag = args.add_tag if api.add_custom_tag(model, tag): api.console.print(f"[green]Added tag '{tag}' to {model}[/green]") else: api.console.print(f"[red]Failed to add tag '{tag}' to {model}[/red]") if args.remove_tag: model, tag = args.remove_tag if api.remove_custom_tag(model, tag): api.console.print(f"[green]Removed tag '{tag}' from {model}[/green]") else: api.console.print(f"[red]Failed to remove tag '{tag}' from {model}[/red]") if args.filter: filtered_models = api.list_models_with_tags(args.filter) api.display_filtered_models(filtered_models, args.filter) if args.info: model_info = api.show_model_info(args.info) if model_info: api.display_model_info(args.info, model_info) else: api.console.print(f"[red]Unable to fetch information for model: {args.info}[/red]") if args.embeddings: model, text = args.embeddings embeddings = api.generate_embeddings(model, text) if embeddings: api.display_embeddings(model, embeddings['embeddings'][0]) else: api.console.print(f"[red]Unable to generate embeddings using model: {model}[/red]") if args.create: name, modelfile = args.create with open(modelfile, 'r') as f: modelfile_content = f.read() api.create_model(name, modelfile_content) if args.delete: api.delete_model(args.delete) if args.copy: source, destination = args.copy api.copy_model(source, destination) if args.pull: api.pull_model(args.pull) if args.push: api.push_model(args.push) if __name__ == "__main__": main()