diff --git a/src/main/java/org/dromara/easyai/resnet/ResNetConnectionLine.java b/src/main/java/org/dromara/easyai/resnet/ResNetConnectionLine.java new file mode 100644 index 0000000..7227711 --- /dev/null +++ b/src/main/java/org/dromara/easyai/resnet/ResNetConnectionLine.java @@ -0,0 +1,76 @@ +package org.dromara.easyai.resnet; + +import org.dromara.easyai.i.CustomEncoding; +import org.dromara.easyai.matrixTools.Matrix; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * @author lidapeng + * @time 2025/4/11 11:17 + */ +public class ResNetConnectionLine implements CustomEncoding { + private ResBlock lastBlock;//最后一层残差块 + private int lastSize;//最后一层的特征大小 + private int allTimes;//隐层神经数量 + private int nerveSize;//输入神经元数量 + private int number = 0; + private final List errorValues = new ArrayList<>(); + + public void setLastBlock(ResBlock lastBlock, int lastSize, int allTimes, int nerveSize) { + this.lastBlock = lastBlock; + this.lastSize = lastSize; + this.allTimes = allTimes; + this.nerveSize = nerveSize; + } + + private void addError(Map wg) throws Exception { + if (wg.size() != nerveSize) { + throw new Exception("线性层回传误差数量与预设值不相等"); + } + for (int i = 1; i <= nerveSize; i++) { + float error = wg.get(i); + if (number == 1) { + errorValues.add(error); + } else { + float value = error + errorValues.get(i - 1); + errorValues.set(i - 1, value); + } + } + } + + private void fill(Matrix feature, float value) throws Exception { + int x = feature.getX(); + int y = feature.getY(); + float myValue = value / (x * y); + for (int i = 0; i < x; i++) { + for (int j = 0; j < y; j++) { + feature.setNub(i, j, myValue); + } + } + } + + private void toError() throws Exception { + List errorMatrix = new ArrayList<>(); + for (Float errorValue : errorValues) { + Matrix feature = new Matrix(lastSize, lastSize); + float error = errorValue; + fill(feature, error); + errorMatrix.add(feature); + } + errorValues.clear(); + lastBlock.backError(errorMatrix); + } + + @Override + public void backError(Map wg, int id) throws Exception { + number++; + addError(wg); + if (number == allTimes) { + number = 0; + toError(); + } + } +}