from datetime import datetime from enum import Enum from typing import List, Optional from fastapi import Body from pydantic import BaseModel from sqlalchemy import Column, Integer, Text, BigInteger, DateTime, Identity, Boolean from sqlalchemy import Enum as SQLAlchemyEnum from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import SQLModel, Field from apps.template.filter.generator import get_permissions_template from apps.template.generate_analysis.generator import get_analysis_template from apps.template.generate_chart.generator import get_chart_template from apps.template.generate_dynamic.generator import get_dynamic_template from apps.template.generate_guess_question.generator import get_guess_question_template from apps.template.generate_predict.generator import get_predict_template from apps.template.generate_sql.generator import get_sql_template from apps.template.select_datasource.generator import get_datasource_template def enum_values(enum_class: type[Enum]) -> list: """Get values for enum.""" return [status.value for status in enum_class] class TypeEnum(Enum): CHAT = "0" # TODO other usage class OperationEnum(Enum): GENERATE_SQL = '0' GENERATE_CHART = '1' ANALYSIS = '2' PREDICT_DATA = '3' GENERATE_RECOMMENDED_QUESTIONS = '4' GENERATE_SQL_WITH_PERMISSIONS = '5' CHOOSE_DATASOURCE = '6' GENERATE_DYNAMIC_SQL = '7' # TODO choose table / check connection / generate description class ChatLog(SQLModel, table=True): __tablename__ = "chat_log" id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True)) type: TypeEnum = Field( sa_column=Column(SQLAlchemyEnum(TypeEnum, native_enum=False, values_callable=enum_values, length=3))) operate: OperationEnum = Field( sa_column=Column(SQLAlchemyEnum(OperationEnum, native_enum=False, values_callable=enum_values, length=3))) pid: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True)) ai_modal_id: Optional[int] = Field(sa_column=Column(BigInteger)) base_modal: Optional[str] = Field(max_length=255) messages: Optional[list[dict]] = Field(sa_column=Column(JSONB)) reasoning_content: Optional[str | None] = Field(sa_column=Column(Text, nullable=True)) start_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True)) finish_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True)) token_usage: Optional[dict | None | int] = Field(sa_column=Column(JSONB)) class Chat(SQLModel, table=True): __tablename__ = "chat" id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True)) oid: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True, default=1)) create_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True)) create_by: int = Field(sa_column=Column(BigInteger, nullable=True)) brief: str = Field(max_length=64, nullable=True) chat_type: str = Field(max_length=20, default="chat") # chat, datasource datasource: int = Field(sa_column=Column(BigInteger, nullable=True)) engine_type: str = Field(max_length=64) origin: Optional[int] = Field( sa_column=Column(Integer, nullable=False, default=0)) # 0: default, 1: mcp, 2: assistant class ChatRecord(SQLModel, table=True): __tablename__ = "chat_record" id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True)) chat_id: int = Field(sa_column=Column(BigInteger, nullable=False)) ai_modal_id: Optional[int] = Field(sa_column=Column(BigInteger)) first_chat: bool = Field(sa_column=Column(Boolean, nullable=True, default=False)) create_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True)) finish_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True)) create_by: int = Field(sa_column=Column(BigInteger, nullable=True)) datasource: int = Field(sa_column=Column(BigInteger, nullable=True)) engine_type: str = Field(max_length=64, nullable=True) question: str = Field(sa_column=Column(Text, nullable=True)) sql_answer: str = Field(sa_column=Column(Text, nullable=True)) sql: str = Field(sa_column=Column(Text, nullable=True)) sql_exec_result: str = Field(sa_column=Column(Text, nullable=True)) data: str = Field(sa_column=Column(Text, nullable=True)) chart_answer: str = Field(sa_column=Column(Text, nullable=True)) chart: str = Field(sa_column=Column(Text, nullable=True)) analysis: str = Field(sa_column=Column(Text, nullable=True)) predict: str = Field(sa_column=Column(Text, nullable=True)) predict_data: str = Field(sa_column=Column(Text, nullable=True)) recommended_question_answer: str = Field(sa_column=Column(Text, nullable=True)) recommended_question: str = Field(sa_column=Column(Text, nullable=True)) datasource_select_answer: str = Field(sa_column=Column(Text, nullable=True)) finish: bool = Field(sa_column=Column(Boolean, nullable=True, default=False)) error: str = Field(sa_column=Column(Text, nullable=True)) analysis_record_id: int = Field(sa_column=Column(BigInteger, nullable=True)) predict_record_id: int = Field(sa_column=Column(BigInteger, nullable=True)) class ChatRecordResult(BaseModel): id: Optional[int] = None chat_id: Optional[int] = None ai_modal_id: Optional[int] = None first_chat: bool = False create_time: Optional[datetime] = None finish_time: Optional[datetime] = None question: Optional[str] = None sql_answer: Optional[str] = None sql: Optional[str] = None data: Optional[str] = None chart_answer: Optional[str] = None chart: Optional[str] = None analysis: Optional[str] = None predict: Optional[str] = None predict_data: Optional[str] = None recommended_question: Optional[str] = None datasource_select_answer: Optional[str] = None finish: Optional[bool] = None error: Optional[str] = None analysis_record_id: Optional[int] = None predict_record_id: Optional[int] = None sql_reasoning_content: Optional[str] = None chart_reasoning_content: Optional[str] = None analysis_reasoning_content: Optional[str] = None predict_reasoning_content: Optional[str] = None class CreateChat(BaseModel): id: int = None question: str = None datasource: int = None origin: Optional[int] = 0 class RenameChat(BaseModel): id: int = None brief: str = '' class ChatInfo(BaseModel): id: Optional[int] = None create_time: datetime = None create_by: int = None brief: str = '' chat_type: str = "chat" datasource: Optional[int] = None engine_type: str = '' ds_type: str = '' datasource_name: str = '' datasource_exists: bool = True records: List[ChatRecord | dict] = [] class AiModelQuestion(BaseModel): question: str = None ai_modal_id: int = None ai_modal_name: str = None # Specific model name engine: str = "" db_schema: str = "" sql: str = "" rule: str = "" fields: str = "" data: str = "" lang: str = "简体中文" filter: str = [] sub_query: Optional[list[dict]] = None terminologies: str = "" error_msg: str = "" def sql_sys_question(self): return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question, lang=self.lang, terminologies=self.terminologies) def sql_user_question(self, current_time: str): return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question, rule=self.rule, current_time=current_time, error_msg=self.error_msg) def chart_sys_question(self): return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang) def chart_user_question(self, chart_type: Optional[str] = None): return get_chart_template()['user'].format(sql=self.sql, question=self.question, rule=self.rule, chart_type=chart_type) def analysis_sys_question(self): return get_analysis_template()['system'].format(lang=self.lang, terminologies=self.terminologies) def analysis_user_question(self): return get_analysis_template()['user'].format(fields=self.fields, data=self.data) def predict_sys_question(self): return get_predict_template()['system'].format(lang=self.lang) def predict_user_question(self): return get_predict_template()['user'].format(fields=self.fields, data=self.data) def datasource_sys_question(self): return get_datasource_template()['system'].format(lang=self.lang) def datasource_user_question(self, datasource_list: str = "[]"): return get_datasource_template()['user'].format(question=self.question, data=datasource_list) def guess_sys_question(self): return get_guess_question_template()['system'].format(lang=self.lang) def guess_user_question(self, old_questions: str = "[]"): return get_guess_question_template()['user'].format(question=self.question, schema=self.db_schema, old_questions=old_questions) def filter_sys_question(self): return get_permissions_template()['system'].format(lang=self.lang, engine=self.engine) def filter_user_question(self): return get_permissions_template()['user'].format(sql=self.sql, filter=self.filter) def dynamic_sys_question(self): return get_dynamic_template()['system'].format(lang=self.lang, engine=self.engine) def dynamic_user_question(self): return get_dynamic_template()['user'].format(sql=self.sql, sub_query=self.sub_query) class ChatQuestion(AiModelQuestion): chat_id: int class ChatMcp(ChatQuestion): token: str class ChatStart(BaseModel): username: str = Body(description='用户名') password: str = Body(description='密码') class McpQuestion(BaseModel): question: str = Body(description='用户提问') chat_id: int = Body(description='会话ID') token: str = Body(description='token') class AxisObj(BaseModel): name: str = '' value: str = '' type: str | None = None class ExcelData(BaseModel): axis: list[AxisObj] = [] data: list[dict] = [] name: str = 'Excel'