Last active
April 14, 2024 22:47
-
-
Save ohdearquant/dcff2c01368944bec55763d2753f8365 to your computer and use it in GitHub Desktop.
intelligent model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from lionagi.core.generic import BaseComponent | |
| from pydantic import BaseModel, Field, field_validator | |
| from abc import ABC, abstractmethod | |
| class iModel(BaseComponent, ABC): | |
| model_id: str = Field(default_factory=str) | |
| model_name: str = Field(default_factory=str) | |
| llmconfig: dict = Field(default_factory=dict) | |
| def update_config(self, updates: Dict): | |
| self.llmconfig.update(updates) | |
| self.log("Configuration updated.") | |
| return True | |
| def log(self, message: str): | |
| """ | |
| Basic logging function. | |
| """ | |
| print(f"Log [{self.model_id}]: {message}") | |
| @abstractmethod | |
| async def predict(self, input_data: dict): | |
| """ | |
| Generate a prediction based on input_data. | |
| This method must be overridden by subclasses. | |
| """ | |
| pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Dict | |
| from lionagi import Session, direct, Services | |
| from lionagi.libs.ln_api import BaseService | |
| class Model(iModel): | |
| service: BaseService | None = None | |
| @field_validator("service", mode="before") | |
| def _validate_service(cls, value=None): | |
| if value is None: | |
| return cls.initialize_service() | |
| if not isinstance(value, BaseService): | |
| raise ValueError("Service is of invalid type") | |
| return value | |
| def check_service(self): | |
| if not self.service: | |
| self.log("Service not initialized.") | |
| return False | |
| return True | |
| async def predict(self, *args, **kwargs): | |
| if not self.check_service(): | |
| return "Service not initialized." | |
| try: | |
| return await direct.predict(*args, service=self.service,**kwargs) | |
| except Exception as e: | |
| self.log(f"Error generating output: {e}") | |
| return "Error generating output." | |
| async def chain_of_thoughts(self, *args, **kwargs): | |
| if not self.check_service(): | |
| return "Service not initialized." | |
| try: | |
| return await direct.chain_of_thoughts(*args, service=self.service,**kwargs) | |
| except Exception as e: | |
| self.log(f"Error generating output: {e}") | |
| return "Error generating output." | |
| async def chain_of_react(self, *args, **kwargs): | |
| if not self.check_service(): | |
| return "Service not initialized." | |
| try: | |
| return await direct.chain_of_react(*args, service=self.service, **kwargs) | |
| except Exception as e: | |
| self.log(f"Error generating output: {e}") | |
| return "Error generating output." | |
| async def react(self, *args, **kwargs): | |
| if not self.check_service(): | |
| return "Service not initialized." | |
| try: | |
| return await direct.react(*args,service=self.service,**kwargs) | |
| except Exception as e: | |
| self.log(f"Error generating output: {e}") | |
| return "Error generating output." | |
| async def plan(self, *args, **kwargs): | |
| if not self.check_service(): | |
| return "Service not initialized." | |
| try: | |
| return await direct.plan(*args,service=self.service,**kwargs) | |
| except Exception as e: | |
| self.log(f"Error generating output: {e}") | |
| return "Error generating output." | |
| async def vote(self, *args, **kwargs): | |
| if not self.check_service(): | |
| return "Service not initialized." | |
| try: | |
| return await direct.vote(*args, service=self.service,**kwargs) | |
| except Exception as e: | |
| self.log(f"Error generating output: {e}") | |
| return "Error generating output." | |
| def initialize_service(self, service=None): | |
| try: | |
| service = service or Services.OpenAI() | |
| return service | |
| except Exception as e: | |
| self.log(f"Initialization error: {e}") | |
| return None | |
| class ExpertModel(Model): | |
| pass | |
| class GatingModel(Model): | |
| expert_weights: dict = Field(default_factory=dict) | |
| learning_rate: float = Field(default=0.1) | |
| def update_weights(self, expert_id: str, performance_score: float): | |
| if expert_id in self.expert_weights: | |
| self.expert_weights[expert_id] = (1 - self.learning_rate) * self.expert_weights[expert_id] + self.learning_rate * performance_score | |
| else: | |
| self.expert_weights[expert_id] = performance_score | |
| def select_experts(self, experts: list[ExpertModel], context: Dict, k: int) -> List[ExpertModel]: | |
| if not experts: | |
| raise ValueError("No experts available to select from.") | |
| relevance_scores = {} | |
| for expert in experts: | |
| weight = self.expert_weights.get(expert.model_id, 1.0) | |
| relevance = random.random() * weight | |
| relevance_scores[expert.model_id] = relevance | |
| selected_expert_ids = sorted(relevance_scores, key=relevance_scores.get, reverse=True)[:k] | |
| selected_experts = [expert for expert in experts if expert.model_id in selected_expert_ids] | |
| return selected_experts |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Any | |
| class ModelManager: | |
| def __init__(self, models: Dict[str, Model]): | |
| self.models = models | |
| def add_model(self, model: Model): | |
| self.models[model.model_id] = model | |
| model.log("Model added to manager.") | |
| def remove_model(self, model_id: str): | |
| if model_id in self.models: | |
| del self.models[model_id] | |
| self.log(f"Model {model_id} removed from manager.") | |
| def update_model(self, model_id: str, updates: Dict): | |
| if model_id in self.models: | |
| self.models[model_id].update(updates) | |
| self.models[model_id].log("Model updated.") | |
| def log(self, message: str): | |
| print(f"Manager Log: {message}") | |
| async def predict(self, model_id: str, input_data: Dict) -> Any: | |
| if model_id in self.models: | |
| return await self.models[model_id].cached_predict(input_data) | |
| self.log(f"Model {model_id} not found.") | |
| return "Model not found." |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class Validator(Worker): | |
| expert: ExpertModel = Field(default_factory=ExpertModel) | |
| pass | |
| class UnifiedValidators(Validator): | |
| experts: dict = Field(default_factory=dict) | |
| @abstractmethod | |
| def select_experts(self, k: int) -> List[ExpertModel]: | |
| pass | |
| @abstractmethod | |
| def validate(self, *args, **kwargs): | |
| pass | |
| class RecursiveUnifiedValidators(UnifiedValidators): | |
| gating_model: GatingModel = Field(default_factory=GatingModel) | |
| exploration_rate: float = Field(default=0.1) | |
| min_iterations: int = Field(default=1) | |
| def select_experts(self, k: int) -> List[ExpertModel]: | |
| return self.gating_model.select_experts(self.experts, context={}, k=k) | |
| async def forward(self, *args, k=3, iterations=1, **kwargs): | |
| selected_experts = None | |
| if random.random() < self.exploration_rate: | |
| selected_experts = random.sample( | |
| self.experts, | |
| k=min(k, len(self.experts)) | |
| ) | |
| else: | |
| selected_experts = self.select_experts(k=min(k, len(self.experts))) | |
| tasks = [ | |
| [expert.predict(*args, **kwargs) for expert in selected_experts] | |
| for _ in range(iterations) | |
| ] | |
| results = await asyncio.gather(*tasks) | |
| return [" ".join(outputs) for outputs in results] | |
| async def validate(self, *args, **kwargs): | |
| outs = [] | |
| for _ in range(self.min_iterations): | |
| outputs = await self.forward(*args, **kwargs) | |
| outs.append(outputs) | |
| for i, output in enumerate(outputs, start=1): | |
| print(f"Expert {i}: {output}") | |
| return outs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment