This commit is contained in:
2025-09-04 14:09:29 +08:00
parent 90583c0cbb
commit a7323421d7

View File

@@ -0,0 +1,97 @@
package org.dromara.easyai.naturalLanguage;
import org.dromara.easyai.config.TfConfig;
import org.dromara.easyai.entity.TalkBody;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.naturalLanguage.word.WordBack;
import org.dromara.easyai.naturalLanguage.word.WordEmbedding;
import org.dromara.easyai.transFormer.TransFormerManager;
import org.dromara.easyai.transFormer.TransWordVector;
import org.dromara.easyai.transFormer.model.TransFormerModel;
import org.dromara.easyai.transFormer.nerve.SensoryNerve;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class TalkToTalk {
private final TfConfig tfConfig;
private final int maxLength;
private final int times;
private final TransFormerManager transFormerManager = new TransFormerManager();
private final boolean splitAnswer;//回答是否带隔断符
private final String splitWord;//隔断符
public TalkToTalk(TfConfig tfConfig) throws Exception {
splitWord = tfConfig.getSplitWord();
splitAnswer = splitWord != null && !splitWord.isEmpty();
this.tfConfig = tfConfig;
maxLength = tfConfig.getMaxLength();
this.times = tfConfig.getTimes();
if (times <= 0) {
throw new Exception("参数times必须大于0");
}
}
private void init(List<TalkBody> talkBodies) throws Exception {
List<String> sentenceList = new ArrayList<>();
for (TalkBody talkBody : talkBodies) {
sentenceList.add(talkBody.getQuestion());
sentenceList.add(talkBody.getAnswer());
}
transFormerManager.init(tfConfig, sentenceList);
}
public String getAnswer(String question, long eventID) throws Exception {
SensoryNerve sensoryNerve = transFormerManager.getSensoryNerve();
TransWordVector transWordVector = transFormerManager.getTransWordVector();
int end = transWordVector.getEndID();
WordBack wordBack = new WordBack();
int id;
StringBuilder answer = new StringBuilder();
int index = 0;
do {
String myAnswer = null;
if (answer.length() > 0) {
myAnswer = answer.toString();
}
sensoryNerve.postSentence(eventID, question, myAnswer, false, wordBack);
id = wordBack.getId();
if (id != end) {//没有结束
String word = transWordVector.getWordByID(id);
if (splitAnswer) {
answer.append(splitWord).append(word);
} else {
answer.append(word);
}
}
index++;
} while (id != end && index < maxLength);
String result = answer.toString();
return result.replace(tfConfig.startWord, "");
}
public void insertModel(TransFormerModel transFormerModel) throws Exception {
transFormerManager.insertModel(transFormerModel, tfConfig);
}
public TransFormerModel study(List<TalkBody> talkBodies) throws Exception {
init(talkBodies);
SensoryNerve sensoryNerve = transFormerManager.getSensoryNerve();
int size = talkBodies.size();
for (int k = 0; k < times; k++) {
int index = 0;
for (TalkBody talkBody : talkBodies) {
index++;
String question = talkBody.getQuestion();
String answer = talkBody.getAnswer();
System.out.println("问题:" + question + ", 回答:" + answer + ",训练语句下标:" + index + ",总数量:" + size + ",当前次数:" + k + ",总次数:" + times);
sensoryNerve.postSentence(1, question, answer, true, null);
}
}
return transFormerManager.getModel();
}
}