import base64 import json import platform import urllib.parse from decimal import Decimal from typing import Optional from apps.db.db_sql import get_table_sql, get_field_sql, get_version_sql from common.error import ParseSQLResultError if platform.system() != "Darwin": import dmPython import pymysql import redshift_connector from sqlalchemy import create_engine, text, Engine from sqlalchemy.orm import sessionmaker from apps.datasource.models.datasource import DatasourceConf, CoreDatasource, TableSchema, ColumnSchema from apps.datasource.utils.utils import aes_decrypt from apps.db.constant import DB, ConnectType from apps.db.engine import get_engine_config from apps.system.crud.assistant import get_ds_engine from apps.system.schemas.system_schema import AssistantOutDsSchema from common.core.deps import Trans from common.utils.utils import SQLBotLogUtil from fastapi import HTTPException def get_uri(ds: CoreDatasource) -> str: conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() return get_uri_from_config(ds.type, conf) def get_uri_from_config(type: str, conf: DatasourceConf) -> str: db_url: str if type == "mysql": if conf.extraJdbc is not None and conf.extraJdbc != '': db_url = f"mysql+pymysql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" else: db_url = f"mysql+pymysql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}" elif type == "sqlServer": if conf.extraJdbc is not None and conf.extraJdbc != '': db_url = f"mssql+pymssql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" else: db_url = f"mssql+pymssql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}" elif type == "pg" or type == "excel": if conf.extraJdbc is not None and conf.extraJdbc != '': db_url = f"postgresql+psycopg2://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" else: db_url = f"postgresql+psycopg2://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}" elif type == "oracle": if conf.mode == "service_name": if conf.extraJdbc is not None and conf.extraJdbc != '': db_url = f"oracle+oracledb://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}?service_name={conf.database}&{conf.extraJdbc}" else: db_url = f"oracle+oracledb://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}?service_name={conf.database}" else: if conf.extraJdbc is not None and conf.extraJdbc != '': db_url = f"oracle+oracledb://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" else: db_url = f"oracle+oracledb://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}" elif type == "ck": if conf.extraJdbc is not None and conf.extraJdbc != '': db_url = f"clickhouse+http://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" else: db_url = f"clickhouse+http://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}" else: raise 'The datasource type not support.' return db_url def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine: conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() if conf.timeout is None: conf.timeout = timeout if timeout > 0: conf.timeout = timeout if ds.type == "pg": if conf.dbSchema is not None and conf.dbSchema != "": engine = create_engine(get_uri(ds), connect_args={"options": f"-c search_path={urllib.parse.quote(conf.dbSchema)}", "connect_timeout": conf.timeout}, pool_timeout=conf.timeout) else: engine = create_engine(get_uri(ds), connect_args={"connect_timeout": conf.timeout}, pool_timeout=conf.timeout) elif ds.type == 'sqlServer': engine = create_engine(get_uri(ds), pool_timeout=conf.timeout) elif ds.type == 'oracle': engine = create_engine(get_uri(ds), pool_timeout=conf.timeout) else: # mysql, ck engine = create_engine(get_uri(ds), connect_args={"connect_timeout": conf.timeout}, pool_timeout=conf.timeout) return engine def get_session(ds: CoreDatasource | AssistantOutDsSchema): engine = get_engine(ds) if isinstance(ds, CoreDatasource) else get_ds_engine(ds) session_maker = sessionmaker(bind=engine) session = session_maker() return session def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDsSchema, is_raise: bool = False): if isinstance(ds, CoreDatasource): db = DB.get_db(ds.type) if db.connect_type == ConnectType.sqlalchemy: conn = get_engine(ds, 10) try: with conn.connect() as connection: SQLBotLogUtil.info("success") return True except Exception as e: SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") if is_raise: raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') return False else: conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type == 'dm': with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, port=conf.port) as conn, conn.cursor() as cursor: try: cursor.execute('select 1', timeout=10).fetchall() SQLBotLogUtil.info("success") return True except Exception as e: SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") if is_raise: raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') return False elif ds.type == 'doris': with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=10, read_timeout=10) as conn, conn.cursor() as cursor: try: cursor.execute('select 1') SQLBotLogUtil.info("success") return True except Exception as e: SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") if is_raise: raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') return False elif ds.type == 'redshift': with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, timeout=10) as conn, conn.cursor() as cursor: try: cursor.execute('select 1') SQLBotLogUtil.info("success") return True except Exception as e: SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") if is_raise: raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') return False else: conn = get_ds_engine(ds) try: with conn.connect() as connection: SQLBotLogUtil.info("success") return True except Exception as e: SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") if is_raise: raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') return False return False def get_version(ds: CoreDatasource | AssistantOutDsSchema): version = '' conf = None if isinstance(ds, CoreDatasource): conf = DatasourceConf( **json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() if isinstance(ds, AssistantOutDsSchema): conf = DatasourceConf() conf.host = ds.host conf.port = ds.port conf.username = ds.user conf.password = ds.password conf.database = ds.dataBase conf.dbSchema = ds.db_schema conf.timeout = 10 db = DB.get_db(ds.type) sql = get_version_sql(ds, conf) try: if db.connect_type == ConnectType.sqlalchemy: with get_session(ds) as session: with session.execute(text(sql)) as result: res = result.fetchall() version = res[0][0] else: if ds.type == 'dm': with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, port=conf.port) as conn, conn.cursor() as cursor: cursor.execute(sql, timeout=10) res = cursor.fetchall() version = res[0][0] elif ds.type == 'doris': with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=10, read_timeout=10) as conn, conn.cursor() as cursor: cursor.execute(sql) res = cursor.fetchall() version = res[0][0] elif ds.type == 'redshift': version = '' except Exception as e: print(e) version = '' return version.decode() if isinstance(version, bytes) else version def get_schema(ds: CoreDatasource): conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() db = DB.get_db(ds.type) if db.connect_type == ConnectType.sqlalchemy: with get_session(ds) as session: sql: str = '' if ds.type == "sqlServer": sql = f"""select name from sys.schemas""" elif ds.type == "pg" or ds.type == "excel": sql = """SELECT nspname FROM pg_namespace""" elif ds.type == "oracle": sql = f"""select * from all_users""" with session.execute(text(sql)) as result: res = result.fetchall() res_list = [item[0] for item in res] return res_list else: if ds.type == 'dm': with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, port=conf.port) as conn, conn.cursor() as cursor: cursor.execute(f"""select OBJECT_NAME from dba_objects where object_type='SCH'""", timeout=conf.timeout) res = cursor.fetchall() res_list = [item[0] for item in res] return res_list elif ds.type == 'redshift': with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, timeout=conf.timeout) as conn, conn.cursor() as cursor: cursor.execute(f"""SELECT nspname FROM pg_namespace""") res = cursor.fetchall() res_list = [item[0] for item in res] return res_list def get_tables(ds: CoreDatasource): conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() db = DB.get_db(ds.type) sql = get_table_sql(ds, conf) if db.connect_type == ConnectType.sqlalchemy: with get_session(ds) as session: with session.execute(text(sql)) as result: res = result.fetchall() res_list = [TableSchema(*item) for item in res] return res_list else: if ds.type == 'dm': with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, port=conf.port) as conn, conn.cursor() as cursor: cursor.execute(sql, timeout=conf.timeout) res = cursor.fetchall() res_list = [TableSchema(*item) for item in res] return res_list elif ds.type == 'doris': with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=conf.timeout, read_timeout=conf.timeout) as conn, conn.cursor() as cursor: cursor.execute(sql) res = cursor.fetchall() res_list = [TableSchema(*item) for item in res] return res_list elif ds.type == 'redshift': with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, timeout=conf.timeout) as conn, conn.cursor() as cursor: cursor.execute(sql) res = cursor.fetchall() res_list = [TableSchema(*item) for item in res] return res_list def get_fields(ds: CoreDatasource, table_name: str = None): conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() db = DB.get_db(ds.type) sql = get_field_sql(ds, conf, table_name) if db.connect_type == ConnectType.sqlalchemy: with get_session(ds) as session: with session.execute(text(sql)) as result: res = result.fetchall() res_list = [ColumnSchema(*item) for item in res] return res_list else: if ds.type == 'dm': with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, port=conf.port) as conn, conn.cursor() as cursor: cursor.execute(sql, timeout=conf.timeout) res = cursor.fetchall() res_list = [ColumnSchema(*item) for item in res] return res_list elif ds.type == 'doris': with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=conf.timeout, read_timeout=conf.timeout) as conn, conn.cursor() as cursor: cursor.execute(sql) res = cursor.fetchall() res_list = [ColumnSchema(*item) for item in res] return res_list elif ds.type == 'redshift': with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, timeout=conf.timeout) as conn, conn.cursor() as cursor: cursor.execute(sql) res = cursor.fetchall() res_list = [ColumnSchema(*item) for item in res] return res_list def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=False): while sql.endswith(';'): sql = sql[:-1] db = DB.get_db(ds.type) if db.connect_type == ConnectType.sqlalchemy: with get_session(ds) as session: with session.execute(text(sql)) as result: try: columns = result.keys()._keys if origin_column else [item.lower() for item in result.keys()._keys] res = result.fetchall() result_list = [ {str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in enumerate(tuple_item)} for tuple_item in res ] return {"fields": columns, "data": result_list, "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} except Exception as ex: raise ParseSQLResultError(str(ex)) else: conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type == 'dm': with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, port=conf.port) as conn, conn.cursor() as cursor: try: cursor.execute(sql, timeout=conf.timeout) res = cursor.fetchall() columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for field in cursor.description] result_list = [ {str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in enumerate(tuple_item)} for tuple_item in res ] return {"fields": columns, "data": result_list, "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} except Exception as ex: raise ParseSQLResultError(str(ex)) elif ds.type == 'doris': with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=conf.timeout, read_timeout=conf.timeout) as conn, conn.cursor() as cursor: try: cursor.execute(sql) res = cursor.fetchall() columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for field in cursor.description] result_list = [ {str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in enumerate(tuple_item)} for tuple_item in res ] return {"fields": columns, "data": result_list, "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} except Exception as ex: raise ParseSQLResultError(str(ex)) elif ds.type == 'redshift': with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, timeout=conf.timeout) as conn, conn.cursor() as cursor: try: cursor.execute(sql) res = cursor.fetchall() columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for field in cursor.description] result_list = [ {str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in enumerate(tuple_item)} for tuple_item in res ] return {"fields": columns, "data": result_list, "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} except Exception as ex: raise ParseSQLResultError(str(ex))