Add File
This commit is contained in:
@@ -0,0 +1,307 @@
|
||||
/*
|
||||
* Copyright (c) 2023-2025, Agents-Flex (fuhai999@gmail.com).
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.agentsflex.core.chain.node;
|
||||
|
||||
import com.agentsflex.core.chain.Chain;
|
||||
import com.agentsflex.core.chain.DataType;
|
||||
import com.agentsflex.core.chain.Parameter;
|
||||
import com.agentsflex.core.llm.ChatOptions;
|
||||
import com.agentsflex.core.llm.Llm;
|
||||
import com.agentsflex.core.llm.response.AiMessageResponse;
|
||||
import com.agentsflex.core.message.AiMessage;
|
||||
import com.agentsflex.core.message.SystemMessage;
|
||||
import com.agentsflex.core.prompt.TextPrompt;
|
||||
import com.agentsflex.core.prompt.template.TextPromptTemplate;
|
||||
import com.agentsflex.core.util.CollectionUtil;
|
||||
import com.agentsflex.core.util.Maps;
|
||||
import com.agentsflex.core.util.StringUtil;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONArray;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class LlmNode extends BaseNode {
|
||||
|
||||
protected Llm llm;
|
||||
protected ChatOptions chatOptions = ChatOptions.DEFAULT;
|
||||
protected String userPrompt;
|
||||
protected TextPromptTemplate userPromptTemplate;
|
||||
|
||||
protected String systemPrompt;
|
||||
protected TextPromptTemplate systemPromptTemplate;
|
||||
protected String outType = "text"; //text markdown json
|
||||
|
||||
public LlmNode() {
|
||||
}
|
||||
|
||||
|
||||
public LlmNode(Llm llm, String userPrompt) {
|
||||
this.llm = llm;
|
||||
this.userPrompt = userPrompt;
|
||||
this.userPromptTemplate = StringUtil.hasText(userPrompt)
|
||||
? TextPromptTemplate.of(userPrompt) : null;
|
||||
}
|
||||
|
||||
|
||||
public Llm getLlm() {
|
||||
return llm;
|
||||
}
|
||||
|
||||
public void setLlm(Llm llm) {
|
||||
this.llm = llm;
|
||||
}
|
||||
|
||||
public String getUserPrompt() {
|
||||
return userPrompt;
|
||||
}
|
||||
|
||||
public void setUserPrompt(String userPrompt) {
|
||||
this.userPrompt = userPrompt;
|
||||
this.userPromptTemplate = StringUtil.hasText(userPrompt)
|
||||
? TextPromptTemplate.of(userPrompt) : null;
|
||||
}
|
||||
|
||||
public String getSystemPrompt() {
|
||||
return systemPrompt;
|
||||
}
|
||||
|
||||
public void setSystemPrompt(String systemPrompt) {
|
||||
this.systemPrompt = systemPrompt;
|
||||
this.systemPromptTemplate = StringUtil.hasText(systemPrompt)
|
||||
? TextPromptTemplate.of(systemPrompt) : null;
|
||||
}
|
||||
|
||||
public ChatOptions getChatOptions() {
|
||||
return chatOptions;
|
||||
}
|
||||
|
||||
public void setChatOptions(ChatOptions chatOptions) {
|
||||
if (chatOptions == null) {
|
||||
chatOptions = ChatOptions.DEFAULT;
|
||||
}
|
||||
this.chatOptions = chatOptions;
|
||||
}
|
||||
|
||||
public String getOutType() {
|
||||
return outType;
|
||||
}
|
||||
|
||||
public void setOutType(String outType) {
|
||||
this.outType = outType;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Map<String, Object> execute(Chain chain) {
|
||||
Map<String, Object> parameterValues = chain.getParameterValues(this);
|
||||
|
||||
if (userPromptTemplate == null) {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
TextPrompt userPrompt = userPromptTemplate.format(parameterValues);
|
||||
|
||||
if (systemPromptTemplate != null) {
|
||||
String systemPrompt = systemPromptTemplate.formatToString(parameterValues);
|
||||
userPrompt.setSystemMessage(SystemMessage.of(systemPrompt));
|
||||
}
|
||||
|
||||
AiMessageResponse response = llm.chat(userPrompt, chatOptions);
|
||||
chain.output(this, response);
|
||||
|
||||
if (response == null) {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
if (response.isError()) {
|
||||
chain.stopError(response.getErrorMessage());
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
AiMessage aiMessage = response.getMessage();
|
||||
if (aiMessage == null) {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
|
||||
String responseContent = aiMessage.getContent();
|
||||
if (StringUtil.noText(responseContent)) {
|
||||
chain.stopError("Can not get response content: " + response.getResponse());
|
||||
return Collections.emptyMap();
|
||||
} else {
|
||||
responseContent = responseContent.trim();
|
||||
}
|
||||
|
||||
|
||||
if ("json".equalsIgnoreCase(outType)) {
|
||||
if (this.outputDefs != null) {
|
||||
JSONObject jsonObject;
|
||||
try {
|
||||
jsonObject = JSON.parseObject(unWrapMarkdown(responseContent));
|
||||
} catch (Exception e) {
|
||||
chain.stopError("Can not parse json: " + response.getResponse() + " " + e.getMessage());
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
return getExecuteResultMap(outputDefs, jsonObject);
|
||||
}
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
// if (outType == null || outType.equalsIgnoreCase("text") || outType.equalsIgnoreCase("markdown")) {
|
||||
else {
|
||||
if (CollectionUtil.noItems(this.outputDefs)) {
|
||||
return Maps.of("output", responseContent);
|
||||
} else {
|
||||
Parameter parameter = this.outputDefs.get(0);
|
||||
return Maps.of(parameter.getName(), responseContent);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 移除 ``` 或者 ```json 等
|
||||
*
|
||||
* @param markdown json内容
|
||||
* @return 方法 json 内容
|
||||
*/
|
||||
public static String unWrapMarkdown(String markdown) {
|
||||
// 移除开头的 ```json 或 ```
|
||||
if (markdown.startsWith("```")) {
|
||||
int newlineIndex = markdown.indexOf('\n');
|
||||
if (newlineIndex != -1) {
|
||||
markdown = markdown.substring(newlineIndex + 1);
|
||||
} else {
|
||||
// 如果没有换行符,直接去掉 ``` 部分
|
||||
markdown = markdown.substring(3);
|
||||
}
|
||||
}
|
||||
|
||||
// 移除结尾的 ```
|
||||
if (markdown.endsWith("```")) {
|
||||
markdown = markdown.substring(0, markdown.length() - 3);
|
||||
}
|
||||
return markdown.trim();
|
||||
}
|
||||
|
||||
public static Map<String, Object> getExecuteResultMap(List<Parameter> outputDefs, JSONObject data) {
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
outputDefs.forEach(output -> {
|
||||
result.put(output.getName(), getOutputDefData(output, data, false));
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
private static Object getOutputDefData(Parameter output, JSONObject data, boolean sub) {
|
||||
String name = output.getName();
|
||||
DataType dataType = output.getDataType();
|
||||
switch (dataType) {
|
||||
case Array:
|
||||
case Array_Object:
|
||||
if (output.getChildren() == null || output.getChildren().isEmpty()) {
|
||||
return data.get(name);
|
||||
}
|
||||
List<Object> subResultList = new ArrayList<>();
|
||||
Object dataObj = data.get(name);
|
||||
if (dataObj instanceof JSONArray) {
|
||||
JSONArray contentFields = ((JSONArray) dataObj);
|
||||
if (!contentFields.isEmpty()) {
|
||||
contentFields.forEach(field -> {
|
||||
if (field instanceof JSONObject) {
|
||||
subResultList.add(getChildrenResult(output.getChildren(), (JSONObject) field, sub));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
return subResultList;
|
||||
case Object:
|
||||
return (output.getChildren() != null && !output.getChildren().isEmpty()) ? getChildrenResult(output.getChildren(), sub ? data : (JSONObject) data.get(name), sub) : data.get(name);
|
||||
case String:
|
||||
case Number:
|
||||
case Boolean:
|
||||
Object obj = data.get(name);
|
||||
return (DataType.String == dataType) ? (obj instanceof String ? obj : "") : (DataType.Number == dataType) ? (obj instanceof Number ? obj : 0) : obj instanceof Boolean ? obj : false;
|
||||
case Array_String:
|
||||
case Array_Number:
|
||||
case Array_Boolean:
|
||||
Object arrayObj = data.get(name);
|
||||
if (arrayObj instanceof JSONArray) {
|
||||
((JSONArray) arrayObj).removeIf(o -> arrayRemoveFlag(dataType, o));
|
||||
return arrayObj;
|
||||
}
|
||||
return Collections.emptyList();
|
||||
default:
|
||||
return ""; // FILE和其他不支持的类型,默认空字符串
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean arrayRemoveFlag(DataType dataType, Object arrayObj) {
|
||||
boolean removeFlag = false;
|
||||
if (DataType.Array_String == dataType) {
|
||||
if (!(arrayObj instanceof String)) {
|
||||
removeFlag = true;
|
||||
}
|
||||
} else if (DataType.Array_Number == dataType) {
|
||||
if (!(arrayObj instanceof Number)) {
|
||||
removeFlag = true;
|
||||
}
|
||||
} else {
|
||||
if (!(arrayObj instanceof Boolean)) {
|
||||
removeFlag = true;
|
||||
}
|
||||
}
|
||||
return removeFlag;
|
||||
}
|
||||
|
||||
private static Map<String, Object> getChildrenResult(List<Parameter> children, JSONObject data, boolean sub) {
|
||||
Map<String, Object> childrenResult = new HashMap<>();
|
||||
children.forEach(child -> {
|
||||
String childName = child.getName();
|
||||
Object subData = getOutputDefData(child, data, sub);
|
||||
if ((subData instanceof JSONObject) && (child.getChildren() != null && !child.getChildren().isEmpty())) {
|
||||
getChildrenResult(child.getChildren(), (JSONObject) subData, true);
|
||||
} else {
|
||||
childrenResult.put(childName, subData);
|
||||
}
|
||||
});
|
||||
return childrenResult;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "LlmNode{" +
|
||||
"llm=" + llm +
|
||||
", chatOptions=" + chatOptions +
|
||||
", userPrompt='" + userPrompt + '\'' +
|
||||
", userPromptTemplate=" + userPromptTemplate +
|
||||
", systemPrompt='" + systemPrompt + '\'' +
|
||||
", systemPromptTemplate=" + systemPromptTemplate +
|
||||
", outType='" + outType + '\'' +
|
||||
", description='" + description + '\'' +
|
||||
", parameters=" + parameters +
|
||||
", outputDefs=" + outputDefs +
|
||||
", id='" + id + '\'' +
|
||||
", name='" + name + '\'' +
|
||||
", async=" + async +
|
||||
", inwardEdges=" + inwardEdges +
|
||||
", outwardEdges=" + outwardEdges +
|
||||
", condition=" + condition +
|
||||
", memory=" + memory +
|
||||
", nodeStatus=" + nodeStatus +
|
||||
'}';
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user