This commit is contained in:
2025-09-08 16:36:18 +08:00
parent 07c288eb93
commit 9f8e44f33e

113
backend/apps/mcp/mcp.py Normal file
View File

@@ -0,0 +1,113 @@
# Author: Junjun
# Date: 2025/7/1
from datetime import timedelta
import jwt
from fastapi import HTTPException, status, APIRouter
from fastapi.responses import StreamingResponse
# from fastapi.security import OAuth2PasswordBearer
from jwt.exceptions import InvalidTokenError
from pydantic import ValidationError
from sqlmodel import select
from apps.chat.api.chat import create_chat
from apps.chat.models.chat_model import ChatMcp, CreateChat, ChatStart, McpQuestion
from apps.chat.task.llm import LLMService
from apps.system.crud.user import authenticate
from apps.system.crud.user import get_db_user
from apps.system.models.system_model import UserWsModel
from apps.system.models.user import UserModel
from apps.system.schemas.system_schema import BaseUserDTO
from apps.system.schemas.system_schema import UserInfoDTO
from common.core import security
from common.core.config import settings
from common.core.deps import SessionDep
from common.core.schemas import TokenPayload, XOAuth2PasswordBearer, Token
from common.core.security import create_access_token
reusable_oauth2 = XOAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
)
router = APIRouter(tags=["mcp"], prefix="/mcp")
# @router.post("/access_token", operation_id="access_token")
# def local_login(
# session: SessionDep,
# form_data: Annotated[OAuth2PasswordRequestForm, Depends()]
# ) -> Token:
# user = authenticate(session=session, account=form_data.username, password=form_data.password)
# if not user:
# raise HTTPException(status_code=400, detail="Incorrect account or password")
# access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
# user_dict = user.to_dict()
# return Token(access_token=create_access_token(
# user_dict, expires_delta=access_token_expires
# ))
# @router.get("/ds_list", operation_id="get_datasource_list")
# async def datasource_list(session: SessionDep):
# return get_datasource_list(session=session)
#
#
# @router.get("/model_list", operation_id="get_model_list")
# async def get_model_list(session: SessionDep):
# return session.query(AiModelDetail).all()
@router.post("/mcp_start", operation_id="mcp_start")
async def mcp_start(session: SessionDep, chat: ChatStart):
user: BaseUserDTO = authenticate(session=session, account=chat.username, password=chat.password)
if not user:
raise HTTPException(status_code=400, detail="Incorrect account or password")
if not user.oid or user.oid == 0:
raise HTTPException(status_code=400, detail="No associated workspace, Please contact the administrator")
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
user_dict = user.to_dict()
t = Token(access_token=create_access_token(
user_dict, expires_delta=access_token_expires
))
c = create_chat(session, user, CreateChat(origin=1), False)
return {"access_token": t.access_token, "chat_id": c.id}
@router.post("/mcp_question", operation_id="mcp_question")
async def mcp_question(session: SessionDep, chat: McpQuestion):
try:
payload = jwt.decode(
chat.token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
)
token_data = TokenPayload(**payload)
except (InvalidTokenError, ValidationError):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)
# session_user = await get_user_info(session=session, user_id=token_data.id)
db_user: UserModel = get_db_user(session=session, user_id=token_data.id)
session_user = UserInfoDTO.model_validate(db_user.model_dump())
session_user.isAdmin = session_user.id == 1 and session_user.account == 'admin'
if session_user.isAdmin:
session_user = session_user
ws_model: UserWsModel = session.exec(
select(UserWsModel).where(UserWsModel.uid == session_user.id, UserWsModel.oid == session_user.oid)).first()
session_user.weight = ws_model.weight if ws_model else -1
session_user = UserInfoDTO.model_validate(session_user)
if not session_user:
raise HTTPException(status_code=404, detail="User not found")
if session_user.status != 1:
raise HTTPException(status_code=400, detail="Inactive user")
mcp_chat = ChatMcp(token=chat.token, chat_id=chat.chat_id, question=chat.question)
# ask
llm_service = await LLMService.create(session_user, mcp_chat)
llm_service.init_record()
return StreamingResponse(llm_service.run_task(False), media_type="text/event-stream")