This commit is contained in:
2025-08-27 19:57:23 +08:00
parent 93ad5a6ea6
commit 1aaf19a540

View File

@@ -0,0 +1,221 @@
/*
* Copyright (c) 2023-2025, Agents-Flex (fuhai999@gmail.com).
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.agentsflex.llm.spark;
import com.agentsflex.core.document.Document;
import com.agentsflex.core.llm.functions.Function;
import com.agentsflex.core.llm.functions.Parameter;
import com.agentsflex.core.llm.ChatOptions;
import com.agentsflex.core.message.*;
import com.agentsflex.core.parser.AiMessageParser;
import com.agentsflex.core.parser.impl.DefaultAiMessageParser;
import com.agentsflex.core.prompt.DefaultPromptFormat;
import com.agentsflex.core.prompt.Prompt;
import com.agentsflex.core.prompt.PromptFormat;
import com.agentsflex.core.util.HashUtil;
import com.agentsflex.core.util.Maps;
import com.agentsflex.core.util.MessageUtil;
import com.alibaba.fastjson.*;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.text.SimpleDateFormat;
import java.util.*;
public class SparkLlmUtil {
private static final PromptFormat promptFormat = new DefaultPromptFormat() {
@Override
protected void buildFunctionJsonArray(List<Map<String, Object>> functionsJsonArray, List<Function> functions) {
for (Function function : functions) {
Map<String, Object> propertiesMap = new HashMap<>();
List<String> requiredProperties = new ArrayList<>();
Parameter[] parameters = function.getParameters();
if (parameters != null) {
for (Parameter parameter : parameters) {
if (parameter.isRequired()) {
requiredProperties.add(parameter.getName());
}
propertiesMap.put(parameter.getName(), Maps.of("type", parameter.getType()).set("description", parameter.getDescription()));
}
}
Maps builder = Maps.of("name", function.getName())
.set("description", function.getDescription())
.set("parameters", Maps.of("type", "object").set("properties", propertiesMap).set("required", requiredProperties));
functionsJsonArray.add(builder);
}
}
};
public static AiMessageParser getAiMessageParser() {
DefaultAiMessageParser aiMessageParser = new DefaultAiMessageParser() {
@Override
public AiMessage parse(JSONObject rootJson) {
if (!rootJson.containsKey("payload")) {
throw new JSONException("json not contains payload: " + rootJson);
}
return super.parse(rootJson);
}
};
aiMessageParser.setContentPath("$.payload.choices.text[0].content");
aiMessageParser.setIndexPath("$.payload.choices.text[0].index");
aiMessageParser.setCompletionTokensPath("$.payload.usage.text.completion_tokens");
aiMessageParser.setPromptTokensPath("$.payload.usage.text.prompt_tokens");
aiMessageParser.setTotalTokensPath("$.payload.usage.text.total_tokens");
aiMessageParser.setStatusParser(content -> {
Integer status = (Integer) JSONPath.eval(content, "$.payload.choices.status");
if (status == null) {
return MessageStatus.UNKNOW;
}
switch (status) {
case 0:
return MessageStatus.START;
case 1:
return MessageStatus.MIDDLE;
case 2:
return MessageStatus.END;
}
return MessageStatus.UNKNOW;
});
aiMessageParser.setCallsParser(content -> {
JSONArray toolCalls = (JSONArray) JSONPath.eval(content, "$.payload.choices.text");
if (toolCalls == null || toolCalls.isEmpty()) {
return Collections.emptyList();
}
List<FunctionCall> functionCalls = new ArrayList<>();
for (int i = 0; i < toolCalls.size(); i++) {
JSONObject jsonObject = toolCalls.getJSONObject(i);
JSONObject functionObject = jsonObject.getJSONObject("function_call");
if (functionObject != null) {
FunctionCall functionCall = new FunctionCall();
functionCall.setId(jsonObject.getString("id"));
functionCall.setName(functionObject.getString("name"));
Object arguments = functionObject.get("arguments");
if (arguments instanceof Map) {
//noinspection unchecked
functionCall.setArgs((Map<String, Object>) arguments);
} else if (arguments instanceof String) {
//noinspection unchecked
functionCall.setArgs(JSON.parseObject(arguments.toString(), Map.class));
}
functionCalls.add(functionCall);
}
}
return functionCalls;
});
return aiMessageParser;
}
public static String promptToPayload(Prompt prompt, SparkLlmConfig config, ChatOptions options) {
// https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
List<Message> messages = prompt.toMessages();
HumanMessage message = MessageUtil.findLastHumanMessage(messages);
Maps root = Maps.of("header", Maps.of("app_id", config.getAppId()).set("uid", UUID.randomUUID().toString().replaceAll("-", "")));
root.set("parameter", Maps.of("chat", Maps.of("domain", getDomain(config.getVersion()))
.setIf(options.getTemperature() > 0, "temperature", options.getTemperature())
.setIf(options.getMaxTokens() != null, "max_tokens", options.getMaxTokens())
.setIfNotNull("top_k", options.getTopK())
)
);
root.set("payload", Maps.of("message", Maps.of("text", promptFormat.toMessagesJsonObject(messages)))
.setIfNotEmpty("functions", Maps.ofNotNull("text", promptFormat.toFunctionsJsonObject(message)))
);
root.setIfNotEmpty(options.getExtra());
return JSON.toJSONString(root);
}
public static String createURL(SparkLlmConfig config) {
SimpleDateFormat sdf = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss '+0000'", Locale.US);
sdf.setTimeZone(java.util.TimeZone.getTimeZone("UTC"));
String date = sdf.format(new Date());
String header = "host: spark-api.xf-yun.com\n";
header += "date: " + date + "\n";
header += "GET /" + config.getVersion() + "/chat HTTP/1.1";
String base64 = HashUtil.hmacSHA256ToBase64(header, config.getApiSecret());
String authorization_origin = "api_key=\"" + config.getApiKey()
+ "\", algorithm=\"hmac-sha256\", headers=\"host date request-line\", signature=\"" + base64 + "\"";
String authorization = Base64.getEncoder().encodeToString(authorization_origin.getBytes());
return "ws://spark-api.xf-yun.com/" + config.getVersion() + "/chat?authorization=" + authorization
+ "&date=" + urlEncode(date) + "&host=spark-api.xf-yun.com";
}
private static String urlEncode(String content) {
try {
return URLEncoder.encode(content, "utf-8").replace("+", "%20");
} catch (UnsupportedEncodingException e) {
throw new RuntimeException(e);
}
}
private static String getDomain(String version) {
switch (version) {
case "v4.0":
return "4.0Ultra";
case "v3.5":
return "generalv3.5";
case "v3.1":
return "generalv3";
case "v2.1":
return "generalv2";
case "v1.1":
return "lite";
default:
return "general";
}
}
public static String embedPayload(SparkLlmConfig config, Document document) {
String text = Maps.of("messages", Collections.singletonList(Maps.of("content", document.getContent()).set("role", "user"))).toJSON();
String textBase64 = Base64.getEncoder().encodeToString(text.getBytes());
return Maps.of("header", Maps.of("app_id", config.getAppId()).set("uid", UUID.randomUUID()).set("status", 3))
.set("parameter", Maps.of("emb", Maps.of("domain", "para").set("feature", Maps.of("encoding", "utf8").set("compress", "raw").set("format", "plain"))))
.set("payload", Maps.of("messages", Maps.of("encoding", "utf8").set("compress", "raw").set("format", "json").set("status", 3).set("text", textBase64)))
.toJSON();
}
/// http://emb-cn-huabei-1.xf-yun.com/
public static String createEmbedURL(SparkLlmConfig config) {
SimpleDateFormat sdf = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss '+0000'", Locale.US);
sdf.setTimeZone(java.util.TimeZone.getTimeZone("UTC"));
String date = sdf.format(new Date());
String header = "host: emb-cn-huabei-1.xf-yun.com\n";
header += "date: " + date + "\n";
header += "POST / HTTP/1.1";
String base64 = HashUtil.hmacSHA256ToBase64(header, config.getApiSecret());
String authorization_origin = "api_key=\"" + config.getApiKey()
+ "\", algorithm=\"hmac-sha256\", headers=\"host date request-line\", signature=\"" + base64 + "\"";
String authorization = Base64.getEncoder().encodeToString(authorization_origin.getBytes());
return "http://emb-cn-huabei-1.xf-yun.com/?authorization=" + authorization
+ "&date=" + urlEncode(date) + "&host=emb-cn-huabei-1.xf-yun.com";
}
}