diff --git a/backend/apps/chat/api/chat.py b/backend/apps/chat/api/chat.py new file mode 100644 index 0000000..fa85d64 --- /dev/null +++ b/backend/apps/chat/api/chat.py @@ -0,0 +1,225 @@ +import asyncio +import io +import traceback + +import numpy as np +import orjson +import pandas as pd +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse +from sqlalchemy import and_, select + +from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \ + delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id +from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, ExcelData +from apps.chat.task.llm import LLMService +from common.core.deps import CurrentAssistant, SessionDep, CurrentUser + +router = APIRouter(tags=["Data Q&A"], prefix="/chat") + + +@router.get("/list") +async def chats(session: SessionDep, current_user: CurrentUser): + return list_chats(session, current_user) + + +@router.get("/get/{chart_id}") +async def get_chat(session: SessionDep, current_user: CurrentUser, chart_id: int, current_assistant: CurrentAssistant): + def inner(): + return get_chat_with_records(chart_id=chart_id, session=session, current_user=current_user, + current_assistant=current_assistant) + + return await asyncio.to_thread(inner) + + +@router.get("/get/with_data/{chart_id}") +async def get_chat_with_data(session: SessionDep, current_user: CurrentUser, chart_id: int, + current_assistant: CurrentAssistant): + def inner(): + return get_chat_with_records_with_data(chart_id=chart_id, session=session, current_user=current_user, + current_assistant=current_assistant) + + return await asyncio.to_thread(inner) + + +@router.get("/record/get/{chart_record_id}/data") +async def chat_record_data(session: SessionDep, chart_record_id: int): + def inner(): + return get_chat_chart_data(chart_record_id=chart_record_id, session=session) + + return await asyncio.to_thread(inner) + + +@router.get("/record/get/{chart_record_id}/predict_data") +async def chat_predict_data(session: SessionDep, chart_record_id: int): + def inner(): + return get_chat_predict_data(chart_record_id=chart_record_id, session=session) + + return await asyncio.to_thread(inner) + + +@router.post("/rename") +async def rename(session: SessionDep, chat: RenameChat): + try: + return rename_chat(session=session, rename_object=chat) + except Exception as e: + raise HTTPException( + status_code=500, + detail=str(e) + ) + + +@router.get("/delete/{chart_id}") +async def delete(session: SessionDep, chart_id: int): + try: + return delete_chat(session=session, chart_id=chart_id) + except Exception as e: + raise HTTPException( + status_code=500, + detail=str(e) + ) + + +@router.post("/start") +async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj: CreateChat): + try: + return create_chat(session, current_user, create_chat_obj) + except Exception as e: + raise HTTPException( + status_code=500, + detail=str(e) + ) + + +@router.post("/assistant/start") +async def start_chat(session: SessionDep, current_user: CurrentUser): + try: + return create_chat(session, current_user, CreateChat(origin=2), False) + except Exception as e: + raise HTTPException( + status_code=500, + detail=str(e) + ) + + +@router.post("/recommend_questions/{chat_record_id}") +async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int, + current_assistant: CurrentAssistant): + try: + record = get_chat_record_by_id(session, chat_record_id) + + if not record: + raise HTTPException( + status_code=400, + detail=f"Chat record with id {chat_record_id} not found" + ) + request_question = ChatQuestion(chat_id=record.chat_id, question=record.question if record.question else '') + + llm_service = await LLMService.create(current_user, request_question, current_assistant, True) + llm_service.set_record(record) + llm_service.run_recommend_questions_task_async() + except Exception as e: + traceback.print_exc() + raise HTTPException( + status_code=500, + detail=str(e) + ) + + return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") + + +@router.post("/question") +async def stream_sql(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion, + current_assistant: CurrentAssistant): + """Stream SQL analysis results + + Args: + session: Database session + current_user: CurrentUser + request_question: User question model + + Returns: + Streaming response with analysis results + """ + + try: + llm_service = await LLMService.create(current_user, request_question, current_assistant) + llm_service.init_record() + llm_service.run_task_async() + except Exception as e: + traceback.print_exc() + + def _err(_e: Exception): + yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n' + + return StreamingResponse(_err(e), media_type="text/event-stream") + + return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") + + +@router.post("/record/{chat_record_id}/{action_type}") +async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, chat_record_id: int, action_type: str, + current_assistant: CurrentAssistant): + try: + if action_type != 'analysis' and action_type != 'predict': + raise Exception(f"Type {action_type} Not Found") + record: ChatRecord | None = None + + stmt = select(ChatRecord.id, ChatRecord.question, ChatRecord.chat_id, ChatRecord.datasource, + ChatRecord.engine_type, + ChatRecord.ai_modal_id, ChatRecord.create_by, ChatRecord.chart, ChatRecord.data).where( + and_(ChatRecord.id == chat_record_id)) + result = session.execute(stmt) + for r in result: + record = ChatRecord(id=r.id, question=r.question, chat_id=r.chat_id, datasource=r.datasource, + engine_type=r.engine_type, ai_modal_id=r.ai_modal_id, create_by=r.create_by, + chart=r.chart, + data=r.data) + + if not record: + raise Exception(f"Chat record with id {chat_record_id} not found") + + if not record.chart: + raise Exception( + f"Chat record with id {chat_record_id} has not generated chart, do not support to analyze it") + + request_question = ChatQuestion(chat_id=record.chat_id, question=record.question) + + llm_service = await LLMService.create(current_user, request_question, current_assistant) + llm_service.run_analysis_or_predict_task_async(action_type, record) + except Exception as e: + traceback.print_exc() + + def _err(_e: Exception): + yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n' + + return StreamingResponse(_err(e), media_type="text/event-stream") + + return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") + + +@router.post("/excel/export") +async def export_excel(excel_data: ExcelData): + def inner(): + _fields_list = [] + data = [] + for _data in excel_data.data: + _row = [] + for field in excel_data.axis: + _row.append(_data.get(field.value)) + data.append(_row) + for field in excel_data.axis: + _fields_list.append(field.name) + df = pd.DataFrame(np.array(data), columns=_fields_list) + + buffer = io.BytesIO() + + with pd.ExcelWriter(buffer, engine='xlsxwriter', + engine_kwargs={'options': {'strings_to_numbers': True}}) as writer: + df.to_excel(writer, sheet_name='Sheet1', index=False) + + buffer.seek(0) + return io.BytesIO(buffer.getvalue()) + + result = await asyncio.to_thread(inner) + return StreamingResponse(result, media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")