diff --git a/main/da_chat.py b/main/da_chat.py new file mode 100644 index 0000000..f0e0c60 --- /dev/null +++ b/main/da_chat.py @@ -0,0 +1,625 @@ +import os +import json5 +import json +from typing import Optional + +from qwen_agent.agents import Assistant, Router +from typing import Dict,List,Union,Iterator,Optional +from qwen_agent.llm.schema import Message,ROLE,USER,ASSISTANT +from qwen_agent.tools.base import BaseTool, register_tool +import time + +try: + import polars as pl +except: + pass +from init import llm,gcfg + +class Nl2SQLAgent(Assistant): + """This is an agent for nl to sql + + 根据自然语言的描述生成SQL语句,并执行返回结果。 + 使用system作为模板 + """ + def __init__(self, + function_list: Optional[List[Union[str, Dict, BaseTool]]] = [], + llm= None, + system_message: Optional[str] = None, + name: Optional[str] = None, + description: Optional[str] = None, + files: Optional[List[str]] = None, + rag_cfg: Optional[Dict] = None): + self.default_sql_tool_name="do_sql" + if self.default_sql_tool_name not in function_list: + function_list=[self.default_sql_tool_name] + (function_list) + super().__init__(function_list=function_list, + llm=llm, + system_message=system_message, + name=name, + description=description, + files=files, + rag_cfg=rag_cfg) + + +# Add a custom tool named do_py_eval +@register_tool('do_py_eval') +class DO_PY_EVAL(BaseTool): + description = '数值计算工具,如: 资产负债比率=241986307.9 / 666947571.54' + parameters = [{ + 'name': 'name', + 'type': 'string', + 'description': '指标名称,如资产负债比率', + 'required': True, + }, + { + 'name': 'expr', + 'type': 'string', + 'description': '数值计算的内容,如241986307.9 / 666947571.54', + 'required': True, + } + ] + + def call(self, params: str, **kwargs) -> str: + name = json5.loads(params)['name'] + expr = json5.loads(params)['expr'] + expr = expr.strip() + if expr.find(":") >0: + expr = expr[0:expr.find(":")] + elif expr.startswith("do_py_eval"): + print("do_py_eval:",expr) + d = eval(expr) + else: + print("do_eval:",expr) + import sympy + try: + result = sympy.sympify(expr) + d={name:float(round(result,4))} + except Exception as e: + d={name:f"计算出现错误{e}","数值计算式":expr} + return json.dumps(d,ensure_ascii=False) + + +# Add a custom tool named do_py_eval +@register_tool('do_desc_excel') +class DO_desc_excel(BaseTool): + description = '简要分析这个excel的数据,给出简要的数据信息,包括行数,列数,列名和列类型以及字符串列的典型的值' + parameters = [{ + 'name': 'file_path', + 'type': 'string', + 'description': '文件的路径', + 'required': True, + } + ] + + def call(self, params: str, **kwargs) -> str: + file_path = json5.loads(params)['file_path'] + df = pl.read_excel(file_path) + + # 获取 DataFrame 中所有字符串列的名称 + str_columns = [col for col in df.columns if df[col].dtype == pl.Utf8] + + # 对每个字符串列进行去重,并打印结果 + unique_values_per_column = {col: df.select(pl.col(col).unique()).to_series().to_list() for col in str_columns} + + dict_string={} + for column, unique_values in unique_values_per_column.items(): + #print(f"Column '{column}' unique values: {unique_values}") + if len(unique_values) <20: + dict_string[column]=unique_values + + print(df.width,df.height,df.columns,df.dtypes) + + d={"列数":df.width,"行数":df.height,"列名":df.columns,"列类型":[str(t) for t in df.dtypes],"字符串列的典型值":dict_string} + + return json.dumps(d,ensure_ascii=False) + + +# Add a custom tool named do_py_eval +@register_tool('do_gen_excel') +class DO_gen_excel(BaseTool): + description = '使用json数据,生成一个新的excel文件' + parameters = [ + { + 'name': 'json_data', + 'type': 'string', + 'description': 'json数据', + 'required': True, + } + ] + + def call(self, params: str, **kwargs) -> str: + json_data = json5.loads(params)['json_data'] + # 解析 JSON 字符串为 Python 对象 + data = json5.loads(json_data) + + # 使用 Polars 创建 DataFrame + df = pl.DataFrame(data) + + file_path =f'{gcfg["fs"]["path"]}/pub/{time.time()}.xlsx' + + df.write_excel(file_path) + + print(df.width,df.height,df.columns,df.dtypes) + d={"文件":file_path} + + return json.dumps(d,ensure_ascii=False) + + + +#根据数据生成excel文件 +def gen_data_to_excel(path,data,desc): + #path为源文件的路径, + from pathlib import Path + src_path = Path(path) + #不带后缀的部分 + base = src_path.stem + # 使用 Polars 创建 DataFrame + df = pl.DataFrame(data) + + file_path =f'{gcfg["fs"]["path"]}/pub/{base}_{desc}_{int(time.time())}.xlsx' + + df.write_excel(file_path) + return file_path + + +# Add a custom tool named do_py_eval +@register_tool('do_group_count') +class DO_group_count(BaseTool): + description = '对列的分组统计和记录数求和,支持多个分组字段,适用于做简单的数据数量分布' + parameters = [{ + 'name': 'file_path', + 'type': 'string', + 'description': '文件的路径', + 'required': True, + }, + { + 'name': 'group_fields', + 'type': 'string', + 'description': '分组字段,多个字段用,隔开,支持tags列', + 'required': True, + }, + { + 'name': 'gen_excel', + 'type': 'bool', + 'description': '是否将结果,生成新的excel文件,默认为False', + 'required': False, + } + ] + + def call(self, params: str, **kwargs) -> str: + params = json5.loads(params) + file_path = params['file_path'] + group_fields = params['group_fields'] + gen_excel = False + if "gen_excel" in params: + gen_excel = params["gen_excel"] + + group_fields = group_fields.split(",") + dataset = pl.read_excel(file_path) + + # 动态处理所有字符串列为空的情况 + #string_columns = [column for column, dtype in df.dtypes.items() if dtype == pl.Utf8] + + # 动态处理group 字段的空值问题 + # dataset = dataset.with_columns( + # [ + # pl.col(col).apply(lambda x: None if x == "" else x).alias(col) + # for col in group_fields + # ] + dataset = dataset.with_columns( + [ + pl.col(col).fill_null("未知") for col in group_fields + ] + ) + + if 'tags' in group_fields: + if dataset['tags'][0].startswith("[") and dataset['tags'][0].endswith("]"): + # 将字符串还原为list对象 + dataset= dataset.with_columns([ + # pl.col(group_fields[0]) + # .map_elements(lambda x:json.loads(x), return_dtype=pl.List(pl.Utf8)) # 先还原为 list + + pl.col('tags').str.json_decode() + + ]) + + if 'tags' in group_fields and dataset['tags'].dtype==pl.List: + #print("多标签分组") + #list 类型的字段,进行展开统计 + df_exploded = dataset.explode('tags') + #print(df_exploded.width,df_exploded.height,df_exploded.columns,df_exploded.dtypes) + + # 对展开后的DataFrame按"tags"列进行分组,并统计每个标签出现的次数 + q = (df_exploded.lazy().group_by(group_fields) + .agg( + pl.count().alias("count") + ) + .limit(100) + .sort("count", descending=True) + ) + else: + # group_by + q = ( + dataset.lazy() + .group_by(group_fields) + .agg( + pl.len().alias('count'), + ) + .limit(100) + .sort("count", descending=True) + ) + df = q.collect() + + if gen_excel: + new_file_path = gen_data_to_excel(file_path,df.to_dicts(),"分组数量") + return {"excel":new_file_path,"data":df.write_json(),"cols":df.columns} + else: + return {"data":df.write_json(),"cols":df.columns} + + +# Add a custom tool named do_py_eval +@register_tool('do_group_agg') +class DO_group_agg(BaseTool): + description = '对一个列或多个列的分组后,对另一列求累计汇总、平均值、数量、中位数、最小值、最大值6个指标的统计分析' + parameters = [{ + 'name': 'file_path', + 'type': 'string', + 'description': '文件的路径', + 'required': True, + }, + { + 'name': 'group_fields', + 'type': 'string', + 'description': '分组字段,不支持tags列', + 'required': True, + }, + { + 'name': 'agg_field', + 'type': 'string', + 'description': '聚合字段', + 'required': True, + }, + { + 'name': 'gen_excel', + 'type': 'bool', + 'description': '是否将结果,生成新的excel文件,默认为False', + 'required': False, + } + ] + + def call(self, params: str, **kwargs) -> str: + params = json5.loads(params) + file_path = params['file_path'] + group_fields = params['group_fields'] + group_fields = group_fields.split(",") + + sum_field = params['agg_field'] + gen_excel = False + if "gen_excel" in params: + gen_excel = params["gen_excel"] + + + dataset = pl.read_excel(file_path) + + dataset = dataset.with_columns( + [ + pl.col(col).fill_null("未知") for col in group_fields + ] + ) + + # group_by + q = ( + dataset.lazy() + .group_by(group_fields) + .agg( + pl.col(sum_field).sum().alias("total"), + pl.col(sum_field).mean().alias("avg"), + pl.col(sum_field).count().alias("count"), + pl.col(sum_field).median().alias("median"), + pl.col(sum_field).min().alias("min"), + pl.col(sum_field).max().alias("max"), + ) + .limit(100) + .sort("count", descending=True) + ) + + df = q.collect() + + if gen_excel: + new_file_path = gen_data_to_excel(file_path,df.to_dicts(),"分组指标") + return {"excel":new_file_path,"data":df.write_json(),"cols":df.columns,"field":sum_field} + else: + return {"data":df.write_json(),"cols":df.columns,"field":sum_field} + + +# Add a custom tool named do_py_eval +@register_tool('do_fields_agg') +class DO_fields_agg(BaseTool): + description = """使用全部数据,对某些列求累计汇总、平均值、数量、中位数、最小值、最大值6个指标的统计分析, + 不支持tags列 + """ + parameters = [{ + 'name': 'file_path', + 'type': 'string', + 'description': '文件的路径', + 'required': True, + }, + { + 'name': 'agg_fields', + 'type': 'string', + 'description': '聚合字段', + 'required': True, + }, + { + 'name': 'gen_excel', + 'type': 'bool', + 'description': '是否将结果,生成新的excel文件,默认为False', + 'required': False, + } + ] + + def call(self, params: str, **kwargs) -> str: + params = json5.loads(params) + file_path = params['file_path'] + agg_fields = params['agg_fields'] + agg_fields = agg_fields.split(",") + gen_excel = False + if "gen_excel" in params: + gen_excel = params["gen_excel"] + + dataset = pl.read_excel(file_path) + + dfs=[] + for agg_field in agg_fields: + # agg + df = ( + dataset + .select( + pl.col(agg_field).sum().alias("total"), + pl.col(agg_field).mean().alias("avg"), + pl.col(agg_field).count().alias("count"), + pl.col(agg_field).median().alias("median"), + pl.col(agg_field).min().alias("min"), + pl.col(agg_field).max().alias("max"), + ).with_columns( + pl.lit(agg_field).alias("类别") + ) + ) + #调整顺序 + df = df.select("类别","total","avg","count","median","min","max") + dfs.append(df) + + if len(dfs) >=2: + df = pl.concat(dfs, how="vertical") + + if gen_excel: + new_file_path = gen_data_to_excel(file_path,df.to_dicts(),"聚合指标") + return {"excel":new_file_path,"data":df.write_json(),"cols":df.columns,"field":agg_fields} + else: + return {"data":df.write_json(),"cols":df.columns,"field":agg_fields} + + + +# Add a custom tool named do_py_eval +@register_tool('do_filter') +class DO_filter(BaseTool): + description = """根据条件过滤数据,并生成新的数据文件excel供后续分析使用 + + 数据筛选时,使用python polars的filter函数支持的筛选条件进行,如: + 1. 年龄大于20的,filter_expr为 pl.col("年龄") > 20 + 2. 姓张的人, filter_expr为 pl.col("name").str.contains("张") + 3. L1,L2,l3级别的, filter_expr为 pl.col("级别").is_in(["L1","L2","L3"]) + 4. 可以使用 .is_null() 或 .is_not_null() 方法来筛选出包含空值或非空值的行。 + 5. 利用逻辑运算符 &(与)、|(或)、~(非)来组合多个条件,每个条件都要单独使用()包裹,如下所示: + (pl.col("a") > 1) & ~(pl.col("b") == 5) + """ + parameters = [{ + 'name': 'file_path', + 'type': 'string', + 'description': '文件的路径', + 'required': True, + }, + { + 'name': 'filter_expr', + 'type': 'string', + 'description': '数据筛选表达式,如果不用筛选数据,内容为""', + 'required': True, + }, + ] + + def call(self, params: str, **kwargs) -> str: + params = json5.loads(params) + file_path = params['file_path'] + filter_expr = params['filter_expr'] + + + dataset = pl.read_excel(file_path) + + #过滤数据 + if filter_expr !="": + df = dataset.filter(eval(filter_expr)) + else: + df = dataset + + new_file_path = gen_data_to_excel(file_path,df.to_dicts(),"过滤") + return {"excel":new_file_path,"data":df.limit(20).write_json(),"cols":df.columns,"desc":f"数据过滤:{filter_expr}"} + + + +# 定义一个函数来提取文本中的关键字并返回作为标签 +def extract_tags(text, keywords): + # 返回文本中存在的关键字列表 + tags = [keyword for keyword in keywords if keyword in text] + return tags if tags else [text] # 如果没有匹配到任何关键字,则返回最开始的值 + + +# Add a custom tool named do_py_eval +@register_tool('do_ext_tags') +class DO_ext_tags(BaseTool): + description = '根据定义好的关键字列表,对某一文本字段进行标签提取操作,增加一个新的list类型的tags列' + parameters = [{ + 'name': 'file_path', + 'type': 'string', + 'description': '文件的路径', + 'required': True, + }, + { + 'name': 'text_field', + 'type': 'string', + 'description': '文本字段', + 'required': True, + }, + { + 'name': 'keywords', + 'type': 'string', + 'description': '关键字列表,多个关键字用,号分割', + 'required': True, + }, + ] + + def call(self, params: str, **kwargs) -> str: + params = json5.loads(params) + file_path = params['file_path'] + keywords = params['keywords'] + keywords = keywords.split(",") + + text_field = params['text_field'] + + + dataset = pl.read_excel(file_path) + + df = dataset.with_columns( + pl.col(text_field) + .map_elements(lambda x: extract_tags(x, keywords),return_dtype=pl.List(pl.Utf8)) + .alias("tags") + ) + + #print(df) + #print(df['tags'].dtype) + + #转换成标准的json字符串的形式,以便写入 + df = df.with_columns( + pl.col("tags").map_elements(lambda x: json.dumps(list(x),ensure_ascii=False),return_dtype=pl.String) + ) + new_file_path = gen_data_to_excel(file_path,df.to_dicts(),"标签提取") + return {"excel":new_file_path,"data":df.limit(20).write_json(),"cols":df.columns,"desc":f"标签提取:{text_field}字段提取{keywords}"} + + +def init_agent_service(): + + system = ('你是一个数据分析助手,你可以调用各种工具分析数据,所有的工具参数都不要臆想,而是使用对话中的信息生成。' + '你可以先调用do_desc_excel工具来知道原始文件数据的行数,列数和列名等信息,避免出错。' + '分析完成后,并对结果数据做些分析点评。' + '调用do_ext_tags会新增tags标签字段并生成新的文件,后续的分析要使用新的这个文件,而不是原始文件。' + '调用do_filter会过滤数据并生成新的文件,基于过滤数据的分析要使用新的这个文件,而不是原始文件。' + '除非明确要求生成文件,否则都不生成文件。' + '如果所有工具都不能满足时,就返回"对不起,暂时无法回答你的问题"') + + tools = ['do_desc_excel','do_group_count','do_group_agg','do_fields_agg','do_ext_tags',"do_filter"] + bot = Assistant( + llm=llm, + name='数据分析助手', + description='分析Excel的数据', + system_message=system, + function_list=tools, + ) + + return bot + +#用于智能体对话的入口函数 +def Agent_DA_chat(query,file_path): + # Define the agent + bot = init_agent_service() + + messages = [] + messages.append({'role': 'user', 'content': f'原始文件:{file_path} {query} '}) + + response={} + i=1 + try: + for response in bot.run(messages): + if response[-1]["content"]!="": + if response[-1]["role"]=="function": + yield f'{response[-1]["name"]}{response[-1]["content"]}' + else: + yield response[-1]["content"] + except Exception as e: + yield f"大模型运行出现错误:{e}, 请检查配置是否正确或者网络通信是否通畅!\n如果是本地大模型还请检查大模型是否正常启动!" + return "" + + return response[-1]["content"] + + # for response in bot.run_nonstream(messages): + # print("-----------------------") + # print(response) + # json_str = "data: " + json.dumps({'rsp':response["content"]}) + "\n\n" + # yield json_str.encode("utf-8") + + +def DA_chat(query,ctx,file_path): + # Define the agent + bot = init_agent_service() + + messages = [] + messages.append({'role': 'user', 'content': f'原始文件:{file_path} {query} '}) + + response={} + i=1 + try: + for response in bot.run(messages): + if len(response)>i: + json_str = "data: " + json.dumps({'rsp':""}) + "\n\n" + i+=1 + yield json_str.encode("utf-8") + if response[-1]["content"]!="": + if response[-1]["role"]=="function": + json_str = "data: " + json.dumps({'rsp':response[-1]["content"],'type':response[-1]["name"]}) + "\n\n" + else: + json_str = "data: " + json.dumps({'rsp':response[-1]["content"]}) + "\n\n" + yield json_str.encode("utf-8") + except Exception as e: + json_str = "data: " + json.dumps({"rsp": f"大模型运行出现错误:{e}, 请检查配置是否正确或者网络通信是否通畅!\n如果是本地大模型还请检查大模型是否正常启动!"}) + "\n\n" + yield json_str.encode("utf-8") + return json_str + + return response[-1]["content"] + + + +def main(file="/home/elementary/下载/20250506.xlsx"): + + questions=[ + #"简要分析这个文件中有什么信息", + "先分析区域分布,然后将分析结果写入一个新的excel", + #"客户类型分布", + #"需求分布", + #'分析不同区域不同客户类型的数量' + ] + + # Define the agent + bot = init_agent_service() + + for query in questions: + messages = [] + messages.append({'role': 'user', 'content': f'源文件:{file} {query} '}) + + response={} + for response in bot.run(messages,stream=False): + #print('bot response:', response) + pass + + + print(response[-1]["content"]) + return + + + +if __name__ == '__main__': + llm_cfg = { + # Use your own model service compatible with OpenAI API: + 'model': 'Qwen/Qwen1.5-72B-Chat', + 'model_server': 'http://192.168.124.78:8082/v1', # api_base + 'api_key': 'EMPTY', + 'generate_cfg': {"max_input_tokens":35000} + } + main() \ No newline at end of file