Add File
This commit is contained in:
114
src/landppt/ai/base.py
Normal file
114
src/landppt/ai/base.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
"""
|
||||||
|
Base classes for AI providers
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Dict, Any, Optional, AsyncGenerator, Union
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class MessageRole(str, Enum):
|
||||||
|
"""Message roles for AI conversations"""
|
||||||
|
SYSTEM = "system"
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
|
||||||
|
class MessageContentType(str, Enum):
|
||||||
|
"""Message content types for multimodal support"""
|
||||||
|
TEXT = "text"
|
||||||
|
IMAGE_URL = "image_url"
|
||||||
|
|
||||||
|
class ImageContent(BaseModel):
|
||||||
|
"""Image content for multimodal messages"""
|
||||||
|
type: MessageContentType = MessageContentType.IMAGE_URL
|
||||||
|
image_url: Dict[str, str] # {"url": "data:image/jpeg;base64,..." or "http://..."}
|
||||||
|
|
||||||
|
class TextContent(BaseModel):
|
||||||
|
"""Text content for multimodal messages"""
|
||||||
|
type: MessageContentType = MessageContentType.TEXT
|
||||||
|
text: str
|
||||||
|
|
||||||
|
class AIMessage(BaseModel):
|
||||||
|
"""AI message model with multimodal support"""
|
||||||
|
role: MessageRole
|
||||||
|
content: Union[str, List[Union[TextContent, ImageContent]]] # Support both simple string and multimodal content
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
class AIResponse(BaseModel):
|
||||||
|
"""AI response model"""
|
||||||
|
content: str
|
||||||
|
model: str
|
||||||
|
usage: Dict[str, int]
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
metadata: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
class AIProvider(ABC):
|
||||||
|
"""Abstract base class for AI providers"""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
self.config = config
|
||||||
|
self.model = config.get("model", "unknown")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
messages: List[AIMessage],
|
||||||
|
**kwargs
|
||||||
|
) -> AIResponse:
|
||||||
|
"""Generate chat completion"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def text_completion(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
**kwargs
|
||||||
|
) -> AIResponse:
|
||||||
|
"""Generate text completion"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def stream_chat_completion(
|
||||||
|
self,
|
||||||
|
messages: List[AIMessage],
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Stream chat completion (optional)"""
|
||||||
|
# Default implementation: return full response at once
|
||||||
|
response = await self.chat_completion(messages, **kwargs)
|
||||||
|
yield response.content
|
||||||
|
|
||||||
|
async def stream_text_completion(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Stream text completion (optional)"""
|
||||||
|
# Default implementation: return full response at once
|
||||||
|
response = await self.text_completion(prompt, **kwargs)
|
||||||
|
yield response.content
|
||||||
|
|
||||||
|
def get_model_info(self) -> Dict[str, Any]:
|
||||||
|
"""Get model information"""
|
||||||
|
return {
|
||||||
|
"model": self.model,
|
||||||
|
"provider": self.__class__.__name__,
|
||||||
|
"config": {k: v for k, v in self.config.items() if "key" not in k.lower()}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _calculate_usage(self, prompt: str, response: str) -> Dict[str, int]:
|
||||||
|
"""Calculate token usage (simplified)"""
|
||||||
|
# Simplified calculation
|
||||||
|
prompt_tokens = len(prompt.split())
|
||||||
|
completion_tokens = len(response.split())
|
||||||
|
|
||||||
|
return {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": prompt_tokens + completion_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
def _merge_config(self, **kwargs) -> Dict[str, Any]:
|
||||||
|
"""Merge provider config with request parameters"""
|
||||||
|
merged = self.config.copy()
|
||||||
|
merged.update(kwargs)
|
||||||
|
return merged
|
||||||
Reference in New Issue
Block a user