diff --git a/src/main/java/org/dromara/easyai/rnnJumpNerveEntity/SoftMax.java b/src/main/java/org/dromara/easyai/rnnJumpNerveEntity/SoftMax.java new file mode 100644 index 0000000..09d6c17 --- /dev/null +++ b/src/main/java/org/dromara/easyai/rnnJumpNerveEntity/SoftMax.java @@ -0,0 +1,122 @@ +package org.dromara.easyai.rnnJumpNerveEntity; + + +import org.dromara.easyai.matrixTools.Matrix; +import org.dromara.easyai.config.RZ; +import org.dromara.easyai.i.OutBack; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class SoftMax extends Nerve { + private final List outNerves; + private final boolean isShowLog; + private NerveCenter nerveCenter;//该输出层对应的神经中枢 + + public SoftMax(List outNerves, boolean isShowLog + , int sensoryNerveNub, int hiddenNerveNub, int outNerveNub, int allDepth) throws Exception { + super(0, "softMax", 0, false, null + , RZ.NOT_RZ, 0, sensoryNerveNub, hiddenNerveNub, outNerveNub, allDepth + , false, 0); + this.outNerves = outNerves; + this.isShowLog = isShowLog; + } + + public void setNerveCenter(NerveCenter nerveCenter) { + this.nerveCenter = nerveCenter; + } + + @Override + protected void sendAppointSoftMax(long eventId, float parameter, Matrix featureMatrix, OutBack outBack, String myWord) throws Exception { + boolean allReady = insertParameter(eventId, parameter); + if (allReady) { + Mes mes = softMax(eventId, false);//输出值 + destroyParameter(eventId); + nerveCenter.backType(eventId, mes.poi, mes.typeID, featureMatrix, outBack, myWord); + } + } + + @Override + protected void input(long eventId, float parameter, boolean isStudy, Map E, OutBack outBack + , Matrix rnnMatrix, int[] storeys, int index, int questionLength) throws Exception { + boolean allReady = insertParameter(eventId, parameter); + if (allReady) { + Mes mes = softMax(eventId, isStudy);//输出值 + int key = 0; + if (isStudy) {//学习 + for (Map.Entry entry : E.entrySet()) { + if (entry.getValue() > 0.9) { + key = entry.getKey(); + break; + } + } + if (isShowLog) { + System.out.println("softMax==" + key + ",out==" + mes.poi + ",nerveId==" + mes.typeID); + } + List errors = error(mes, key); + features.remove(eventId); //清空当前上层输入参数参数 + int size = outNerves.size(); + for (int i = 0; i < size; i++) { + outNerves.get(i).getGBySoftMax(errors.get(i), eventId, storeys, index); + } + } else {//输出 + destroyParameter(eventId); + if (outBack != null) { + outBack.getBack(mes.poi, mes.typeID, eventId); + outBack.getSoftMaxBack(eventId, mes.softMax); + } else { + throw new Exception("not find outBack"); + } + } + } + } + + private List error(Mes mes, int key) { + int t = key - 1; + List softMax = mes.softMax; + List error = new ArrayList<>(); + for (int i = 0; i < softMax.size(); i++) { + float self = softMax.get(i); + float myError; + if (i != t) { + myError = -self; + } else { + myError = 1 - self; + } + error.add(myError); + } + return error; + } + + private Mes softMax(long eventId, boolean isStudy) {//计算当前输出结果 + float sigma = 0; + int id = 0; + float poi = 0; + Mes mes = new Mes(); + List featuresList = features.get(eventId); + for (float value : featuresList) { + sigma = (float) Math.exp(value) + sigma; + } + List softMax = new ArrayList<>(); + for (int i = 0; i < featuresList.size(); i++) { + float eSelf = (float) Math.exp(featuresList.get(i)); + float value = eSelf / sigma; + softMax.add(value); + if (value > poi) { + poi = value; + id = i + 1; + } + } + mes.softMax = softMax; + mes.typeID = id; + mes.poi = poi; + return mes; + } + + static class Mes { + int typeID; + float poi; + List softMax; + } +}