Add File
This commit is contained in:
430
backend/apps/terminology/curd/terminology.py
Normal file
430
backend/apps/terminology/curd/terminology.py
Normal file
@@ -0,0 +1,430 @@
|
||||
import datetime
|
||||
import logging
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Optional
|
||||
from xml.dom.minidom import parseString
|
||||
|
||||
import dicttoxml
|
||||
from sqlalchemy import and_, or_, select, func, delete, update, union
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from apps.ai_model.embedding import EmbeddingModelCache
|
||||
from apps.template.generate_chart.generator import get_base_terminology_template
|
||||
from apps.terminology.models.terminology_model import Terminology, TerminologyInfo
|
||||
from common.core.config import settings
|
||||
from common.core.deps import SessionDep, Trans
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=200)
|
||||
|
||||
|
||||
def page_terminology(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None,
|
||||
oid: Optional[int] = 1):
|
||||
_list: List[TerminologyInfo] = []
|
||||
|
||||
child = aliased(Terminology)
|
||||
|
||||
current_page = max(1, current_page)
|
||||
page_size = max(10, page_size)
|
||||
|
||||
total_count = 0
|
||||
total_pages = 0
|
||||
|
||||
if name and name.strip() != "":
|
||||
keyword_pattern = f"%{name.strip()}%"
|
||||
# 步骤1:先找到所有匹配的节点ID(无论是父节点还是子节点)
|
||||
matched_ids_subquery = (
|
||||
select(Terminology.id)
|
||||
.where(and_(Terminology.word.ilike(keyword_pattern), Terminology.oid == oid)) # LIKE查询条件
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# 步骤2:找到这些匹配节点的所有父节点(包括自身如果是父节点)
|
||||
parent_ids_subquery = (
|
||||
select(Terminology.id)
|
||||
.where(
|
||||
(Terminology.id.in_(matched_ids_subquery)) |
|
||||
(Terminology.id.in_(
|
||||
select(Terminology.pid)
|
||||
.where(Terminology.id.in_(matched_ids_subquery))
|
||||
.where(Terminology.pid.isnot(None))
|
||||
))
|
||||
)
|
||||
.where(Terminology.pid.is_(None)) # 只取父节点
|
||||
)
|
||||
|
||||
count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery())
|
||||
total_count = session.execute(count_stmt).scalar()
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
if current_page > total_pages:
|
||||
current_page = 1
|
||||
|
||||
# 步骤3:获取分页后的父节点ID
|
||||
paginated_parent_ids = (
|
||||
parent_ids_subquery
|
||||
.order_by(Terminology.create_time.desc())
|
||||
.offset((current_page - 1) * page_size)
|
||||
.limit(page_size)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# 步骤4:获取这些父节点的childrenNames
|
||||
children_subquery = (
|
||||
select(
|
||||
child.pid,
|
||||
func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words')
|
||||
)
|
||||
.where(child.pid.isnot(None))
|
||||
.group_by(child.pid)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# 主查询
|
||||
stmt = (
|
||||
select(
|
||||
Terminology.id,
|
||||
Terminology.word,
|
||||
Terminology.create_time,
|
||||
Terminology.description,
|
||||
children_subquery.c.other_words
|
||||
)
|
||||
.outerjoin(
|
||||
children_subquery,
|
||||
Terminology.id == children_subquery.c.pid
|
||||
)
|
||||
.where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid))
|
||||
.order_by(Terminology.create_time.desc())
|
||||
)
|
||||
else:
|
||||
parent_ids_subquery = (
|
||||
select(Terminology.id)
|
||||
.where(and_(Terminology.pid.is_(None), Terminology.oid == oid)) # 只取父节点
|
||||
)
|
||||
count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery())
|
||||
total_count = session.execute(count_stmt).scalar()
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
if current_page > total_pages:
|
||||
current_page = 1
|
||||
|
||||
paginated_parent_ids = (
|
||||
parent_ids_subquery
|
||||
.order_by(Terminology.create_time.desc())
|
||||
.offset((current_page - 1) * page_size)
|
||||
.limit(page_size)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
Terminology.id,
|
||||
Terminology.word,
|
||||
Terminology.create_time,
|
||||
Terminology.description,
|
||||
func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words')
|
||||
)
|
||||
.outerjoin(child, and_(Terminology.id == child.pid))
|
||||
.where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid))
|
||||
.group_by(Terminology.id, Terminology.word)
|
||||
.order_by(Terminology.create_time.desc())
|
||||
)
|
||||
|
||||
result = session.execute(stmt)
|
||||
|
||||
for row in result:
|
||||
_list.append(TerminologyInfo(
|
||||
id=row.id,
|
||||
word=row.word,
|
||||
create_time=row.create_time,
|
||||
description=row.description,
|
||||
other_words=row.other_words if row.other_words else [],
|
||||
))
|
||||
|
||||
return current_page, page_size, total_count, total_pages, _list
|
||||
|
||||
|
||||
def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans):
|
||||
create_time = datetime.datetime.now()
|
||||
parent = Terminology(word=info.word, create_time=create_time, description=info.description, oid=oid)
|
||||
|
||||
words = [info.word]
|
||||
for child in info.other_words:
|
||||
if child in words:
|
||||
raise Exception(trans("i18n_terminology.cannot_be_repeated"))
|
||||
else:
|
||||
words.append(child)
|
||||
|
||||
exists = session.query(
|
||||
session.query(Terminology).filter(and_(Terminology.word.in_(words), Terminology.oid == oid)).exists()).scalar()
|
||||
if exists:
|
||||
raise Exception(trans("i18n_terminology.exists_in_db"))
|
||||
|
||||
result = Terminology(**parent.model_dump())
|
||||
|
||||
session.add(parent)
|
||||
session.flush()
|
||||
session.refresh(parent)
|
||||
|
||||
result.id = parent.id
|
||||
session.commit()
|
||||
|
||||
_list: List[Terminology] = []
|
||||
if info.other_words:
|
||||
for other_word in info.other_words:
|
||||
if other_word.strip() == "":
|
||||
continue
|
||||
_list.append(
|
||||
Terminology(pid=result.id, word=other_word, create_time=create_time, oid=oid))
|
||||
session.bulk_save_objects(_list)
|
||||
session.flush()
|
||||
session.commit()
|
||||
|
||||
# embedding
|
||||
run_save_embeddings([result.id])
|
||||
|
||||
return result.id
|
||||
|
||||
|
||||
def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans):
|
||||
count = session.query(Terminology).filter(
|
||||
Terminology.oid == oid,
|
||||
Terminology.id == info.id
|
||||
).count()
|
||||
if count == 0:
|
||||
raise Exception(trans('i18n_terminology.terminology_not_exists'))
|
||||
|
||||
words = [info.word]
|
||||
for child in info.other_words:
|
||||
if child in words:
|
||||
raise Exception(trans("i18n_terminology.cannot_be_repeated"))
|
||||
else:
|
||||
words.append(child)
|
||||
|
||||
exists = session.query(
|
||||
session.query(Terminology).filter(
|
||||
Terminology.word.in_(words),
|
||||
Terminology.oid == oid,
|
||||
or_(
|
||||
Terminology.pid != info.id,
|
||||
and_(Terminology.pid.is_(None), Terminology.id != info.id)
|
||||
),
|
||||
Terminology.id != info.id
|
||||
).exists()).scalar()
|
||||
if exists:
|
||||
raise Exception(trans("i18n_terminology.exists_in_db"))
|
||||
|
||||
stmt = update(Terminology).where(and_(Terminology.id == info.id)).values(
|
||||
word=info.word,
|
||||
description=info.description,
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
stmt = delete(Terminology).where(and_(Terminology.pid == info.id))
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
create_time = datetime.datetime.now()
|
||||
_list: List[Terminology] = []
|
||||
if info.other_words:
|
||||
for other_word in info.other_words:
|
||||
if other_word.strip() == "":
|
||||
continue
|
||||
_list.append(
|
||||
Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid))
|
||||
session.bulk_save_objects(_list)
|
||||
session.flush()
|
||||
session.commit()
|
||||
|
||||
# embedding
|
||||
run_save_embeddings([info.id])
|
||||
|
||||
return info.id
|
||||
|
||||
|
||||
def delete_terminology(session: SessionDep, ids: list[int]):
|
||||
stmt = delete(Terminology).where(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids)))
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
|
||||
def run_save_embeddings(ids: List[int]):
|
||||
executor.submit(save_embeddings, ids)
|
||||
|
||||
|
||||
def fill_empty_embeddings():
|
||||
executor.submit(run_fill_empty_embeddings)
|
||||
|
||||
|
||||
def run_fill_empty_embeddings():
|
||||
if not settings.EMBEDDING_ENABLED:
|
||||
return
|
||||
engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
|
||||
session_maker = sessionmaker(bind=engine)
|
||||
session = session_maker()
|
||||
stmt1 = select(Terminology.id).where(and_(Terminology.embedding.is_(None), Terminology.pid.is_(None)))
|
||||
stmt2 = select(Terminology.pid).where(and_(Terminology.embedding.is_(None), Terminology.pid.isnot(None))).distinct()
|
||||
combined_stmt = union(stmt1, stmt2)
|
||||
results = session.execute(combined_stmt).scalars().all()
|
||||
save_embeddings(results)
|
||||
|
||||
|
||||
def save_embeddings(ids: List[int]):
|
||||
if not settings.EMBEDDING_ENABLED:
|
||||
return
|
||||
|
||||
if not ids or len(ids) == 0:
|
||||
return
|
||||
try:
|
||||
engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
|
||||
session_maker = sessionmaker(bind=engine)
|
||||
session = session_maker()
|
||||
|
||||
_list = session.query(Terminology).filter(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids))).all()
|
||||
|
||||
_words_list = [item.word for item in _list]
|
||||
|
||||
model = EmbeddingModelCache.get_model()
|
||||
|
||||
results = model.embed_documents(_words_list)
|
||||
|
||||
for index in range(len(results)):
|
||||
item = results[index]
|
||||
stmt = update(Terminology).where(and_(Terminology.id == _list[index].id)).values(embedding=item)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
embedding_sql = f"""
|
||||
SELECT id, pid, word, similarity
|
||||
FROM
|
||||
(SELECT id, pid, word, oid,
|
||||
( 1 - (embedding <=> :embedding_array) ) AS similarity
|
||||
FROM terminology AS child
|
||||
) TEMP
|
||||
WHERE similarity > {settings.EMBEDDING_SIMILARITY} and oid = :oid
|
||||
ORDER BY similarity DESC
|
||||
LIMIT {settings.EMBEDDING_TOP_COUNT}
|
||||
"""
|
||||
|
||||
|
||||
def select_terminology_by_word(session: SessionDep, word: str, oid: int):
|
||||
if word.strip() == "":
|
||||
return []
|
||||
|
||||
_list: List[Terminology] = []
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
Terminology.id,
|
||||
Terminology.pid,
|
||||
Terminology.word,
|
||||
)
|
||||
.where(
|
||||
and_(text(":sentence ILIKE '%' || word || '%'"), Terminology.oid == oid)
|
||||
)
|
||||
)
|
||||
|
||||
results = session.execute(stmt, {'sentence': word}).fetchall()
|
||||
|
||||
for row in results:
|
||||
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid))
|
||||
|
||||
if settings.EMBEDDING_ENABLED:
|
||||
try:
|
||||
model = EmbeddingModelCache.get_model()
|
||||
|
||||
embedding = model.embed_query(word)
|
||||
|
||||
results = session.execute(text(embedding_sql), {'embedding_array': str(embedding), 'oid': oid})
|
||||
|
||||
for row in results:
|
||||
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid))
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
_map: dict = {}
|
||||
_ids: list[int] = []
|
||||
for row in _list:
|
||||
if row.id in _ids or row.pid in _ids:
|
||||
continue
|
||||
if row.pid is not None:
|
||||
_ids.append(row.pid)
|
||||
else:
|
||||
_ids.append(row.id)
|
||||
|
||||
if len(_ids) == 0:
|
||||
return []
|
||||
|
||||
t_list = session.query(Terminology.id, Terminology.pid, Terminology.word, Terminology.description).filter(
|
||||
or_(Terminology.id.in_(_ids), Terminology.pid.in_(_ids))).all()
|
||||
for row in t_list:
|
||||
pid = str(row.pid) if row.pid is not None else str(row.id)
|
||||
if _map.get(pid) is None:
|
||||
_map[pid] = {'words': [], 'description': row.description}
|
||||
_map[pid]['words'].append(row.word)
|
||||
|
||||
_results: list[dict] = []
|
||||
for key in _map.keys():
|
||||
_results.append(_map.get(key))
|
||||
|
||||
return _results
|
||||
|
||||
|
||||
def get_example():
|
||||
_obj = {
|
||||
'terminologies': [
|
||||
{'words': ['GDP', '国内生产总值'],
|
||||
'description': '指在一个季度或一年,一个国家或地区的经济中所生产出的全部最终产品和劳务的价值。'},
|
||||
]
|
||||
}
|
||||
return to_xml_string(_obj, 'example')
|
||||
|
||||
|
||||
def to_xml_string(_dict: list[dict] | dict, root: str = 'terminologies') -> str:
|
||||
item_name_func = lambda x: 'terminology' if x == 'terminologies' else 'word' if x == 'words' else 'item'
|
||||
dicttoxml.LOG.setLevel(logging.ERROR)
|
||||
xml = dicttoxml.dicttoxml(_dict,
|
||||
cdata=['word', 'description'],
|
||||
custom_root=root,
|
||||
item_func=item_name_func,
|
||||
xml_declaration=False,
|
||||
encoding='utf-8',
|
||||
attr_type=False).decode('utf-8')
|
||||
pretty_xml = parseString(xml).toprettyxml()
|
||||
|
||||
if pretty_xml.startswith('<?xml'):
|
||||
end_index = pretty_xml.find('>') + 1
|
||||
pretty_xml = pretty_xml[end_index:].lstrip()
|
||||
|
||||
# 替换所有 XML 转义字符
|
||||
escape_map = {
|
||||
'<': '<',
|
||||
'>': '>',
|
||||
'&': '&',
|
||||
'"': '"',
|
||||
''': "'"
|
||||
}
|
||||
for escaped, original in escape_map.items():
|
||||
pretty_xml = pretty_xml.replace(escaped, original)
|
||||
|
||||
return pretty_xml
|
||||
|
||||
|
||||
def get_terminology_template(session: SessionDep, question: str, oid: Optional[int] = 1) -> str:
|
||||
if not oid:
|
||||
oid = 1
|
||||
_results = select_terminology_by_word(session, question, oid)
|
||||
if _results and len(_results) > 0:
|
||||
terminology = to_xml_string(_results)
|
||||
template = get_base_terminology_template().format(terminologies=terminology)
|
||||
return template
|
||||
else:
|
||||
return ''
|
||||
Reference in New Issue
Block a user