431 lines
14 KiB
Python
431 lines
14 KiB
Python
import datetime
|
||
import logging
|
||
import traceback
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from typing import List, Optional
|
||
from xml.dom.minidom import parseString
|
||
|
||
import dicttoxml
|
||
from sqlalchemy import and_, or_, select, func, delete, update, union
|
||
from sqlalchemy import create_engine, text
|
||
from sqlalchemy.orm import aliased
|
||
from sqlalchemy.orm import sessionmaker
|
||
|
||
from apps.ai_model.embedding import EmbeddingModelCache
|
||
from apps.template.generate_chart.generator import get_base_terminology_template
|
||
from apps.terminology.models.terminology_model import Terminology, TerminologyInfo
|
||
from common.core.config import settings
|
||
from common.core.deps import SessionDep, Trans
|
||
|
||
executor = ThreadPoolExecutor(max_workers=200)
|
||
|
||
|
||
def page_terminology(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None,
|
||
oid: Optional[int] = 1):
|
||
_list: List[TerminologyInfo] = []
|
||
|
||
child = aliased(Terminology)
|
||
|
||
current_page = max(1, current_page)
|
||
page_size = max(10, page_size)
|
||
|
||
total_count = 0
|
||
total_pages = 0
|
||
|
||
if name and name.strip() != "":
|
||
keyword_pattern = f"%{name.strip()}%"
|
||
# 步骤1:先找到所有匹配的节点ID(无论是父节点还是子节点)
|
||
matched_ids_subquery = (
|
||
select(Terminology.id)
|
||
.where(and_(Terminology.word.ilike(keyword_pattern), Terminology.oid == oid)) # LIKE查询条件
|
||
.subquery()
|
||
)
|
||
|
||
# 步骤2:找到这些匹配节点的所有父节点(包括自身如果是父节点)
|
||
parent_ids_subquery = (
|
||
select(Terminology.id)
|
||
.where(
|
||
(Terminology.id.in_(matched_ids_subquery)) |
|
||
(Terminology.id.in_(
|
||
select(Terminology.pid)
|
||
.where(Terminology.id.in_(matched_ids_subquery))
|
||
.where(Terminology.pid.isnot(None))
|
||
))
|
||
)
|
||
.where(Terminology.pid.is_(None)) # 只取父节点
|
||
)
|
||
|
||
count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery())
|
||
total_count = session.execute(count_stmt).scalar()
|
||
total_pages = (total_count + page_size - 1) // page_size
|
||
|
||
if current_page > total_pages:
|
||
current_page = 1
|
||
|
||
# 步骤3:获取分页后的父节点ID
|
||
paginated_parent_ids = (
|
||
parent_ids_subquery
|
||
.order_by(Terminology.create_time.desc())
|
||
.offset((current_page - 1) * page_size)
|
||
.limit(page_size)
|
||
.subquery()
|
||
)
|
||
|
||
# 步骤4:获取这些父节点的childrenNames
|
||
children_subquery = (
|
||
select(
|
||
child.pid,
|
||
func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words')
|
||
)
|
||
.where(child.pid.isnot(None))
|
||
.group_by(child.pid)
|
||
.subquery()
|
||
)
|
||
|
||
# 主查询
|
||
stmt = (
|
||
select(
|
||
Terminology.id,
|
||
Terminology.word,
|
||
Terminology.create_time,
|
||
Terminology.description,
|
||
children_subquery.c.other_words
|
||
)
|
||
.outerjoin(
|
||
children_subquery,
|
||
Terminology.id == children_subquery.c.pid
|
||
)
|
||
.where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid))
|
||
.order_by(Terminology.create_time.desc())
|
||
)
|
||
else:
|
||
parent_ids_subquery = (
|
||
select(Terminology.id)
|
||
.where(and_(Terminology.pid.is_(None), Terminology.oid == oid)) # 只取父节点
|
||
)
|
||
count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery())
|
||
total_count = session.execute(count_stmt).scalar()
|
||
total_pages = (total_count + page_size - 1) // page_size
|
||
|
||
if current_page > total_pages:
|
||
current_page = 1
|
||
|
||
paginated_parent_ids = (
|
||
parent_ids_subquery
|
||
.order_by(Terminology.create_time.desc())
|
||
.offset((current_page - 1) * page_size)
|
||
.limit(page_size)
|
||
.subquery()
|
||
)
|
||
|
||
stmt = (
|
||
select(
|
||
Terminology.id,
|
||
Terminology.word,
|
||
Terminology.create_time,
|
||
Terminology.description,
|
||
func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words')
|
||
)
|
||
.outerjoin(child, and_(Terminology.id == child.pid))
|
||
.where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid))
|
||
.group_by(Terminology.id, Terminology.word)
|
||
.order_by(Terminology.create_time.desc())
|
||
)
|
||
|
||
result = session.execute(stmt)
|
||
|
||
for row in result:
|
||
_list.append(TerminologyInfo(
|
||
id=row.id,
|
||
word=row.word,
|
||
create_time=row.create_time,
|
||
description=row.description,
|
||
other_words=row.other_words if row.other_words else [],
|
||
))
|
||
|
||
return current_page, page_size, total_count, total_pages, _list
|
||
|
||
|
||
def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans):
|
||
create_time = datetime.datetime.now()
|
||
parent = Terminology(word=info.word, create_time=create_time, description=info.description, oid=oid)
|
||
|
||
words = [info.word]
|
||
for child in info.other_words:
|
||
if child in words:
|
||
raise Exception(trans("i18n_terminology.cannot_be_repeated"))
|
||
else:
|
||
words.append(child)
|
||
|
||
exists = session.query(
|
||
session.query(Terminology).filter(and_(Terminology.word.in_(words), Terminology.oid == oid)).exists()).scalar()
|
||
if exists:
|
||
raise Exception(trans("i18n_terminology.exists_in_db"))
|
||
|
||
result = Terminology(**parent.model_dump())
|
||
|
||
session.add(parent)
|
||
session.flush()
|
||
session.refresh(parent)
|
||
|
||
result.id = parent.id
|
||
session.commit()
|
||
|
||
_list: List[Terminology] = []
|
||
if info.other_words:
|
||
for other_word in info.other_words:
|
||
if other_word.strip() == "":
|
||
continue
|
||
_list.append(
|
||
Terminology(pid=result.id, word=other_word, create_time=create_time, oid=oid))
|
||
session.bulk_save_objects(_list)
|
||
session.flush()
|
||
session.commit()
|
||
|
||
# embedding
|
||
run_save_embeddings([result.id])
|
||
|
||
return result.id
|
||
|
||
|
||
def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans):
|
||
count = session.query(Terminology).filter(
|
||
Terminology.oid == oid,
|
||
Terminology.id == info.id
|
||
).count()
|
||
if count == 0:
|
||
raise Exception(trans('i18n_terminology.terminology_not_exists'))
|
||
|
||
words = [info.word]
|
||
for child in info.other_words:
|
||
if child in words:
|
||
raise Exception(trans("i18n_terminology.cannot_be_repeated"))
|
||
else:
|
||
words.append(child)
|
||
|
||
exists = session.query(
|
||
session.query(Terminology).filter(
|
||
Terminology.word.in_(words),
|
||
Terminology.oid == oid,
|
||
or_(
|
||
Terminology.pid != info.id,
|
||
and_(Terminology.pid.is_(None), Terminology.id != info.id)
|
||
),
|
||
Terminology.id != info.id
|
||
).exists()).scalar()
|
||
if exists:
|
||
raise Exception(trans("i18n_terminology.exists_in_db"))
|
||
|
||
stmt = update(Terminology).where(and_(Terminology.id == info.id)).values(
|
||
word=info.word,
|
||
description=info.description,
|
||
)
|
||
session.execute(stmt)
|
||
session.commit()
|
||
|
||
stmt = delete(Terminology).where(and_(Terminology.pid == info.id))
|
||
session.execute(stmt)
|
||
session.commit()
|
||
|
||
create_time = datetime.datetime.now()
|
||
_list: List[Terminology] = []
|
||
if info.other_words:
|
||
for other_word in info.other_words:
|
||
if other_word.strip() == "":
|
||
continue
|
||
_list.append(
|
||
Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid))
|
||
session.bulk_save_objects(_list)
|
||
session.flush()
|
||
session.commit()
|
||
|
||
# embedding
|
||
run_save_embeddings([info.id])
|
||
|
||
return info.id
|
||
|
||
|
||
def delete_terminology(session: SessionDep, ids: list[int]):
|
||
stmt = delete(Terminology).where(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids)))
|
||
session.execute(stmt)
|
||
session.commit()
|
||
|
||
|
||
def run_save_embeddings(ids: List[int]):
|
||
executor.submit(save_embeddings, ids)
|
||
|
||
|
||
def fill_empty_embeddings():
|
||
executor.submit(run_fill_empty_embeddings)
|
||
|
||
|
||
def run_fill_empty_embeddings():
|
||
if not settings.EMBEDDING_ENABLED:
|
||
return
|
||
engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
|
||
session_maker = sessionmaker(bind=engine)
|
||
session = session_maker()
|
||
stmt1 = select(Terminology.id).where(and_(Terminology.embedding.is_(None), Terminology.pid.is_(None)))
|
||
stmt2 = select(Terminology.pid).where(and_(Terminology.embedding.is_(None), Terminology.pid.isnot(None))).distinct()
|
||
combined_stmt = union(stmt1, stmt2)
|
||
results = session.execute(combined_stmt).scalars().all()
|
||
save_embeddings(results)
|
||
|
||
|
||
def save_embeddings(ids: List[int]):
|
||
if not settings.EMBEDDING_ENABLED:
|
||
return
|
||
|
||
if not ids or len(ids) == 0:
|
||
return
|
||
try:
|
||
engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
|
||
session_maker = sessionmaker(bind=engine)
|
||
session = session_maker()
|
||
|
||
_list = session.query(Terminology).filter(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids))).all()
|
||
|
||
_words_list = [item.word for item in _list]
|
||
|
||
model = EmbeddingModelCache.get_model()
|
||
|
||
results = model.embed_documents(_words_list)
|
||
|
||
for index in range(len(results)):
|
||
item = results[index]
|
||
stmt = update(Terminology).where(and_(Terminology.id == _list[index].id)).values(embedding=item)
|
||
session.execute(stmt)
|
||
session.commit()
|
||
|
||
except Exception:
|
||
traceback.print_exc()
|
||
|
||
|
||
embedding_sql = f"""
|
||
SELECT id, pid, word, similarity
|
||
FROM
|
||
(SELECT id, pid, word, oid,
|
||
( 1 - (embedding <=> :embedding_array) ) AS similarity
|
||
FROM terminology AS child
|
||
) TEMP
|
||
WHERE similarity > {settings.EMBEDDING_SIMILARITY} and oid = :oid
|
||
ORDER BY similarity DESC
|
||
LIMIT {settings.EMBEDDING_TOP_COUNT}
|
||
"""
|
||
|
||
|
||
def select_terminology_by_word(session: SessionDep, word: str, oid: int):
|
||
if word.strip() == "":
|
||
return []
|
||
|
||
_list: List[Terminology] = []
|
||
|
||
stmt = (
|
||
select(
|
||
Terminology.id,
|
||
Terminology.pid,
|
||
Terminology.word,
|
||
)
|
||
.where(
|
||
and_(text(":sentence ILIKE '%' || word || '%'"), Terminology.oid == oid)
|
||
)
|
||
)
|
||
|
||
results = session.execute(stmt, {'sentence': word}).fetchall()
|
||
|
||
for row in results:
|
||
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid))
|
||
|
||
if settings.EMBEDDING_ENABLED:
|
||
try:
|
||
model = EmbeddingModelCache.get_model()
|
||
|
||
embedding = model.embed_query(word)
|
||
|
||
results = session.execute(text(embedding_sql), {'embedding_array': str(embedding), 'oid': oid})
|
||
|
||
for row in results:
|
||
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid))
|
||
|
||
except Exception:
|
||
traceback.print_exc()
|
||
|
||
_map: dict = {}
|
||
_ids: list[int] = []
|
||
for row in _list:
|
||
if row.id in _ids or row.pid in _ids:
|
||
continue
|
||
if row.pid is not None:
|
||
_ids.append(row.pid)
|
||
else:
|
||
_ids.append(row.id)
|
||
|
||
if len(_ids) == 0:
|
||
return []
|
||
|
||
t_list = session.query(Terminology.id, Terminology.pid, Terminology.word, Terminology.description).filter(
|
||
or_(Terminology.id.in_(_ids), Terminology.pid.in_(_ids))).all()
|
||
for row in t_list:
|
||
pid = str(row.pid) if row.pid is not None else str(row.id)
|
||
if _map.get(pid) is None:
|
||
_map[pid] = {'words': [], 'description': row.description}
|
||
_map[pid]['words'].append(row.word)
|
||
|
||
_results: list[dict] = []
|
||
for key in _map.keys():
|
||
_results.append(_map.get(key))
|
||
|
||
return _results
|
||
|
||
|
||
def get_example():
|
||
_obj = {
|
||
'terminologies': [
|
||
{'words': ['GDP', '国内生产总值'],
|
||
'description': '指在一个季度或一年,一个国家或地区的经济中所生产出的全部最终产品和劳务的价值。'},
|
||
]
|
||
}
|
||
return to_xml_string(_obj, 'example')
|
||
|
||
|
||
def to_xml_string(_dict: list[dict] | dict, root: str = 'terminologies') -> str:
|
||
item_name_func = lambda x: 'terminology' if x == 'terminologies' else 'word' if x == 'words' else 'item'
|
||
dicttoxml.LOG.setLevel(logging.ERROR)
|
||
xml = dicttoxml.dicttoxml(_dict,
|
||
cdata=['word', 'description'],
|
||
custom_root=root,
|
||
item_func=item_name_func,
|
||
xml_declaration=False,
|
||
encoding='utf-8',
|
||
attr_type=False).decode('utf-8')
|
||
pretty_xml = parseString(xml).toprettyxml()
|
||
|
||
if pretty_xml.startswith('<?xml'):
|
||
end_index = pretty_xml.find('>') + 1
|
||
pretty_xml = pretty_xml[end_index:].lstrip()
|
||
|
||
# 替换所有 XML 转义字符
|
||
escape_map = {
|
||
'<': '<',
|
||
'>': '>',
|
||
'&': '&',
|
||
'"': '"',
|
||
''': "'"
|
||
}
|
||
for escaped, original in escape_map.items():
|
||
pretty_xml = pretty_xml.replace(escaped, original)
|
||
|
||
return pretty_xml
|
||
|
||
|
||
def get_terminology_template(session: SessionDep, question: str, oid: Optional[int] = 1) -> str:
|
||
if not oid:
|
||
oid = 1
|
||
_results = select_terminology_by_word(session, question, oid)
|
||
if _results and len(_results) > 0:
|
||
terminology = to_xml_string(_results)
|
||
template = get_base_terminology_template().format(terminologies=terminology)
|
||
return template
|
||
else:
|
||
return ''
|