diff --git a/backend/apps/terminology/curd/terminology.py b/backend/apps/terminology/curd/terminology.py new file mode 100644 index 0000000..309a2c8 --- /dev/null +++ b/backend/apps/terminology/curd/terminology.py @@ -0,0 +1,430 @@ +import datetime +import logging +import traceback +from concurrent.futures import ThreadPoolExecutor +from typing import List, Optional +from xml.dom.minidom import parseString + +import dicttoxml +from sqlalchemy import and_, or_, select, func, delete, update, union +from sqlalchemy import create_engine, text +from sqlalchemy.orm import aliased +from sqlalchemy.orm import sessionmaker + +from apps.ai_model.embedding import EmbeddingModelCache +from apps.template.generate_chart.generator import get_base_terminology_template +from apps.terminology.models.terminology_model import Terminology, TerminologyInfo +from common.core.config import settings +from common.core.deps import SessionDep, Trans + +executor = ThreadPoolExecutor(max_workers=200) + + +def page_terminology(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None, + oid: Optional[int] = 1): + _list: List[TerminologyInfo] = [] + + child = aliased(Terminology) + + current_page = max(1, current_page) + page_size = max(10, page_size) + + total_count = 0 + total_pages = 0 + + if name and name.strip() != "": + keyword_pattern = f"%{name.strip()}%" + # 步骤1:先找到所有匹配的节点ID(无论是父节点还是子节点) + matched_ids_subquery = ( + select(Terminology.id) + .where(and_(Terminology.word.ilike(keyword_pattern), Terminology.oid == oid)) # LIKE查询条件 + .subquery() + ) + + # 步骤2:找到这些匹配节点的所有父节点(包括自身如果是父节点) + parent_ids_subquery = ( + select(Terminology.id) + .where( + (Terminology.id.in_(matched_ids_subquery)) | + (Terminology.id.in_( + select(Terminology.pid) + .where(Terminology.id.in_(matched_ids_subquery)) + .where(Terminology.pid.isnot(None)) + )) + ) + .where(Terminology.pid.is_(None)) # 只取父节点 + ) + + count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery()) + total_count = session.execute(count_stmt).scalar() + total_pages = (total_count + page_size - 1) // page_size + + if current_page > total_pages: + current_page = 1 + + # 步骤3:获取分页后的父节点ID + paginated_parent_ids = ( + parent_ids_subquery + .order_by(Terminology.create_time.desc()) + .offset((current_page - 1) * page_size) + .limit(page_size) + .subquery() + ) + + # 步骤4:获取这些父节点的childrenNames + children_subquery = ( + select( + child.pid, + func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words') + ) + .where(child.pid.isnot(None)) + .group_by(child.pid) + .subquery() + ) + + # 主查询 + stmt = ( + select( + Terminology.id, + Terminology.word, + Terminology.create_time, + Terminology.description, + children_subquery.c.other_words + ) + .outerjoin( + children_subquery, + Terminology.id == children_subquery.c.pid + ) + .where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid)) + .order_by(Terminology.create_time.desc()) + ) + else: + parent_ids_subquery = ( + select(Terminology.id) + .where(and_(Terminology.pid.is_(None), Terminology.oid == oid)) # 只取父节点 + ) + count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery()) + total_count = session.execute(count_stmt).scalar() + total_pages = (total_count + page_size - 1) // page_size + + if current_page > total_pages: + current_page = 1 + + paginated_parent_ids = ( + parent_ids_subquery + .order_by(Terminology.create_time.desc()) + .offset((current_page - 1) * page_size) + .limit(page_size) + .subquery() + ) + + stmt = ( + select( + Terminology.id, + Terminology.word, + Terminology.create_time, + Terminology.description, + func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words') + ) + .outerjoin(child, and_(Terminology.id == child.pid)) + .where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid)) + .group_by(Terminology.id, Terminology.word) + .order_by(Terminology.create_time.desc()) + ) + + result = session.execute(stmt) + + for row in result: + _list.append(TerminologyInfo( + id=row.id, + word=row.word, + create_time=row.create_time, + description=row.description, + other_words=row.other_words if row.other_words else [], + )) + + return current_page, page_size, total_count, total_pages, _list + + +def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans): + create_time = datetime.datetime.now() + parent = Terminology(word=info.word, create_time=create_time, description=info.description, oid=oid) + + words = [info.word] + for child in info.other_words: + if child in words: + raise Exception(trans("i18n_terminology.cannot_be_repeated")) + else: + words.append(child) + + exists = session.query( + session.query(Terminology).filter(and_(Terminology.word.in_(words), Terminology.oid == oid)).exists()).scalar() + if exists: + raise Exception(trans("i18n_terminology.exists_in_db")) + + result = Terminology(**parent.model_dump()) + + session.add(parent) + session.flush() + session.refresh(parent) + + result.id = parent.id + session.commit() + + _list: List[Terminology] = [] + if info.other_words: + for other_word in info.other_words: + if other_word.strip() == "": + continue + _list.append( + Terminology(pid=result.id, word=other_word, create_time=create_time, oid=oid)) + session.bulk_save_objects(_list) + session.flush() + session.commit() + + # embedding + run_save_embeddings([result.id]) + + return result.id + + +def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans): + count = session.query(Terminology).filter( + Terminology.oid == oid, + Terminology.id == info.id + ).count() + if count == 0: + raise Exception(trans('i18n_terminology.terminology_not_exists')) + + words = [info.word] + for child in info.other_words: + if child in words: + raise Exception(trans("i18n_terminology.cannot_be_repeated")) + else: + words.append(child) + + exists = session.query( + session.query(Terminology).filter( + Terminology.word.in_(words), + Terminology.oid == oid, + or_( + Terminology.pid != info.id, + and_(Terminology.pid.is_(None), Terminology.id != info.id) + ), + Terminology.id != info.id + ).exists()).scalar() + if exists: + raise Exception(trans("i18n_terminology.exists_in_db")) + + stmt = update(Terminology).where(and_(Terminology.id == info.id)).values( + word=info.word, + description=info.description, + ) + session.execute(stmt) + session.commit() + + stmt = delete(Terminology).where(and_(Terminology.pid == info.id)) + session.execute(stmt) + session.commit() + + create_time = datetime.datetime.now() + _list: List[Terminology] = [] + if info.other_words: + for other_word in info.other_words: + if other_word.strip() == "": + continue + _list.append( + Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid)) + session.bulk_save_objects(_list) + session.flush() + session.commit() + + # embedding + run_save_embeddings([info.id]) + + return info.id + + +def delete_terminology(session: SessionDep, ids: list[int]): + stmt = delete(Terminology).where(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids))) + session.execute(stmt) + session.commit() + + +def run_save_embeddings(ids: List[int]): + executor.submit(save_embeddings, ids) + + +def fill_empty_embeddings(): + executor.submit(run_fill_empty_embeddings) + + +def run_fill_empty_embeddings(): + if not settings.EMBEDDING_ENABLED: + return + engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) + session_maker = sessionmaker(bind=engine) + session = session_maker() + stmt1 = select(Terminology.id).where(and_(Terminology.embedding.is_(None), Terminology.pid.is_(None))) + stmt2 = select(Terminology.pid).where(and_(Terminology.embedding.is_(None), Terminology.pid.isnot(None))).distinct() + combined_stmt = union(stmt1, stmt2) + results = session.execute(combined_stmt).scalars().all() + save_embeddings(results) + + +def save_embeddings(ids: List[int]): + if not settings.EMBEDDING_ENABLED: + return + + if not ids or len(ids) == 0: + return + try: + engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) + session_maker = sessionmaker(bind=engine) + session = session_maker() + + _list = session.query(Terminology).filter(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids))).all() + + _words_list = [item.word for item in _list] + + model = EmbeddingModelCache.get_model() + + results = model.embed_documents(_words_list) + + for index in range(len(results)): + item = results[index] + stmt = update(Terminology).where(and_(Terminology.id == _list[index].id)).values(embedding=item) + session.execute(stmt) + session.commit() + + except Exception: + traceback.print_exc() + + +embedding_sql = f""" +SELECT id, pid, word, similarity +FROM +(SELECT id, pid, word, oid, +( 1 - (embedding <=> :embedding_array) ) AS similarity +FROM terminology AS child +) TEMP +WHERE similarity > {settings.EMBEDDING_SIMILARITY} and oid = :oid +ORDER BY similarity DESC +LIMIT {settings.EMBEDDING_TOP_COUNT} +""" + + +def select_terminology_by_word(session: SessionDep, word: str, oid: int): + if word.strip() == "": + return [] + + _list: List[Terminology] = [] + + stmt = ( + select( + Terminology.id, + Terminology.pid, + Terminology.word, + ) + .where( + and_(text(":sentence ILIKE '%' || word || '%'"), Terminology.oid == oid) + ) + ) + + results = session.execute(stmt, {'sentence': word}).fetchall() + + for row in results: + _list.append(Terminology(id=row.id, word=row.word, pid=row.pid)) + + if settings.EMBEDDING_ENABLED: + try: + model = EmbeddingModelCache.get_model() + + embedding = model.embed_query(word) + + results = session.execute(text(embedding_sql), {'embedding_array': str(embedding), 'oid': oid}) + + for row in results: + _list.append(Terminology(id=row.id, word=row.word, pid=row.pid)) + + except Exception: + traceback.print_exc() + + _map: dict = {} + _ids: list[int] = [] + for row in _list: + if row.id in _ids or row.pid in _ids: + continue + if row.pid is not None: + _ids.append(row.pid) + else: + _ids.append(row.id) + + if len(_ids) == 0: + return [] + + t_list = session.query(Terminology.id, Terminology.pid, Terminology.word, Terminology.description).filter( + or_(Terminology.id.in_(_ids), Terminology.pid.in_(_ids))).all() + for row in t_list: + pid = str(row.pid) if row.pid is not None else str(row.id) + if _map.get(pid) is None: + _map[pid] = {'words': [], 'description': row.description} + _map[pid]['words'].append(row.word) + + _results: list[dict] = [] + for key in _map.keys(): + _results.append(_map.get(key)) + + return _results + + +def get_example(): + _obj = { + 'terminologies': [ + {'words': ['GDP', '国内生产总值'], + 'description': '指在一个季度或一年,一个国家或地区的经济中所生产出的全部最终产品和劳务的价值。'}, + ] + } + return to_xml_string(_obj, 'example') + + +def to_xml_string(_dict: list[dict] | dict, root: str = 'terminologies') -> str: + item_name_func = lambda x: 'terminology' if x == 'terminologies' else 'word' if x == 'words' else 'item' + dicttoxml.LOG.setLevel(logging.ERROR) + xml = dicttoxml.dicttoxml(_dict, + cdata=['word', 'description'], + custom_root=root, + item_func=item_name_func, + xml_declaration=False, + encoding='utf-8', + attr_type=False).decode('utf-8') + pretty_xml = parseString(xml).toprettyxml() + + if pretty_xml.startswith('') + 1 + pretty_xml = pretty_xml[end_index:].lstrip() + + # 替换所有 XML 转义字符 + escape_map = { + '<': '<', + '>': '>', + '&': '&', + '"': '"', + ''': "'" + } + for escaped, original in escape_map.items(): + pretty_xml = pretty_xml.replace(escaped, original) + + return pretty_xml + + +def get_terminology_template(session: SessionDep, question: str, oid: Optional[int] = 1) -> str: + if not oid: + oid = 1 + _results = select_terminology_by_word(session, question, oid) + if _results and len(_results) > 0: + terminology = to_xml_string(_results) + template = get_base_terminology_template().format(terminologies=terminology) + return template + else: + return ''