Add File
This commit is contained in:
@@ -0,0 +1,307 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2023-2025, Agents-Flex (fuhai999@gmail.com).
|
||||||
|
* <p>
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
* <p>
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* <p>
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package com.agentsflex.core.chain.node;
|
||||||
|
|
||||||
|
import com.agentsflex.core.chain.Chain;
|
||||||
|
import com.agentsflex.core.chain.DataType;
|
||||||
|
import com.agentsflex.core.chain.Parameter;
|
||||||
|
import com.agentsflex.core.llm.ChatOptions;
|
||||||
|
import com.agentsflex.core.llm.Llm;
|
||||||
|
import com.agentsflex.core.llm.response.AiMessageResponse;
|
||||||
|
import com.agentsflex.core.message.AiMessage;
|
||||||
|
import com.agentsflex.core.message.SystemMessage;
|
||||||
|
import com.agentsflex.core.prompt.TextPrompt;
|
||||||
|
import com.agentsflex.core.prompt.template.TextPromptTemplate;
|
||||||
|
import com.agentsflex.core.util.CollectionUtil;
|
||||||
|
import com.agentsflex.core.util.Maps;
|
||||||
|
import com.agentsflex.core.util.StringUtil;
|
||||||
|
import com.alibaba.fastjson.JSON;
|
||||||
|
import com.alibaba.fastjson.JSONArray;
|
||||||
|
import com.alibaba.fastjson.JSONObject;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
public class LlmNode extends BaseNode {
|
||||||
|
|
||||||
|
protected Llm llm;
|
||||||
|
protected ChatOptions chatOptions = ChatOptions.DEFAULT;
|
||||||
|
protected String userPrompt;
|
||||||
|
protected TextPromptTemplate userPromptTemplate;
|
||||||
|
|
||||||
|
protected String systemPrompt;
|
||||||
|
protected TextPromptTemplate systemPromptTemplate;
|
||||||
|
protected String outType = "text"; //text markdown json
|
||||||
|
|
||||||
|
public LlmNode() {
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public LlmNode(Llm llm, String userPrompt) {
|
||||||
|
this.llm = llm;
|
||||||
|
this.userPrompt = userPrompt;
|
||||||
|
this.userPromptTemplate = StringUtil.hasText(userPrompt)
|
||||||
|
? TextPromptTemplate.of(userPrompt) : null;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public Llm getLlm() {
|
||||||
|
return llm;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setLlm(Llm llm) {
|
||||||
|
this.llm = llm;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getUserPrompt() {
|
||||||
|
return userPrompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setUserPrompt(String userPrompt) {
|
||||||
|
this.userPrompt = userPrompt;
|
||||||
|
this.userPromptTemplate = StringUtil.hasText(userPrompt)
|
||||||
|
? TextPromptTemplate.of(userPrompt) : null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getSystemPrompt() {
|
||||||
|
return systemPrompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setSystemPrompt(String systemPrompt) {
|
||||||
|
this.systemPrompt = systemPrompt;
|
||||||
|
this.systemPromptTemplate = StringUtil.hasText(systemPrompt)
|
||||||
|
? TextPromptTemplate.of(systemPrompt) : null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ChatOptions getChatOptions() {
|
||||||
|
return chatOptions;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setChatOptions(ChatOptions chatOptions) {
|
||||||
|
if (chatOptions == null) {
|
||||||
|
chatOptions = ChatOptions.DEFAULT;
|
||||||
|
}
|
||||||
|
this.chatOptions = chatOptions;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getOutType() {
|
||||||
|
return outType;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setOutType(String outType) {
|
||||||
|
this.outType = outType;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Map<String, Object> execute(Chain chain) {
|
||||||
|
Map<String, Object> parameterValues = chain.getParameterValues(this);
|
||||||
|
|
||||||
|
if (userPromptTemplate == null) {
|
||||||
|
return Collections.emptyMap();
|
||||||
|
}
|
||||||
|
|
||||||
|
TextPrompt userPrompt = userPromptTemplate.format(parameterValues);
|
||||||
|
|
||||||
|
if (systemPromptTemplate != null) {
|
||||||
|
String systemPrompt = systemPromptTemplate.formatToString(parameterValues);
|
||||||
|
userPrompt.setSystemMessage(SystemMessage.of(systemPrompt));
|
||||||
|
}
|
||||||
|
|
||||||
|
AiMessageResponse response = llm.chat(userPrompt, chatOptions);
|
||||||
|
chain.output(this, response);
|
||||||
|
|
||||||
|
if (response == null) {
|
||||||
|
return Collections.emptyMap();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (response.isError()) {
|
||||||
|
chain.stopError(response.getErrorMessage());
|
||||||
|
return Collections.emptyMap();
|
||||||
|
}
|
||||||
|
|
||||||
|
AiMessage aiMessage = response.getMessage();
|
||||||
|
if (aiMessage == null) {
|
||||||
|
return Collections.emptyMap();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
String responseContent = aiMessage.getContent();
|
||||||
|
if (StringUtil.noText(responseContent)) {
|
||||||
|
chain.stopError("Can not get response content: " + response.getResponse());
|
||||||
|
return Collections.emptyMap();
|
||||||
|
} else {
|
||||||
|
responseContent = responseContent.trim();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if ("json".equalsIgnoreCase(outType)) {
|
||||||
|
if (this.outputDefs != null) {
|
||||||
|
JSONObject jsonObject;
|
||||||
|
try {
|
||||||
|
jsonObject = JSON.parseObject(unWrapMarkdown(responseContent));
|
||||||
|
} catch (Exception e) {
|
||||||
|
chain.stopError("Can not parse json: " + response.getResponse() + " " + e.getMessage());
|
||||||
|
return Collections.emptyMap();
|
||||||
|
}
|
||||||
|
return getExecuteResultMap(outputDefs, jsonObject);
|
||||||
|
}
|
||||||
|
return Collections.emptyMap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// if (outType == null || outType.equalsIgnoreCase("text") || outType.equalsIgnoreCase("markdown")) {
|
||||||
|
else {
|
||||||
|
if (CollectionUtil.noItems(this.outputDefs)) {
|
||||||
|
return Maps.of("output", responseContent);
|
||||||
|
} else {
|
||||||
|
Parameter parameter = this.outputDefs.get(0);
|
||||||
|
return Maps.of(parameter.getName(), responseContent);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 移除 ``` 或者 ```json 等
|
||||||
|
*
|
||||||
|
* @param markdown json内容
|
||||||
|
* @return 方法 json 内容
|
||||||
|
*/
|
||||||
|
public static String unWrapMarkdown(String markdown) {
|
||||||
|
// 移除开头的 ```json 或 ```
|
||||||
|
if (markdown.startsWith("```")) {
|
||||||
|
int newlineIndex = markdown.indexOf('\n');
|
||||||
|
if (newlineIndex != -1) {
|
||||||
|
markdown = markdown.substring(newlineIndex + 1);
|
||||||
|
} else {
|
||||||
|
// 如果没有换行符,直接去掉 ``` 部分
|
||||||
|
markdown = markdown.substring(3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 移除结尾的 ```
|
||||||
|
if (markdown.endsWith("```")) {
|
||||||
|
markdown = markdown.substring(0, markdown.length() - 3);
|
||||||
|
}
|
||||||
|
return markdown.trim();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Map<String, Object> getExecuteResultMap(List<Parameter> outputDefs, JSONObject data) {
|
||||||
|
Map<String, Object> result = new HashMap<>();
|
||||||
|
outputDefs.forEach(output -> {
|
||||||
|
result.put(output.getName(), getOutputDefData(output, data, false));
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Object getOutputDefData(Parameter output, JSONObject data, boolean sub) {
|
||||||
|
String name = output.getName();
|
||||||
|
DataType dataType = output.getDataType();
|
||||||
|
switch (dataType) {
|
||||||
|
case Array:
|
||||||
|
case Array_Object:
|
||||||
|
if (output.getChildren() == null || output.getChildren().isEmpty()) {
|
||||||
|
return data.get(name);
|
||||||
|
}
|
||||||
|
List<Object> subResultList = new ArrayList<>();
|
||||||
|
Object dataObj = data.get(name);
|
||||||
|
if (dataObj instanceof JSONArray) {
|
||||||
|
JSONArray contentFields = ((JSONArray) dataObj);
|
||||||
|
if (!contentFields.isEmpty()) {
|
||||||
|
contentFields.forEach(field -> {
|
||||||
|
if (field instanceof JSONObject) {
|
||||||
|
subResultList.add(getChildrenResult(output.getChildren(), (JSONObject) field, sub));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return subResultList;
|
||||||
|
case Object:
|
||||||
|
return (output.getChildren() != null && !output.getChildren().isEmpty()) ? getChildrenResult(output.getChildren(), sub ? data : (JSONObject) data.get(name), sub) : data.get(name);
|
||||||
|
case String:
|
||||||
|
case Number:
|
||||||
|
case Boolean:
|
||||||
|
Object obj = data.get(name);
|
||||||
|
return (DataType.String == dataType) ? (obj instanceof String ? obj : "") : (DataType.Number == dataType) ? (obj instanceof Number ? obj : 0) : obj instanceof Boolean ? obj : false;
|
||||||
|
case Array_String:
|
||||||
|
case Array_Number:
|
||||||
|
case Array_Boolean:
|
||||||
|
Object arrayObj = data.get(name);
|
||||||
|
if (arrayObj instanceof JSONArray) {
|
||||||
|
((JSONArray) arrayObj).removeIf(o -> arrayRemoveFlag(dataType, o));
|
||||||
|
return arrayObj;
|
||||||
|
}
|
||||||
|
return Collections.emptyList();
|
||||||
|
default:
|
||||||
|
return ""; // FILE和其他不支持的类型,默认空字符串
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean arrayRemoveFlag(DataType dataType, Object arrayObj) {
|
||||||
|
boolean removeFlag = false;
|
||||||
|
if (DataType.Array_String == dataType) {
|
||||||
|
if (!(arrayObj instanceof String)) {
|
||||||
|
removeFlag = true;
|
||||||
|
}
|
||||||
|
} else if (DataType.Array_Number == dataType) {
|
||||||
|
if (!(arrayObj instanceof Number)) {
|
||||||
|
removeFlag = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (!(arrayObj instanceof Boolean)) {
|
||||||
|
removeFlag = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return removeFlag;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Map<String, Object> getChildrenResult(List<Parameter> children, JSONObject data, boolean sub) {
|
||||||
|
Map<String, Object> childrenResult = new HashMap<>();
|
||||||
|
children.forEach(child -> {
|
||||||
|
String childName = child.getName();
|
||||||
|
Object subData = getOutputDefData(child, data, sub);
|
||||||
|
if ((subData instanceof JSONObject) && (child.getChildren() != null && !child.getChildren().isEmpty())) {
|
||||||
|
getChildrenResult(child.getChildren(), (JSONObject) subData, true);
|
||||||
|
} else {
|
||||||
|
childrenResult.put(childName, subData);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return childrenResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "LlmNode{" +
|
||||||
|
"llm=" + llm +
|
||||||
|
", chatOptions=" + chatOptions +
|
||||||
|
", userPrompt='" + userPrompt + '\'' +
|
||||||
|
", userPromptTemplate=" + userPromptTemplate +
|
||||||
|
", systemPrompt='" + systemPrompt + '\'' +
|
||||||
|
", systemPromptTemplate=" + systemPromptTemplate +
|
||||||
|
", outType='" + outType + '\'' +
|
||||||
|
", description='" + description + '\'' +
|
||||||
|
", parameters=" + parameters +
|
||||||
|
", outputDefs=" + outputDefs +
|
||||||
|
", id='" + id + '\'' +
|
||||||
|
", name='" + name + '\'' +
|
||||||
|
", async=" + async +
|
||||||
|
", inwardEdges=" + inwardEdges +
|
||||||
|
", outwardEdges=" + outwardEdges +
|
||||||
|
", condition=" + condition +
|
||||||
|
", memory=" + memory +
|
||||||
|
", nodeStatus=" + nodeStatus +
|
||||||
|
'}';
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user