Add File
This commit is contained in:
236
backend/apps/system/crud/assistant.py
Normal file
236
backend/apps/system/crud/assistant.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import json
|
||||
import urllib
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from fastapi import FastAPI
|
||||
from sqlalchemy import Engine, create_engine
|
||||
from sqlmodel import Session, select
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from apps.datasource.models.datasource import CoreDatasource, DatasourceConf
|
||||
from apps.system.models.system_model import AssistantModel
|
||||
from apps.system.schemas.auth import CacheName, CacheNamespace
|
||||
from apps.system.schemas.system_schema import AssistantHeader, AssistantOutDsSchema, UserInfoDTO
|
||||
from common.core.config import settings
|
||||
from common.core.db import engine
|
||||
from common.core.sqlbot_cache import cache
|
||||
from common.utils.aes_crypto import simple_aes_decrypt
|
||||
from common.utils.utils import string_to_numeric_hash
|
||||
|
||||
@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="assistant_id")
|
||||
async def get_assistant_info(*, session: Session, assistant_id: int) -> AssistantModel | None:
|
||||
db_model = session.get(AssistantModel, assistant_id)
|
||||
return db_model
|
||||
|
||||
|
||||
def get_assistant_user(*, id: int):
|
||||
return UserInfoDTO(id=id, account="sqlbot-inner-assistant", oid=1, name="sqlbot-inner-assistant",
|
||||
email="sqlbot-inner-assistant@sqlbot.com")
|
||||
|
||||
|
||||
def get_assistant_ds(session: Session, llm_service) -> list[dict]:
|
||||
assistant: AssistantHeader = llm_service.current_assistant
|
||||
type = assistant.type
|
||||
if type == 0 or type == 2:
|
||||
configuration = assistant.configuration
|
||||
if configuration:
|
||||
config: dict[any] = json.loads(configuration)
|
||||
oid: int = int(config['oid'])
|
||||
stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where(
|
||||
CoreDatasource.oid == oid)
|
||||
if not assistant.online:
|
||||
public_list: list[int] = config.get('public_list') or None
|
||||
if public_list:
|
||||
stmt = stmt.where(CoreDatasource.id.in_(public_list))
|
||||
else:
|
||||
return []
|
||||
""" private_list: list[int] = config.get('private_list') or None
|
||||
if private_list:
|
||||
stmt = stmt.where(~CoreDatasource.id.in_(private_list)) """
|
||||
db_ds_list = session.exec(stmt)
|
||||
|
||||
result_list = [
|
||||
{
|
||||
"id": ds.id,
|
||||
"name": ds.name,
|
||||
"description": ds.description
|
||||
}
|
||||
for ds in db_ds_list
|
||||
]
|
||||
|
||||
# filter private ds if offline
|
||||
return result_list
|
||||
out_ds_instance: AssistantOutDs = AssistantOutDsFactory.get_instance(assistant)
|
||||
llm_service.out_ds_instance = out_ds_instance
|
||||
dslist = out_ds_instance.get_simple_ds_list()
|
||||
# format?
|
||||
return dslist
|
||||
|
||||
|
||||
def init_dynamic_cors(app: FastAPI):
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
list_result = session.exec(select(AssistantModel).order_by(AssistantModel.create_time)).all()
|
||||
seen = set()
|
||||
unique_domains = []
|
||||
for item in list_result:
|
||||
if item.domain:
|
||||
for domain in item.domain.split(','):
|
||||
domain = domain.strip()
|
||||
if domain and domain not in seen:
|
||||
seen.add(domain)
|
||||
unique_domains.append(domain)
|
||||
cors_middleware = None
|
||||
for middleware in app.user_middleware:
|
||||
if middleware.cls == CORSMiddleware:
|
||||
cors_middleware = middleware
|
||||
break
|
||||
if cors_middleware:
|
||||
updated_origins = list(set(settings.all_cors_origins + unique_domains))
|
||||
cors_middleware.kwargs['allow_origins'] = updated_origins
|
||||
except Exception as e:
|
||||
return False, e
|
||||
|
||||
|
||||
class AssistantOutDs:
|
||||
assistant: AssistantHeader
|
||||
ds_list: Optional[list[AssistantOutDsSchema]] = None
|
||||
certificate: Optional[str] = None
|
||||
|
||||
def __init__(self, assistant: AssistantHeader):
|
||||
self.assistant = assistant
|
||||
self.ds_list = None
|
||||
self.certificate = assistant.certificate
|
||||
self.get_ds_from_api()
|
||||
|
||||
# @cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_DS, keyExpression="current_user.id")
|
||||
def get_ds_from_api(self):
|
||||
config: dict[any] = json.loads(self.assistant.configuration)
|
||||
endpoint: str = config['endpoint']
|
||||
certificateList: list[any] = json.loads(self.certificate)
|
||||
header = {}
|
||||
cookies = {}
|
||||
param = {}
|
||||
for item in certificateList:
|
||||
if item['target'] == 'header':
|
||||
header[item['key']] = item['value']
|
||||
if item['target'] == 'cookie':
|
||||
cookies[item['key']] = item['value']
|
||||
if item['target'] == 'param':
|
||||
param[item['key']] = item['value']
|
||||
res = requests.get(url=endpoint, params=param, headers=header, cookies=cookies, timeout=10)
|
||||
if res.status_code == 200:
|
||||
result_json: dict[any] = json.loads(res.text)
|
||||
if result_json.get('code') == 0 or result_json.get('code') == 200:
|
||||
temp_list = result_json.get('data', [])
|
||||
temp_ds_list = [
|
||||
self.convert2schema(item, config)
|
||||
for item in temp_list
|
||||
]
|
||||
self.ds_list = temp_ds_list
|
||||
return self.ds_list
|
||||
else:
|
||||
raise Exception(f"Failed to get datasource list from {endpoint}, error: {result_json.get('message')}")
|
||||
else:
|
||||
raise Exception(f"Failed to get datasource list from {endpoint}, status code: {res.status_code}")
|
||||
|
||||
def get_simple_ds_list(self):
|
||||
if self.ds_list:
|
||||
return [{'id': ds.id, 'name': ds.name, 'description': ds.comment} for ds in self.ds_list]
|
||||
else:
|
||||
raise Exception("Datasource list is not found.")
|
||||
|
||||
def get_db_schema(self, ds_id: int) -> str:
|
||||
ds = self.get_ds(ds_id)
|
||||
schema_str = ""
|
||||
db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase
|
||||
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
|
||||
for table in ds.tables:
|
||||
schema_str += f"# Table: {db_name}.{table.name}" if ds.type != "mysql" else f"# Table: {table.name}"
|
||||
table_comment = table.comment
|
||||
if table_comment == '':
|
||||
schema_str += '\n[\n'
|
||||
else:
|
||||
schema_str += f", {table_comment}\n[\n"
|
||||
|
||||
field_list = []
|
||||
for field in table.fields:
|
||||
field_comment = field.comment
|
||||
if field_comment == '':
|
||||
field_list.append(f"({field.name}:{field.type})")
|
||||
else:
|
||||
field_list.append(f"({field.name}:{field.type}, {field_comment})")
|
||||
schema_str += ",\n".join(field_list)
|
||||
schema_str += '\n]\n'
|
||||
return schema_str
|
||||
|
||||
def get_ds(self, ds_id: int):
|
||||
if self.ds_list:
|
||||
for ds in self.ds_list:
|
||||
if ds.id == ds_id:
|
||||
return ds
|
||||
else:
|
||||
raise Exception("Datasource list is not found.")
|
||||
raise Exception(f"Datasource with id {ds_id} not found.")
|
||||
|
||||
def convert2schema(self, ds_dict: dict, config: dict[any]) -> AssistantOutDsSchema:
|
||||
id_marker: str = ''
|
||||
attr_list = ['name', 'type', 'host', 'port', 'user', 'dataBase', 'schema']
|
||||
if config.get('encrypt', False):
|
||||
key = config.get('aes_key', None)
|
||||
iv = config.get('aes_iv', None)
|
||||
aes_attrs = ['host', 'user', 'password', 'dataBase', 'db_schema']
|
||||
for attr in aes_attrs:
|
||||
if attr in ds_dict and ds_dict[attr]:
|
||||
try:
|
||||
ds_dict[attr] = simple_aes_decrypt(ds_dict[attr], key, iv)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to encrypt {attr} for datasource {ds_dict.get('name')}, error: {str(e)}")
|
||||
for attr in attr_list:
|
||||
if attr in ds_dict:
|
||||
id_marker += str(ds_dict.get(attr, '')) + '--sqlbot--'
|
||||
id = string_to_numeric_hash(id_marker)
|
||||
db_schema = ds_dict.get('schema', ds_dict.get('db_schema', ''))
|
||||
ds_dict.pop("schema", None)
|
||||
return AssistantOutDsSchema(**{**ds_dict, "id": id, "db_schema": db_schema})
|
||||
|
||||
class AssistantOutDsFactory:
|
||||
@staticmethod
|
||||
def get_instance(assistant: AssistantHeader) -> AssistantOutDs:
|
||||
return AssistantOutDs(assistant)
|
||||
|
||||
|
||||
def get_ds_engine(ds: AssistantOutDsSchema) -> Engine:
|
||||
timeout: int = 30
|
||||
connect_args = {"connect_timeout": timeout}
|
||||
conf = DatasourceConf(
|
||||
host=ds.host,
|
||||
port=ds.port,
|
||||
username=ds.user,
|
||||
password=ds.password,
|
||||
database=ds.dataBase,
|
||||
driver='',
|
||||
extraJdbc=ds.extraParams,
|
||||
dbSchema=ds.db_schema or ''
|
||||
)
|
||||
conf.extraJdbc = ''
|
||||
from apps.db.db import get_uri_from_config
|
||||
uri = get_uri_from_config(ds.type, conf)
|
||||
# if ds.type == "pg" and ds.db_schema:
|
||||
# connect_args.update({"options": f"-c search_path={ds.db_schema}"})
|
||||
# engine = create_engine(uri, connect_args=connect_args, pool_timeout=timeout, pool_size=20, max_overflow=10)
|
||||
|
||||
if ds.type == "pg" and ds.db_schema:
|
||||
engine = create_engine(uri,
|
||||
connect_args={"options": f"-c search_path={urllib.parse.quote(ds.db_schema)}",
|
||||
"connect_timeout": timeout},
|
||||
pool_timeout=timeout)
|
||||
elif ds.type == 'sqlServer':
|
||||
engine = create_engine(uri, pool_timeout=timeout)
|
||||
elif ds.type == 'oracle':
|
||||
engine = create_engine(uri,
|
||||
pool_timeout=timeout)
|
||||
else:
|
||||
engine = create_engine(uri, connect_args={"connect_timeout": timeout}, pool_timeout=timeout)
|
||||
return engine
|
||||
Reference in New Issue
Block a user