Add File
This commit is contained in:
@@ -0,0 +1,97 @@
|
||||
package org.dromara.easyai.naturalLanguage;
|
||||
|
||||
import org.dromara.easyai.config.TfConfig;
|
||||
import org.dromara.easyai.entity.TalkBody;
|
||||
import org.dromara.easyai.matrixTools.Matrix;
|
||||
import org.dromara.easyai.matrixTools.MatrixOperation;
|
||||
import org.dromara.easyai.naturalLanguage.word.WordBack;
|
||||
import org.dromara.easyai.naturalLanguage.word.WordEmbedding;
|
||||
import org.dromara.easyai.transFormer.TransFormerManager;
|
||||
import org.dromara.easyai.transFormer.TransWordVector;
|
||||
import org.dromara.easyai.transFormer.model.TransFormerModel;
|
||||
import org.dromara.easyai.transFormer.nerve.SensoryNerve;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class TalkToTalk {
|
||||
private final TfConfig tfConfig;
|
||||
private final int maxLength;
|
||||
private final int times;
|
||||
private final TransFormerManager transFormerManager = new TransFormerManager();
|
||||
private final boolean splitAnswer;//回答是否带隔断符
|
||||
private final String splitWord;//隔断符
|
||||
|
||||
public TalkToTalk(TfConfig tfConfig) throws Exception {
|
||||
splitWord = tfConfig.getSplitWord();
|
||||
splitAnswer = splitWord != null && !splitWord.isEmpty();
|
||||
this.tfConfig = tfConfig;
|
||||
maxLength = tfConfig.getMaxLength();
|
||||
this.times = tfConfig.getTimes();
|
||||
if (times <= 0) {
|
||||
throw new Exception("参数times必须大于0");
|
||||
}
|
||||
}
|
||||
|
||||
private void init(List<TalkBody> talkBodies) throws Exception {
|
||||
List<String> sentenceList = new ArrayList<>();
|
||||
for (TalkBody talkBody : talkBodies) {
|
||||
sentenceList.add(talkBody.getQuestion());
|
||||
sentenceList.add(talkBody.getAnswer());
|
||||
}
|
||||
transFormerManager.init(tfConfig, sentenceList);
|
||||
}
|
||||
|
||||
|
||||
public String getAnswer(String question, long eventID) throws Exception {
|
||||
SensoryNerve sensoryNerve = transFormerManager.getSensoryNerve();
|
||||
TransWordVector transWordVector = transFormerManager.getTransWordVector();
|
||||
int end = transWordVector.getEndID();
|
||||
WordBack wordBack = new WordBack();
|
||||
int id;
|
||||
StringBuilder answer = new StringBuilder();
|
||||
int index = 0;
|
||||
do {
|
||||
String myAnswer = null;
|
||||
if (answer.length() > 0) {
|
||||
myAnswer = answer.toString();
|
||||
}
|
||||
sensoryNerve.postSentence(eventID, question, myAnswer, false, wordBack);
|
||||
id = wordBack.getId();
|
||||
if (id != end) {//没有结束
|
||||
String word = transWordVector.getWordByID(id);
|
||||
if (splitAnswer) {
|
||||
answer.append(splitWord).append(word);
|
||||
} else {
|
||||
answer.append(word);
|
||||
}
|
||||
}
|
||||
index++;
|
||||
} while (id != end && index < maxLength);
|
||||
String result = answer.toString();
|
||||
return result.replace(tfConfig.startWord, "");
|
||||
}
|
||||
|
||||
public void insertModel(TransFormerModel transFormerModel) throws Exception {
|
||||
transFormerManager.insertModel(transFormerModel, tfConfig);
|
||||
}
|
||||
|
||||
|
||||
public TransFormerModel study(List<TalkBody> talkBodies) throws Exception {
|
||||
init(talkBodies);
|
||||
SensoryNerve sensoryNerve = transFormerManager.getSensoryNerve();
|
||||
int size = talkBodies.size();
|
||||
for (int k = 0; k < times; k++) {
|
||||
int index = 0;
|
||||
for (TalkBody talkBody : talkBodies) {
|
||||
index++;
|
||||
String question = talkBody.getQuestion();
|
||||
String answer = talkBody.getAnswer();
|
||||
System.out.println("问题:" + question + ", 回答:" + answer + ",训练语句下标:" + index + ",总数量:" + size + ",当前次数:" + k + ",总次数:" + times);
|
||||
sensoryNerve.postSentence(1, question, answer, true, null);
|
||||
}
|
||||
}
|
||||
return transFormerManager.getModel();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user