Add File
This commit is contained in:
@@ -0,0 +1,208 @@
|
|||||||
|
package org.dromara.easyai.rnnJumpNerveCenter;
|
||||||
|
|
||||||
|
|
||||||
|
import org.dromara.easyai.matrixTools.Matrix;
|
||||||
|
import org.dromara.easyai.config.SentenceConfig;
|
||||||
|
import org.dromara.easyai.entity.TypeMapping;
|
||||||
|
import org.dromara.easyai.entity.WordBack;
|
||||||
|
import org.dromara.easyai.function.Tanh;
|
||||||
|
import org.dromara.easyai.i.OutBack;
|
||||||
|
import org.dromara.easyai.naturalLanguage.word.WordEmbedding;
|
||||||
|
import org.dromara.easyai.rnnJumpNerveEntity.MyWordFeature;
|
||||||
|
import org.dromara.easyai.rnnJumpNerveEntity.SensoryNerve;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
public class RRNerveManager {
|
||||||
|
private final WordEmbedding wordEmbedding;
|
||||||
|
private final Map<Integer, Integer> mapping = new HashMap<>();//主键是真实id,值是映射识别用id
|
||||||
|
private NerveJumpManager typeNerveManager;//类别网络
|
||||||
|
private int typeNub;//分类数量
|
||||||
|
private int vectorDimension;//特征纵向维度
|
||||||
|
private int maxFeatureLength;//特征最长长度
|
||||||
|
private float studyPoint;//词向量学习学习率
|
||||||
|
private boolean showLog;//是否输出学习数据
|
||||||
|
private int minLength;//最小长度
|
||||||
|
private float trustPowerTh = 0;//可信阈值
|
||||||
|
private int rzModel;//正则模式
|
||||||
|
private float rzParam;//正则系数
|
||||||
|
|
||||||
|
public RRNerveManager(WordEmbedding wordEmbedding) {
|
||||||
|
this.wordEmbedding = wordEmbedding;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void init(SentenceConfig config) throws Exception {
|
||||||
|
if (config.getTypeNub() > 0) {
|
||||||
|
this.trustPowerTh = config.getTrustPowerTh();
|
||||||
|
this.minLength = config.getMinLength();
|
||||||
|
this.typeNub = config.getTypeNub();
|
||||||
|
this.vectorDimension = config.getWordVectorDimension();
|
||||||
|
this.maxFeatureLength = config.getMaxWordLength();
|
||||||
|
this.studyPoint = config.getWeStudyPoint();
|
||||||
|
this.showLog = config.isShowLog();
|
||||||
|
this.rzModel = config.getRzModel();
|
||||||
|
this.rzParam = config.getParam();
|
||||||
|
initNerveManager();
|
||||||
|
} else {
|
||||||
|
throw new Exception("分类种类数量必须大于0");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initNerveManager() throws Exception {
|
||||||
|
typeNerveManager = new NerveJumpManager(vectorDimension, vectorDimension, typeNub, maxFeatureLength - 1, new Tanh(), false,
|
||||||
|
studyPoint, rzModel, rzParam);
|
||||||
|
typeNerveManager.initRnn(true, showLog, true, false, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
private int getMappingType(int key) {//通过自增主键查找原映射
|
||||||
|
int id = 0;
|
||||||
|
for (Map.Entry<Integer, Integer> entry : mapping.entrySet()) {
|
||||||
|
if (entry.getValue() == key) {
|
||||||
|
id = entry.getKey();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int balance(Map<Integer, List<String>> model) {//强行均衡
|
||||||
|
int maxNumber = 300;
|
||||||
|
int index = 1;
|
||||||
|
for (Map.Entry<Integer, List<String>> entry : model.entrySet()) {//查找最大数量
|
||||||
|
mapping.put(entry.getKey(), index);
|
||||||
|
if (entry.getValue().size() > maxNumber) {
|
||||||
|
maxNumber = entry.getValue().size();
|
||||||
|
}
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
for (Map.Entry<Integer, List<String>> entry : model.entrySet()) {
|
||||||
|
int size = entry.getValue().size();
|
||||||
|
if (maxNumber > size) {
|
||||||
|
int times = maxNumber / size - 1;//循环几次
|
||||||
|
int sub = maxNumber % size;//余数
|
||||||
|
List<String> list = entry.getValue();
|
||||||
|
List<String> otherList = new ArrayList<>(list);
|
||||||
|
for (int i = 0; i < times; i++) {
|
||||||
|
list.addAll(otherList);
|
||||||
|
}
|
||||||
|
list.addAll(otherList.subList(0, sub));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return maxNumber;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void studyNerve(long eventId, List<SensoryNerve> sensoryNerves, List<Float> featureList, Matrix rnnMatrix, Map<Integer, Float> E, boolean isStudy, OutBack convBack, int[] storeys) throws Exception {
|
||||||
|
if (sensoryNerves.size() == featureList.size()) {
|
||||||
|
for (int i = 0; i < sensoryNerves.size(); i++) {
|
||||||
|
sensoryNerves.get(i).postMessage(eventId, featureList.get(i), isStudy, E, convBack, rnnMatrix, storeys, 0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new Exception("1size not equals,feature size:" + featureList.size() + "," +
|
||||||
|
"sensorySize:" + sensoryNerves.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getType(String sentence, long eventID) throws Exception {//进行理解
|
||||||
|
if (sentence.length() > maxFeatureLength) {
|
||||||
|
sentence = sentence.substring(0, maxFeatureLength);
|
||||||
|
}
|
||||||
|
MyWordFeature myWordFeature = wordEmbedding.getEmbedding(sentence, eventID, false);
|
||||||
|
List<Float> featureList = myWordFeature.getFirstFeatureList();
|
||||||
|
Matrix featureMatrix = myWordFeature.getFeatureMatrix();
|
||||||
|
int[] storeys = new int[featureMatrix.getX()];
|
||||||
|
for (int i = 0; i < storeys.length; i++) {
|
||||||
|
storeys[i] = i;
|
||||||
|
}
|
||||||
|
WordBack wordBack = new WordBack();//trustPowerTh
|
||||||
|
studyNerve(eventID, typeNerveManager.getSensoryNerves(), featureList, featureMatrix, null, false, wordBack, storeys);
|
||||||
|
if (wordBack.getOut() > trustPowerTh) {
|
||||||
|
return getMappingType(wordBack.getId());
|
||||||
|
} else {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void insertModel(RandomModel randomModel) throws Exception {
|
||||||
|
typeNerveManager.insertModelParameter(randomModel.getTypeModelParameter());
|
||||||
|
List<TypeMapping> typeMappings = randomModel.getTypeMappings();
|
||||||
|
mapping.clear();
|
||||||
|
for (TypeMapping typeMapping : typeMappings) {
|
||||||
|
mapping.put(typeMapping.getType(), typeMapping.getMapping());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public RandomModel getModel() throws Exception {
|
||||||
|
RandomModel randomModel = new RandomModel();
|
||||||
|
randomModel.setTypeModelParameter(typeNerveManager.getModelParameter());
|
||||||
|
List<TypeMapping> typeMappings = new ArrayList<>();
|
||||||
|
randomModel.setTypeMappings(typeMappings);
|
||||||
|
for (Map.Entry<Integer, Integer> entry : mapping.entrySet()) {
|
||||||
|
TypeMapping typeMapping = new TypeMapping();
|
||||||
|
typeMapping.setType(entry.getKey());
|
||||||
|
typeMapping.setMapping(entry.getValue());
|
||||||
|
typeMappings.add(typeMapping);
|
||||||
|
}
|
||||||
|
return randomModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
public RandomModel studyType(Map<Integer, List<String>> model) throws Exception {
|
||||||
|
int maxNumber = balance(model);//平衡样本
|
||||||
|
for (int i = 0; i < maxFeatureLength; i++) {//第一阶段学习
|
||||||
|
System.out.println("1第:" + (i + 1) + "次。共:" + maxFeatureLength + "次");
|
||||||
|
myStudy(maxNumber, model, i + 1);
|
||||||
|
}
|
||||||
|
return getModel();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void myStudy(int maxNumber, Map<Integer, List<String>> model, int time) throws Exception {
|
||||||
|
int index = 0;
|
||||||
|
Map<Integer, Float> E = new HashMap<>();
|
||||||
|
do {
|
||||||
|
for (Map.Entry<Integer, List<String>> entry : model.entrySet()) {
|
||||||
|
System.out.println("index======" + index + "," + time + "次");
|
||||||
|
E.clear();
|
||||||
|
List<String> sentence = entry.getValue();
|
||||||
|
int key = mapping.get(entry.getKey());
|
||||||
|
E.put(key, 1f);
|
||||||
|
String word = sentence.get(index);
|
||||||
|
if (word.length() > maxFeatureLength) {
|
||||||
|
word = word.substring(0, maxFeatureLength);
|
||||||
|
}
|
||||||
|
randomTypeStudy(wordEmbedding.getEmbedding(word, 1, false), E);
|
||||||
|
}
|
||||||
|
index++;
|
||||||
|
} while (index < maxNumber);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void randomTypeStudy(MyWordFeature myWordFeature, Map<Integer, Float> E) throws Exception {
|
||||||
|
Matrix featureMatrix = myWordFeature.getFeatureMatrix();
|
||||||
|
List<Float> firstFeatureList = myWordFeature.getFirstFeatureList();
|
||||||
|
int len = featureMatrix.getX();//文字长度
|
||||||
|
Random random = new Random();
|
||||||
|
if (len > 1) {//长度大于1才可以进行训练
|
||||||
|
int[] storeys;
|
||||||
|
if (len < minLength) {
|
||||||
|
storeys = new int[len];
|
||||||
|
for (int i = 0; i < len; i++) {
|
||||||
|
storeys[i] = i;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
List<Integer> list = new ArrayList<>();
|
||||||
|
for (int i = 1; i < len; i++) {
|
||||||
|
list.add(i);
|
||||||
|
}
|
||||||
|
int myLen = (int) (minLength + (float)Math.random() * (len - minLength + 1));
|
||||||
|
storeys = new int[myLen];
|
||||||
|
for (int i = 1; i < myLen; i++) {
|
||||||
|
int index = random.nextInt(list.size());
|
||||||
|
storeys[i] = list.get(index);
|
||||||
|
list.remove(index);
|
||||||
|
}
|
||||||
|
Arrays.sort(storeys);
|
||||||
|
}
|
||||||
|
studyNerve(1, typeNerveManager.getSensoryNerves(), firstFeatureList, featureMatrix
|
||||||
|
, E, true, null, storeys);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user