Add File
This commit is contained in:
@@ -0,0 +1,175 @@
|
|||||||
|
package org.dromara.easyai.naturalLanguage.word;
|
||||||
|
|
||||||
|
|
||||||
|
import org.dromara.easyai.matrixTools.Matrix;
|
||||||
|
import org.dromara.easyai.matrixTools.MatrixList;
|
||||||
|
import org.dromara.easyai.matrixTools.MatrixOperation;
|
||||||
|
import org.dromara.easyai.config.RZ;
|
||||||
|
import org.dromara.easyai.config.SentenceConfig;
|
||||||
|
import org.dromara.easyai.entity.SentenceModel;
|
||||||
|
import org.dromara.easyai.entity.WordMatrix;
|
||||||
|
import org.dromara.easyai.entity.WordTwoVectorModel;
|
||||||
|
import org.dromara.easyai.function.Tanh;
|
||||||
|
import org.dromara.easyai.i.OutBack;
|
||||||
|
import org.dromara.easyai.rnnJumpNerveEntity.MyWordFeature;
|
||||||
|
import org.dromara.easyai.rnnNerveCenter.NerveManager;
|
||||||
|
import org.dromara.easyai.rnnNerveEntity.SensoryNerve;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param
|
||||||
|
* @DATA
|
||||||
|
* @Author LiDaPeng
|
||||||
|
* @Description 词嵌入向量训练
|
||||||
|
*/
|
||||||
|
public class WordEmbedding extends MatrixOperation {
|
||||||
|
private NerveManager nerveManager;
|
||||||
|
private SentenceModel sentenceModel;
|
||||||
|
private final List<String> wordList = new ArrayList<>();//单字集合
|
||||||
|
private SentenceConfig config;
|
||||||
|
private int wordVectorDimension;
|
||||||
|
private int studyTimes = 1;
|
||||||
|
|
||||||
|
public void setStudyTimes(int studyTimes) {
|
||||||
|
this.studyTimes = studyTimes;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setConfig(SentenceConfig config) {
|
||||||
|
this.config = config;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getWordVectorDimension() {
|
||||||
|
return wordVectorDimension;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void init(SentenceModel sentenceModel, int wordVectorDimension) throws Exception {
|
||||||
|
this.wordVectorDimension = wordVectorDimension;
|
||||||
|
this.sentenceModel = sentenceModel;
|
||||||
|
wordList.addAll(sentenceModel.getWordSet());
|
||||||
|
nerveManager = new NerveManager(wordList.size(), wordVectorDimension, wordList.size()
|
||||||
|
, 1, new Tanh(), config.getWeStudyPoint(), config.getRzModel(),
|
||||||
|
config.getWeLParam());
|
||||||
|
nerveManager.init(true, false, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> getWordList() {
|
||||||
|
return wordList;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getWord(int id) {
|
||||||
|
return wordList.get(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void insertModel(WordTwoVectorModel wordTwoVectorModel, int wordVectorDimension) throws Exception {
|
||||||
|
wordList.clear();
|
||||||
|
this.wordVectorDimension = wordVectorDimension;
|
||||||
|
List<String> myWordList = wordTwoVectorModel.getWordList();
|
||||||
|
wordList.addAll(myWordList);
|
||||||
|
nerveManager = new NerveManager(wordList.size(), wordVectorDimension, wordList.size()
|
||||||
|
, 1, new Tanh(), config.getWeStudyPoint(), RZ.NOT_RZ, 0);
|
||||||
|
nerveManager.init(true, false, true);
|
||||||
|
nerveManager.insertModelParameter(wordTwoVectorModel.getModelParameter());
|
||||||
|
}
|
||||||
|
|
||||||
|
public MyWordFeature getEmbedding(String word, long eventId, boolean once) throws Exception {//做截断
|
||||||
|
MyWordFeature myWordFeature = new MyWordFeature();
|
||||||
|
int wordDim = wordVectorDimension;//
|
||||||
|
MatrixList matrixList = null;
|
||||||
|
for (int i = 0; i < word.length(); i++) {
|
||||||
|
WordMatrix wordMatrix = new WordMatrix(wordDim);
|
||||||
|
String myWord;
|
||||||
|
if (!once) {
|
||||||
|
myWord = word.substring(i, i + 1);
|
||||||
|
} else {
|
||||||
|
myWord = word;
|
||||||
|
}
|
||||||
|
int index = getID(myWord);
|
||||||
|
studyDNN(eventId, index, 0, wordMatrix, false);
|
||||||
|
if (matrixList == null) {
|
||||||
|
myWordFeature.setFirstFeatureList(wordMatrix.getList());
|
||||||
|
matrixList = new MatrixList(wordMatrix.getVector(), true);
|
||||||
|
} else {
|
||||||
|
matrixList.add(wordMatrix.getVector());
|
||||||
|
}
|
||||||
|
if (once) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
myWordFeature.setFeatureMatrix(matrixList.getMatrix());
|
||||||
|
return myWordFeature;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void studyDNN(long eventId, int featureIndex, int resIndex, OutBack outBack, boolean isStudy) throws Exception {
|
||||||
|
List<SensoryNerve> sensoryNerves = nerveManager.getSensoryNerves();
|
||||||
|
int size = sensoryNerves.size();
|
||||||
|
Map<Integer, Float> map = new HashMap<>();
|
||||||
|
if (resIndex > 0) {
|
||||||
|
map.put(resIndex + 1, 1f);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
float feature = 0;
|
||||||
|
if (i == featureIndex) {
|
||||||
|
feature = 1;
|
||||||
|
}
|
||||||
|
sensoryNerves.get(i).postMessage(eventId, feature, isStudy, map, outBack, true, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public WordTwoVectorModel start() throws Exception {//开始进行词向量训练
|
||||||
|
List<String[]> sentenceList = sentenceModel.getSentenceList();
|
||||||
|
int size = sentenceList.size();
|
||||||
|
System.out.println("词嵌入训练启动...");
|
||||||
|
int allTimes = studyTimes * size;
|
||||||
|
int index = 0;
|
||||||
|
for (int k = 0; k < studyTimes; k++) {
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
index++;
|
||||||
|
long start = System.currentTimeMillis();
|
||||||
|
study(sentenceList.get(i));
|
||||||
|
long end = System.currentTimeMillis() - start;
|
||||||
|
float r = (float) index / allTimes * 100;
|
||||||
|
String result = String.format("%.6f", r);
|
||||||
|
System.out.println("size:" + size + ",index:" + i + ",耗时:" + end + ",完成度:" + result + "%");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
WordTwoVectorModel wordTwoVectorModel = new WordTwoVectorModel();
|
||||||
|
wordTwoVectorModel.setModelParameter(nerveManager.getModelParameter());
|
||||||
|
wordTwoVectorModel.setWordList(wordList);
|
||||||
|
//词向量训练结束
|
||||||
|
return wordTwoVectorModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void study(String[] word) throws Exception {
|
||||||
|
int[] indexArray = new int[word.length];
|
||||||
|
for (int i = 0; i < word.length; i++) {
|
||||||
|
int index = getID(word[i]);
|
||||||
|
indexArray[i] = index;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < indexArray.length; i++) {
|
||||||
|
int index = indexArray[i];
|
||||||
|
for (int j = 0; j < indexArray.length; j++) {
|
||||||
|
if (i != j) {
|
||||||
|
int resIndex = indexArray[j];
|
||||||
|
studyDNN(1, index, resIndex, null, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getID(String word) {
|
||||||
|
int index = 0;
|
||||||
|
int size = wordList.size();
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
if (wordList.get(i).equals(word)) {
|
||||||
|
index = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return index;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user