Add File
This commit is contained in:
@@ -0,0 +1,175 @@
|
||||
package org.dromara.easyai.naturalLanguage.word;
|
||||
|
||||
|
||||
import org.dromara.easyai.matrixTools.Matrix;
|
||||
import org.dromara.easyai.matrixTools.MatrixList;
|
||||
import org.dromara.easyai.matrixTools.MatrixOperation;
|
||||
import org.dromara.easyai.config.RZ;
|
||||
import org.dromara.easyai.config.SentenceConfig;
|
||||
import org.dromara.easyai.entity.SentenceModel;
|
||||
import org.dromara.easyai.entity.WordMatrix;
|
||||
import org.dromara.easyai.entity.WordTwoVectorModel;
|
||||
import org.dromara.easyai.function.Tanh;
|
||||
import org.dromara.easyai.i.OutBack;
|
||||
import org.dromara.easyai.rnnJumpNerveEntity.MyWordFeature;
|
||||
import org.dromara.easyai.rnnNerveCenter.NerveManager;
|
||||
import org.dromara.easyai.rnnNerveEntity.SensoryNerve;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* @param
|
||||
* @DATA
|
||||
* @Author LiDaPeng
|
||||
* @Description 词嵌入向量训练
|
||||
*/
|
||||
public class WordEmbedding extends MatrixOperation {
|
||||
private NerveManager nerveManager;
|
||||
private SentenceModel sentenceModel;
|
||||
private final List<String> wordList = new ArrayList<>();//单字集合
|
||||
private SentenceConfig config;
|
||||
private int wordVectorDimension;
|
||||
private int studyTimes = 1;
|
||||
|
||||
public void setStudyTimes(int studyTimes) {
|
||||
this.studyTimes = studyTimes;
|
||||
}
|
||||
|
||||
public void setConfig(SentenceConfig config) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
public int getWordVectorDimension() {
|
||||
return wordVectorDimension;
|
||||
}
|
||||
|
||||
public void init(SentenceModel sentenceModel, int wordVectorDimension) throws Exception {
|
||||
this.wordVectorDimension = wordVectorDimension;
|
||||
this.sentenceModel = sentenceModel;
|
||||
wordList.addAll(sentenceModel.getWordSet());
|
||||
nerveManager = new NerveManager(wordList.size(), wordVectorDimension, wordList.size()
|
||||
, 1, new Tanh(), config.getWeStudyPoint(), config.getRzModel(),
|
||||
config.getWeLParam());
|
||||
nerveManager.init(true, false, true);
|
||||
}
|
||||
|
||||
public List<String> getWordList() {
|
||||
return wordList;
|
||||
}
|
||||
|
||||
public String getWord(int id) {
|
||||
return wordList.get(id);
|
||||
}
|
||||
|
||||
public void insertModel(WordTwoVectorModel wordTwoVectorModel, int wordVectorDimension) throws Exception {
|
||||
wordList.clear();
|
||||
this.wordVectorDimension = wordVectorDimension;
|
||||
List<String> myWordList = wordTwoVectorModel.getWordList();
|
||||
wordList.addAll(myWordList);
|
||||
nerveManager = new NerveManager(wordList.size(), wordVectorDimension, wordList.size()
|
||||
, 1, new Tanh(), config.getWeStudyPoint(), RZ.NOT_RZ, 0);
|
||||
nerveManager.init(true, false, true);
|
||||
nerveManager.insertModelParameter(wordTwoVectorModel.getModelParameter());
|
||||
}
|
||||
|
||||
public MyWordFeature getEmbedding(String word, long eventId, boolean once) throws Exception {//做截断
|
||||
MyWordFeature myWordFeature = new MyWordFeature();
|
||||
int wordDim = wordVectorDimension;//
|
||||
MatrixList matrixList = null;
|
||||
for (int i = 0; i < word.length(); i++) {
|
||||
WordMatrix wordMatrix = new WordMatrix(wordDim);
|
||||
String myWord;
|
||||
if (!once) {
|
||||
myWord = word.substring(i, i + 1);
|
||||
} else {
|
||||
myWord = word;
|
||||
}
|
||||
int index = getID(myWord);
|
||||
studyDNN(eventId, index, 0, wordMatrix, false);
|
||||
if (matrixList == null) {
|
||||
myWordFeature.setFirstFeatureList(wordMatrix.getList());
|
||||
matrixList = new MatrixList(wordMatrix.getVector(), true);
|
||||
} else {
|
||||
matrixList.add(wordMatrix.getVector());
|
||||
}
|
||||
if (once) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
myWordFeature.setFeatureMatrix(matrixList.getMatrix());
|
||||
return myWordFeature;
|
||||
}
|
||||
|
||||
private void studyDNN(long eventId, int featureIndex, int resIndex, OutBack outBack, boolean isStudy) throws Exception {
|
||||
List<SensoryNerve> sensoryNerves = nerveManager.getSensoryNerves();
|
||||
int size = sensoryNerves.size();
|
||||
Map<Integer, Float> map = new HashMap<>();
|
||||
if (resIndex > 0) {
|
||||
map.put(resIndex + 1, 1f);
|
||||
}
|
||||
for (int i = 0; i < size; i++) {
|
||||
float feature = 0;
|
||||
if (i == featureIndex) {
|
||||
feature = 1;
|
||||
}
|
||||
sensoryNerves.get(i).postMessage(eventId, feature, isStudy, map, outBack, true, null);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public WordTwoVectorModel start() throws Exception {//开始进行词向量训练
|
||||
List<String[]> sentenceList = sentenceModel.getSentenceList();
|
||||
int size = sentenceList.size();
|
||||
System.out.println("词嵌入训练启动...");
|
||||
int allTimes = studyTimes * size;
|
||||
int index = 0;
|
||||
for (int k = 0; k < studyTimes; k++) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
index++;
|
||||
long start = System.currentTimeMillis();
|
||||
study(sentenceList.get(i));
|
||||
long end = System.currentTimeMillis() - start;
|
||||
float r = (float) index / allTimes * 100;
|
||||
String result = String.format("%.6f", r);
|
||||
System.out.println("size:" + size + ",index:" + i + ",耗时:" + end + ",完成度:" + result + "%");
|
||||
}
|
||||
}
|
||||
WordTwoVectorModel wordTwoVectorModel = new WordTwoVectorModel();
|
||||
wordTwoVectorModel.setModelParameter(nerveManager.getModelParameter());
|
||||
wordTwoVectorModel.setWordList(wordList);
|
||||
//词向量训练结束
|
||||
return wordTwoVectorModel;
|
||||
}
|
||||
|
||||
private void study(String[] word) throws Exception {
|
||||
int[] indexArray = new int[word.length];
|
||||
for (int i = 0; i < word.length; i++) {
|
||||
int index = getID(word[i]);
|
||||
indexArray[i] = index;
|
||||
}
|
||||
for (int i = 0; i < indexArray.length; i++) {
|
||||
int index = indexArray[i];
|
||||
for (int j = 0; j < indexArray.length; j++) {
|
||||
if (i != j) {
|
||||
int resIndex = indexArray[j];
|
||||
studyDNN(1, index, resIndex, null, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public int getID(String word) {
|
||||
int index = 0;
|
||||
int size = wordList.size();
|
||||
for (int i = 0; i < size; i++) {
|
||||
if (wordList.get(i).equals(word)) {
|
||||
index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return index;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user