Add File
This commit is contained in:
114
src/landppt/ai/base.py
Normal file
114
src/landppt/ai/base.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
Base classes for AI providers
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator, Union
|
||||
from pydantic import BaseModel
|
||||
from enum import Enum
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
"""Message roles for AI conversations"""
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
class MessageContentType(str, Enum):
|
||||
"""Message content types for multimodal support"""
|
||||
TEXT = "text"
|
||||
IMAGE_URL = "image_url"
|
||||
|
||||
class ImageContent(BaseModel):
|
||||
"""Image content for multimodal messages"""
|
||||
type: MessageContentType = MessageContentType.IMAGE_URL
|
||||
image_url: Dict[str, str] # {"url": "data:image/jpeg;base64,..." or "http://..."}
|
||||
|
||||
class TextContent(BaseModel):
|
||||
"""Text content for multimodal messages"""
|
||||
type: MessageContentType = MessageContentType.TEXT
|
||||
text: str
|
||||
|
||||
class AIMessage(BaseModel):
|
||||
"""AI message model with multimodal support"""
|
||||
role: MessageRole
|
||||
content: Union[str, List[Union[TextContent, ImageContent]]] # Support both simple string and multimodal content
|
||||
name: Optional[str] = None
|
||||
|
||||
class AIResponse(BaseModel):
|
||||
"""AI response model"""
|
||||
content: str
|
||||
model: str
|
||||
usage: Dict[str, int]
|
||||
finish_reason: Optional[str] = None
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
class AIProvider(ABC):
|
||||
"""Abstract base class for AI providers"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.model = config.get("model", "unknown")
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: List[AIMessage],
|
||||
**kwargs
|
||||
) -> AIResponse:
|
||||
"""Generate chat completion"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def text_completion(
|
||||
self,
|
||||
prompt: str,
|
||||
**kwargs
|
||||
) -> AIResponse:
|
||||
"""Generate text completion"""
|
||||
pass
|
||||
|
||||
async def stream_chat_completion(
|
||||
self,
|
||||
messages: List[AIMessage],
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream chat completion (optional)"""
|
||||
# Default implementation: return full response at once
|
||||
response = await self.chat_completion(messages, **kwargs)
|
||||
yield response.content
|
||||
|
||||
async def stream_text_completion(
|
||||
self,
|
||||
prompt: str,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream text completion (optional)"""
|
||||
# Default implementation: return full response at once
|
||||
response = await self.text_completion(prompt, **kwargs)
|
||||
yield response.content
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""Get model information"""
|
||||
return {
|
||||
"model": self.model,
|
||||
"provider": self.__class__.__name__,
|
||||
"config": {k: v for k, v in self.config.items() if "key" not in k.lower()}
|
||||
}
|
||||
|
||||
def _calculate_usage(self, prompt: str, response: str) -> Dict[str, int]:
|
||||
"""Calculate token usage (simplified)"""
|
||||
# Simplified calculation
|
||||
prompt_tokens = len(prompt.split())
|
||||
completion_tokens = len(response.split())
|
||||
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens
|
||||
}
|
||||
|
||||
def _merge_config(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Merge provider config with request parameters"""
|
||||
merged = self.config.copy()
|
||||
merged.update(kwargs)
|
||||
return merged
|
||||
Reference in New Issue
Block a user