import datetime import logging import traceback from concurrent.futures import ThreadPoolExecutor from typing import List, Optional from xml.dom.minidom import parseString import dicttoxml from sqlalchemy import and_, or_, select, func, delete, update, union from sqlalchemy import create_engine, text from sqlalchemy.orm import aliased from sqlalchemy.orm import sessionmaker from apps.ai_model.embedding import EmbeddingModelCache from apps.template.generate_chart.generator import get_base_terminology_template from apps.terminology.models.terminology_model import Terminology, TerminologyInfo from common.core.config import settings from common.core.deps import SessionDep, Trans executor = ThreadPoolExecutor(max_workers=200) def page_terminology(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None, oid: Optional[int] = 1): _list: List[TerminologyInfo] = [] child = aliased(Terminology) current_page = max(1, current_page) page_size = max(10, page_size) total_count = 0 total_pages = 0 if name and name.strip() != "": keyword_pattern = f"%{name.strip()}%" # 步骤1:先找到所有匹配的节点ID(无论是父节点还是子节点) matched_ids_subquery = ( select(Terminology.id) .where(and_(Terminology.word.ilike(keyword_pattern), Terminology.oid == oid)) # LIKE查询条件 .subquery() ) # 步骤2:找到这些匹配节点的所有父节点(包括自身如果是父节点) parent_ids_subquery = ( select(Terminology.id) .where( (Terminology.id.in_(matched_ids_subquery)) | (Terminology.id.in_( select(Terminology.pid) .where(Terminology.id.in_(matched_ids_subquery)) .where(Terminology.pid.isnot(None)) )) ) .where(Terminology.pid.is_(None)) # 只取父节点 ) count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery()) total_count = session.execute(count_stmt).scalar() total_pages = (total_count + page_size - 1) // page_size if current_page > total_pages: current_page = 1 # 步骤3:获取分页后的父节点ID paginated_parent_ids = ( parent_ids_subquery .order_by(Terminology.create_time.desc()) .offset((current_page - 1) * page_size) .limit(page_size) .subquery() ) # 步骤4:获取这些父节点的childrenNames children_subquery = ( select( child.pid, func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words') ) .where(child.pid.isnot(None)) .group_by(child.pid) .subquery() ) # 主查询 stmt = ( select( Terminology.id, Terminology.word, Terminology.create_time, Terminology.description, children_subquery.c.other_words ) .outerjoin( children_subquery, Terminology.id == children_subquery.c.pid ) .where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid)) .order_by(Terminology.create_time.desc()) ) else: parent_ids_subquery = ( select(Terminology.id) .where(and_(Terminology.pid.is_(None), Terminology.oid == oid)) # 只取父节点 ) count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery()) total_count = session.execute(count_stmt).scalar() total_pages = (total_count + page_size - 1) // page_size if current_page > total_pages: current_page = 1 paginated_parent_ids = ( parent_ids_subquery .order_by(Terminology.create_time.desc()) .offset((current_page - 1) * page_size) .limit(page_size) .subquery() ) stmt = ( select( Terminology.id, Terminology.word, Terminology.create_time, Terminology.description, func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words') ) .outerjoin(child, and_(Terminology.id == child.pid)) .where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid)) .group_by(Terminology.id, Terminology.word) .order_by(Terminology.create_time.desc()) ) result = session.execute(stmt) for row in result: _list.append(TerminologyInfo( id=row.id, word=row.word, create_time=row.create_time, description=row.description, other_words=row.other_words if row.other_words else [], )) return current_page, page_size, total_count, total_pages, _list def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans): create_time = datetime.datetime.now() parent = Terminology(word=info.word, create_time=create_time, description=info.description, oid=oid) words = [info.word] for child in info.other_words: if child in words: raise Exception(trans("i18n_terminology.cannot_be_repeated")) else: words.append(child) exists = session.query( session.query(Terminology).filter(and_(Terminology.word.in_(words), Terminology.oid == oid)).exists()).scalar() if exists: raise Exception(trans("i18n_terminology.exists_in_db")) result = Terminology(**parent.model_dump()) session.add(parent) session.flush() session.refresh(parent) result.id = parent.id session.commit() _list: List[Terminology] = [] if info.other_words: for other_word in info.other_words: if other_word.strip() == "": continue _list.append( Terminology(pid=result.id, word=other_word, create_time=create_time, oid=oid)) session.bulk_save_objects(_list) session.flush() session.commit() # embedding run_save_embeddings([result.id]) return result.id def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans): count = session.query(Terminology).filter( Terminology.oid == oid, Terminology.id == info.id ).count() if count == 0: raise Exception(trans('i18n_terminology.terminology_not_exists')) words = [info.word] for child in info.other_words: if child in words: raise Exception(trans("i18n_terminology.cannot_be_repeated")) else: words.append(child) exists = session.query( session.query(Terminology).filter( Terminology.word.in_(words), Terminology.oid == oid, or_( Terminology.pid != info.id, and_(Terminology.pid.is_(None), Terminology.id != info.id) ), Terminology.id != info.id ).exists()).scalar() if exists: raise Exception(trans("i18n_terminology.exists_in_db")) stmt = update(Terminology).where(and_(Terminology.id == info.id)).values( word=info.word, description=info.description, ) session.execute(stmt) session.commit() stmt = delete(Terminology).where(and_(Terminology.pid == info.id)) session.execute(stmt) session.commit() create_time = datetime.datetime.now() _list: List[Terminology] = [] if info.other_words: for other_word in info.other_words: if other_word.strip() == "": continue _list.append( Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid)) session.bulk_save_objects(_list) session.flush() session.commit() # embedding run_save_embeddings([info.id]) return info.id def delete_terminology(session: SessionDep, ids: list[int]): stmt = delete(Terminology).where(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids))) session.execute(stmt) session.commit() def run_save_embeddings(ids: List[int]): executor.submit(save_embeddings, ids) def fill_empty_embeddings(): executor.submit(run_fill_empty_embeddings) def run_fill_empty_embeddings(): if not settings.EMBEDDING_ENABLED: return engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) session_maker = sessionmaker(bind=engine) session = session_maker() stmt1 = select(Terminology.id).where(and_(Terminology.embedding.is_(None), Terminology.pid.is_(None))) stmt2 = select(Terminology.pid).where(and_(Terminology.embedding.is_(None), Terminology.pid.isnot(None))).distinct() combined_stmt = union(stmt1, stmt2) results = session.execute(combined_stmt).scalars().all() save_embeddings(results) def save_embeddings(ids: List[int]): if not settings.EMBEDDING_ENABLED: return if not ids or len(ids) == 0: return try: engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) session_maker = sessionmaker(bind=engine) session = session_maker() _list = session.query(Terminology).filter(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids))).all() _words_list = [item.word for item in _list] model = EmbeddingModelCache.get_model() results = model.embed_documents(_words_list) for index in range(len(results)): item = results[index] stmt = update(Terminology).where(and_(Terminology.id == _list[index].id)).values(embedding=item) session.execute(stmt) session.commit() except Exception: traceback.print_exc() embedding_sql = f""" SELECT id, pid, word, similarity FROM (SELECT id, pid, word, oid, ( 1 - (embedding <=> :embedding_array) ) AS similarity FROM terminology AS child ) TEMP WHERE similarity > {settings.EMBEDDING_SIMILARITY} and oid = :oid ORDER BY similarity DESC LIMIT {settings.EMBEDDING_TOP_COUNT} """ def select_terminology_by_word(session: SessionDep, word: str, oid: int): if word.strip() == "": return [] _list: List[Terminology] = [] stmt = ( select( Terminology.id, Terminology.pid, Terminology.word, ) .where( and_(text(":sentence ILIKE '%' || word || '%'"), Terminology.oid == oid) ) ) results = session.execute(stmt, {'sentence': word}).fetchall() for row in results: _list.append(Terminology(id=row.id, word=row.word, pid=row.pid)) if settings.EMBEDDING_ENABLED: try: model = EmbeddingModelCache.get_model() embedding = model.embed_query(word) results = session.execute(text(embedding_sql), {'embedding_array': str(embedding), 'oid': oid}) for row in results: _list.append(Terminology(id=row.id, word=row.word, pid=row.pid)) except Exception: traceback.print_exc() _map: dict = {} _ids: list[int] = [] for row in _list: if row.id in _ids or row.pid in _ids: continue if row.pid is not None: _ids.append(row.pid) else: _ids.append(row.id) if len(_ids) == 0: return [] t_list = session.query(Terminology.id, Terminology.pid, Terminology.word, Terminology.description).filter( or_(Terminology.id.in_(_ids), Terminology.pid.in_(_ids))).all() for row in t_list: pid = str(row.pid) if row.pid is not None else str(row.id) if _map.get(pid) is None: _map[pid] = {'words': [], 'description': row.description} _map[pid]['words'].append(row.word) _results: list[dict] = [] for key in _map.keys(): _results.append(_map.get(key)) return _results def get_example(): _obj = { 'terminologies': [ {'words': ['GDP', '国内生产总值'], 'description': '指在一个季度或一年,一个国家或地区的经济中所生产出的全部最终产品和劳务的价值。'}, ] } return to_xml_string(_obj, 'example') def to_xml_string(_dict: list[dict] | dict, root: str = 'terminologies') -> str: item_name_func = lambda x: 'terminology' if x == 'terminologies' else 'word' if x == 'words' else 'item' dicttoxml.LOG.setLevel(logging.ERROR) xml = dicttoxml.dicttoxml(_dict, cdata=['word', 'description'], custom_root=root, item_func=item_name_func, xml_declaration=False, encoding='utf-8', attr_type=False).decode('utf-8') pretty_xml = parseString(xml).toprettyxml() if pretty_xml.startswith('') + 1 pretty_xml = pretty_xml[end_index:].lstrip() # 替换所有 XML 转义字符 escape_map = { '<': '<', '>': '>', '&': '&', '"': '"', ''': "'" } for escaped, original in escape_map.items(): pretty_xml = pretty_xml.replace(escaped, original) return pretty_xml def get_terminology_template(session: SessionDep, question: str, oid: Optional[int] = 1) -> str: if not oid: oid = 1 _results = select_terminology_by_word(session, question, oid) if _results and len(_results) > 0: terminology = to_xml_string(_results) template = get_base_terminology_template().format(terminologies=terminology) return template else: return ''