Files
SQLBot/backend/apps/chat/api/chat.py
2025-09-08 16:35:58 +08:00

226 lines
8.2 KiB
Python

import asyncio
import io
import traceback
import numpy as np
import orjson
import pandas as pd
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy import and_, select
from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \
delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, ExcelData
from apps.chat.task.llm import LLMService
from common.core.deps import CurrentAssistant, SessionDep, CurrentUser
router = APIRouter(tags=["Data Q&A"], prefix="/chat")
@router.get("/list")
async def chats(session: SessionDep, current_user: CurrentUser):
return list_chats(session, current_user)
@router.get("/get/{chart_id}")
async def get_chat(session: SessionDep, current_user: CurrentUser, chart_id: int, current_assistant: CurrentAssistant):
def inner():
return get_chat_with_records(chart_id=chart_id, session=session, current_user=current_user,
current_assistant=current_assistant)
return await asyncio.to_thread(inner)
@router.get("/get/with_data/{chart_id}")
async def get_chat_with_data(session: SessionDep, current_user: CurrentUser, chart_id: int,
current_assistant: CurrentAssistant):
def inner():
return get_chat_with_records_with_data(chart_id=chart_id, session=session, current_user=current_user,
current_assistant=current_assistant)
return await asyncio.to_thread(inner)
@router.get("/record/get/{chart_record_id}/data")
async def chat_record_data(session: SessionDep, chart_record_id: int):
def inner():
return get_chat_chart_data(chart_record_id=chart_record_id, session=session)
return await asyncio.to_thread(inner)
@router.get("/record/get/{chart_record_id}/predict_data")
async def chat_predict_data(session: SessionDep, chart_record_id: int):
def inner():
return get_chat_predict_data(chart_record_id=chart_record_id, session=session)
return await asyncio.to_thread(inner)
@router.post("/rename")
async def rename(session: SessionDep, chat: RenameChat):
try:
return rename_chat(session=session, rename_object=chat)
except Exception as e:
raise HTTPException(
status_code=500,
detail=str(e)
)
@router.get("/delete/{chart_id}")
async def delete(session: SessionDep, chart_id: int):
try:
return delete_chat(session=session, chart_id=chart_id)
except Exception as e:
raise HTTPException(
status_code=500,
detail=str(e)
)
@router.post("/start")
async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj: CreateChat):
try:
return create_chat(session, current_user, create_chat_obj)
except Exception as e:
raise HTTPException(
status_code=500,
detail=str(e)
)
@router.post("/assistant/start")
async def start_chat(session: SessionDep, current_user: CurrentUser):
try:
return create_chat(session, current_user, CreateChat(origin=2), False)
except Exception as e:
raise HTTPException(
status_code=500,
detail=str(e)
)
@router.post("/recommend_questions/{chat_record_id}")
async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int,
current_assistant: CurrentAssistant):
try:
record = get_chat_record_by_id(session, chat_record_id)
if not record:
raise HTTPException(
status_code=400,
detail=f"Chat record with id {chat_record_id} not found"
)
request_question = ChatQuestion(chat_id=record.chat_id, question=record.question if record.question else '')
llm_service = await LLMService.create(current_user, request_question, current_assistant, True)
llm_service.set_record(record)
llm_service.run_recommend_questions_task_async()
except Exception as e:
traceback.print_exc()
raise HTTPException(
status_code=500,
detail=str(e)
)
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
@router.post("/question")
async def stream_sql(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion,
current_assistant: CurrentAssistant):
"""Stream SQL analysis results
Args:
session: Database session
current_user: CurrentUser
request_question: User question model
Returns:
Streaming response with analysis results
"""
try:
llm_service = await LLMService.create(current_user, request_question, current_assistant)
llm_service.init_record()
llm_service.run_task_async()
except Exception as e:
traceback.print_exc()
def _err(_e: Exception):
yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n'
return StreamingResponse(_err(e), media_type="text/event-stream")
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
@router.post("/record/{chat_record_id}/{action_type}")
async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, chat_record_id: int, action_type: str,
current_assistant: CurrentAssistant):
try:
if action_type != 'analysis' and action_type != 'predict':
raise Exception(f"Type {action_type} Not Found")
record: ChatRecord | None = None
stmt = select(ChatRecord.id, ChatRecord.question, ChatRecord.chat_id, ChatRecord.datasource,
ChatRecord.engine_type,
ChatRecord.ai_modal_id, ChatRecord.create_by, ChatRecord.chart, ChatRecord.data).where(
and_(ChatRecord.id == chat_record_id))
result = session.execute(stmt)
for r in result:
record = ChatRecord(id=r.id, question=r.question, chat_id=r.chat_id, datasource=r.datasource,
engine_type=r.engine_type, ai_modal_id=r.ai_modal_id, create_by=r.create_by,
chart=r.chart,
data=r.data)
if not record:
raise Exception(f"Chat record with id {chat_record_id} not found")
if not record.chart:
raise Exception(
f"Chat record with id {chat_record_id} has not generated chart, do not support to analyze it")
request_question = ChatQuestion(chat_id=record.chat_id, question=record.question)
llm_service = await LLMService.create(current_user, request_question, current_assistant)
llm_service.run_analysis_or_predict_task_async(action_type, record)
except Exception as e:
traceback.print_exc()
def _err(_e: Exception):
yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n'
return StreamingResponse(_err(e), media_type="text/event-stream")
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
@router.post("/excel/export")
async def export_excel(excel_data: ExcelData):
def inner():
_fields_list = []
data = []
for _data in excel_data.data:
_row = []
for field in excel_data.axis:
_row.append(_data.get(field.value))
data.append(_row)
for field in excel_data.axis:
_fields_list.append(field.name)
df = pd.DataFrame(np.array(data), columns=_fields_list)
buffer = io.BytesIO()
with pd.ExcelWriter(buffer, engine='xlsxwriter',
engine_kwargs={'options': {'strings_to_numbers': True}}) as writer:
df.to_excel(writer, sheet_name='Sheet1', index=False)
buffer.seek(0)
return io.BytesIO(buffer.getvalue())
result = await asyncio.to_thread(inner)
return StreamingResponse(result, media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")