diff --git a/backend/apps/system/api/aimodel.py b/backend/apps/system/api/aimodel.py new file mode 100644 index 0000000..a96354d --- /dev/null +++ b/backend/apps/system/api/aimodel.py @@ -0,0 +1,159 @@ +import json +from typing import List, Union + +from fastapi.responses import StreamingResponse +from apps.ai_model.model_factory import LLMConfig, LLMFactory +from apps.system.schemas.ai_model_schema import AiModelConfigItem, AiModelCreator, AiModelEditor, AiModelGridItem +from fastapi import APIRouter, Query +from sqlmodel import func, select, update + +from apps.system.models.system_model import AiModelDetail +from common.core.deps import SessionDep, Trans +from common.utils.crypto import sqlbot_decrypt +from common.utils.time import get_timestamp +from common.utils.utils import SQLBotLogUtil, prepare_model_arg + +router = APIRouter(tags=["system/aimodel"], prefix="/system/aimodel") + +@router.post("/status") +async def check_llm(info: AiModelCreator, trans: Trans): + async def generate(): + try: + additional_params = {item.key: prepare_model_arg(item.val) for item in info.config_list if item.key and item.val} + config = LLMConfig( + model_type="openai" if info.protocol == 1 else "vllm", + model_name=info.base_model, + api_key=info.api_key, + api_base_url=info.api_domain, + additional_params=additional_params, + ) + llm_instance = LLMFactory.create_llm(config) + async for chunk in llm_instance.llm.astream("1+1=?"): + SQLBotLogUtil.info(chunk) + if chunk and isinstance(chunk, str): + yield json.dumps({"content": chunk}) + "\n" + if chunk and isinstance(chunk, dict) and chunk.content: + yield json.dumps({"content": chunk.content}) + "\n" + + except Exception as e: + SQLBotLogUtil.error(f"Error checking LLM: {e}") + error_msg = trans('i18n_llm.validate_error', msg=str(e)) + yield json.dumps({"error": error_msg}) + "\n" + + return StreamingResponse(generate(), media_type="application/x-ndjson") + +@router.get("/default") +async def check_default(session: SessionDep, trans: Trans): + db_model = session.exec( + select(AiModelDetail).where(AiModelDetail.default_model == True) + ).first() + if not db_model: + raise Exception(trans('i18n_llm.miss_default')) + +@router.put("/default/{id}") +async def set_default(session: SessionDep, id: int): + db_model = session.get(AiModelDetail, id) + if not db_model: + raise ValueError(f"AiModelDetail with id {id} not found") + if db_model.default_model: + return + + try: + session.exec( + update(AiModelDetail).values(default_model=False) + ) + db_model.default_model = True + session.add(db_model) + session.commit() + except Exception as e: + session.rollback() + raise e + +@router.get("", response_model=list[AiModelGridItem]) +async def query( + session: SessionDep, + keyword: Union[str, None] = Query(default=None, max_length=255) +): + statement = select(AiModelDetail.id, + AiModelDetail.name, + AiModelDetail.model_type, + AiModelDetail.base_model, + AiModelDetail.supplier, + AiModelDetail.protocol, + AiModelDetail.default_model) + if keyword is not None: + statement = statement.where(AiModelDetail.name.like(f"%{keyword}%")) + statement = statement.order_by(AiModelDetail.default_model.desc(), AiModelDetail.name, AiModelDetail.create_time) + items = session.exec(statement).all() + return items + +@router.get("/{id}", response_model=AiModelEditor) +async def get_model_by_id( + session: SessionDep, + id: int +): + db_model = session.get(AiModelDetail, id) + if not db_model: + raise ValueError(f"AiModelDetail with id {id} not found") + + config_list: List[AiModelConfigItem] = [] + if db_model.config: + try: + raw = json.loads(db_model.config) + config_list = [AiModelConfigItem(**item) for item in raw] + except Exception: + pass + if db_model.api_key: + db_model.api_key = await sqlbot_decrypt(db_model.api_key) + if db_model.api_domain: + db_model.api_domain = await sqlbot_decrypt(db_model.api_domain) + data = AiModelDetail.model_validate(db_model).model_dump(exclude_unset=True) + data.pop("config", None) + data["config_list"] = config_list + return AiModelEditor(**data) + +@router.post("") +async def add_model( + session: SessionDep, + creator: AiModelCreator +): + data = creator.model_dump(exclude_unset=True) + data["config"] = json.dumps([item.model_dump(exclude_unset=True) for item in creator.config_list]) + data.pop("config_list", None) + detail = AiModelDetail.model_validate(data) + detail.create_time = get_timestamp() + count = session.exec(select(func.count(AiModelDetail.id))).one() + if count == 0: + detail.default_model = True + session.add(detail) + session.commit() + +@router.put("") +async def update_model( + session: SessionDep, + editor: AiModelEditor +): + id = int(editor.id) + data = editor.model_dump(exclude_unset=True) + data["config"] = json.dumps([item.model_dump(exclude_unset=True) for item in editor.config_list]) + data.pop("config_list", None) + db_model = session.get(AiModelDetail, id) + #update_data = AiModelDetail.model_validate(data) + db_model.sqlmodel_update(data) + session.add(db_model) + session.commit() + +@router.delete("/{id}") +async def delete_model( + session: SessionDep, + trans: Trans, + id: int +): + item = session.get(AiModelDetail, id) + if item.default_model: + raise Exception(trans('i18n_llm.delete_default_error', key = item.name)) + session.delete(item) + session.commit() + + + \ No newline at end of file