package org.dromara.easyai.resnet; import org.dromara.easyai.conv.ResConvCount; import org.dromara.easyai.i.OutBack; import org.dromara.easyai.matrixTools.Matrix; import org.dromara.easyai.matrixTools.MatrixNorm; import org.dromara.easyai.matrixTools.MatrixOperation; import org.dromara.easyai.nerveEntity.SensoryNerve; import org.dromara.easyai.resnet.entity.ResBlockModel; import org.dromara.easyai.resnet.entity.ResnetError; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Random; /** * @author lidapeng * @time 2025/4/11 13:40 * @des resnet残差模块 */ public class ResBlock extends ResConvCount { private ConvLay firstConvPower;//第一层卷积 private final ResConvPower firstResConvPower = new ResConvPower(); private final ResConvPower secondResConvPower = new ResConvPower(); private final int channelNo; private final int deep; private final float studyRate;//全局学习率 private final int imageSize;//图像大小 private ResBlock fatherResBlock; private ResBlock sonResBlock; private final List sensoryNerves;//输出神经元 private final float gaMa; private final float gMaxTh; private final boolean auto; private final float GRate;//梯度衰减 private final MatrixOperation matrixOperation = new MatrixOperation(); public ResBlockModel getModel() { ResBlockModel model = new ResBlockModel(); model.setFirstResConvModel(firstResConvPower.getModel()); model.setSecondResConvModel(secondResConvPower.getModel()); if (deep == 1) { model.setFirstConvModel(firstConvPower.getModel()); } return model; } public void insertModel(ResBlockModel resBlockModel) { firstResConvPower.insertModel(resBlockModel.getFirstResConvModel()); secondResConvPower.insertModel(resBlockModel.getSecondResConvModel()); if (deep == 1) { firstConvPower.insertModel(resBlockModel.getFirstConvModel()); } } public ResBlock(int channelNo, int deep, float studyRate, int imageSize, List sensoryNerves, float gaMa, float gMaxTh , boolean auto, float GRate) throws Exception { this.imageSize = imageSize; this.GRate = GRate; this.sensoryNerves = sensoryNerves; this.channelNo = channelNo; this.deep = deep; this.gMaxTh = gMaxTh; this.studyRate = studyRate; this.auto = auto; this.gaMa = gaMa; boolean initOneConv = true; Random random = new Random(); if (deep == 1) { initOneConv = false; firstConvPower = initMatrixPower(random, 7, channelNo, true); } initBlock(firstResConvPower, random, initOneConv);//初始化两个残差块 initBlock(secondResConvPower, random, false);//初始化两个残差块 } private void fillZero(List matrixList, boolean fill) throws Exception { int size = matrixList.size(); for (int i = 0; i < size; i++) { Matrix matrix = matrixList.get(i); if (fill) { matrixList.set(i, padding2(matrix, 1)); } else { matrixList.set(i, unPadding2(matrix, 1)); } } } public void backError(List errorMatrixList) throws Exception {//返回误差 List errorList = backOneConvMatrix(errorMatrixList, secondResConvPower, 2); if (deep > 1) { List errorFinalMatrix = backOneConvMatrix(errorList, firstResConvPower, 1); backFatherError(errorFinalMatrix); } else { List errorFinalMatrix = backOneConvMatrix(errorList, firstResConvPower, 2); errorFinalMatrix = backDownPoolingByList(errorFinalMatrix);//退下池化 if (fill(deep, imageSize, false)) {//判断是否需要补0 fillZero(errorFinalMatrix, false); } ResBlockError(errorFinalMatrix, firstConvPower.getBackParameter(), firstConvPower.getMatrixNormList(), firstConvPower.getConvPower(), studyRate, 7, null, firstConvPower.getDymStudyRateList(), gaMa, gMaxTh, auto); } } private void backFatherError(List errorMatrixList) throws Exception {//返回上一层误差 if (fill(deep, imageSize, true)) {//不是偶数,要补0 fillZero(errorMatrixList, false); } fatherResBlock.backError(errorMatrixList); } private List backOneConvMatrix(List errorMatrixList, ResConvPower resConvPower, int deep) throws Exception { ConvLay firstConv = resConvPower.getFirstConvPower(); ConvLay secondConv = resConvPower.getSecondConvPower(); List> oneConvPower = null; List> dymStudyRateList = null; if (deep == 1) { oneConvPower = resConvPower.getOneConvPower(); dymStudyRateList = resConvPower.getDymStudyRateList(); } ResnetError resnetError = ResBlockError2(errorMatrixList, secondConv.getBackParameter(), secondConv.getMatrixNormList(), secondConv.getConvPower(), studyRate, true, oneConvPower, null, secondConv.getDymStudyRateList() , dymStudyRateList, gaMa, gMaxTh, auto); List resErrorMatrixList = resnetError.getResErrorMatrixList();//残差误差 List nextErrorMatrixList = resnetError.getNextErrorMatrixList();//下一层误差 List errorList; if (deep == 2) { errorList = ResBlockError2(nextErrorMatrixList, firstConv.getBackParameter(), firstConv.getMatrixNormList(), firstConv.getConvPower(), studyRate, false, oneConvPower, resErrorMatrixList, firstConv.getDymStudyRateList(), dymStudyRateList, gaMa, gMaxTh, auto).getNextErrorMatrixList(); } else { errorList = ResBlockError(nextErrorMatrixList, firstConv.getBackParameter(), firstConv.getMatrixNormList(), firstConv.getConvPower(), studyRate, 3, resErrorMatrixList, firstConv.getDymStudyRateList(), gaMa, gMaxTh , auto); } matrixOperation.mathMulByList(errorList, GRate); return errorList; } private void convMatrix(List feature, int step, boolean study, OutBack outBack, Map E, long eventID, boolean outFeature) throws Exception {// feature 准备跳层用 boolean one = step == 1; List featureList = oneConvMatrix(feature, firstResConvPower, study, one); List lastFeatureList = oneConvMatrix(featureList, secondResConvPower, study, true); if (sonResBlock != null) { sonResBlock.sendMatrixList(lastFeatureList, outBack, study, E, eventID, outFeature); } else {//最后卷积层了,求平均值 List outFeatures = new ArrayList<>(); for (Matrix matrix : lastFeatureList) { outFeatures.add(matrix.getAVG()); } if (!study && outFeature) { int size = outFeatures.size(); Matrix matrix = new Matrix(1, size); for (int j = 0; j < size; j++) { matrix.setNub(0, j, outFeatures.get(j)); } if (outBack != null) { outBack.getBackMatrix(matrix, 1, eventID); } else { throw new Exception("没有传入OutBack输出回调类"); } } else { if (sensoryNerves.size() == outFeatures.size()) { int size = sensoryNerves.size(); for (int i = 0; i < size; i++) { sensoryNerves.get(i).postMessage(eventID, outFeatures.get(i), study, E, outBack); } } else { throw new Exception("线性层与特征层特征维度不相等"); } } } } private List oneConvMatrix(List feature, ResConvPower resConvPower, boolean study, boolean one) throws Exception { ConvLay firstConvLay = resConvPower.getFirstConvPower(); ConvLay secondConvLay = resConvPower.getSecondConvPower(); List> oneConvPower = resConvPower.getOneConvPower(); List firstOutMatrix; if (one) {//步长为1 firstOutMatrix = downConvMany2(feature, firstConvLay.getConvPower(), study, firstConvLay.getBackParameter(), firstConvLay.getMatrixNormList(), null, null); } else {//步长为2 firstOutMatrix = downConvMany(feature, firstConvLay.getConvPower(), 3, study, firstConvLay.getBackParameter(), firstConvLay.getMatrixNormList()); } return downConvMany2(firstOutMatrix, secondConvLay.getConvPower(), study, secondConvLay.getBackParameter() , secondConvLay.getMatrixNormList(), feature, oneConvPower); } public void sendMatrixList(List matrixList, OutBack outBack, boolean study, Map E, long eventID, boolean outFeature) throws Exception { //判定特征大小是否为偶数 if (fill(deep, imageSize, true)) {//不是偶数,要补0 fillZero(matrixList, true); } if (deep == 1) { List myMatrixList = downConvMany(matrixList, firstConvPower.getConvPower(), 7, study, firstConvPower.getBackParameter(), firstConvPower.getMatrixNormList()); if (fill(deep, imageSize, false)) {//池化需要补0 fillZero(myMatrixList, true); } //池化 List nextMatrixList = new ArrayList<>(); for (Matrix matrix : myMatrixList) { nextMatrixList.add(downPooling(matrix)); } convMatrix(nextMatrixList, 1, study, outBack, E, eventID, outFeature); } else { convMatrix(matrixList, 2, study, outBack, E, eventID, outFeature); } } private int getChannelNo() { return (int) (channelNo * Math.pow(2, deep - 1));//卷积层输出特征大小 } private void initBlock(ResConvPower resConvPower, Random random, boolean initOneConv) throws Exception { resConvPower.setFirstConvPower(initMatrixPower(random, 3, getChannelNo(), false)); resConvPower.setSecondConvPower(initMatrixPower(random, 3, getChannelNo(), false)); if (deep > 1 && initOneConv) {//初始化11卷积层 int featureLength = getChannelNo();//卷积层输出特征大小 List> onePowers = new ArrayList<>(); List> dymStudyRateList = new ArrayList<>(); resConvPower.setOneConvPower(onePowers); resConvPower.setDymStudyRateList(dymStudyRateList); int length = featureLength / 2; for (int i = 0; i < featureLength; i++) { List oneConvPowerList = new ArrayList<>(); List dymStudyRage = new ArrayList<>(); for (int j = 0; j < length; j++) { oneConvPowerList.add(random.nextFloat() / length); dymStudyRage.add(0f); } dymStudyRateList.add(dymStudyRage); onePowers.add(oneConvPowerList); } } } private ConvLay initMatrixPower(Random random, int kernLen, int channelNo, boolean seven) throws Exception { int nerveNub = kernLen * kernLen; ConvLay convLay = new ConvLay(); List nerveMatrixList = new ArrayList<>();//一层当中所有的深度卷积核 List sumOfSquares = new ArrayList<>();//动态学习率 List matrixNormList = new ArrayList<>(); int size = getFeatureSize(deep, imageSize, seven); for (int k = 0; k < channelNo; k++) {//遍历通道 Matrix nerveMatrix = new Matrix(nerveNub, 1);//一组通道创建一组卷积核 Matrix dymStudyRate = new Matrix(nerveNub, 1); sumOfSquares.add(dymStudyRate); for (int i = 0; i < nerveMatrix.getX(); i++) {//初始化深度卷积核权重 float nub = random.nextFloat() / kernLen; nerveMatrix.setNub(i, 0, nub); } nerveMatrixList.add(nerveMatrix); MatrixNorm matrixNorm = new MatrixNorm(size, studyRate, gaMa, gMaxTh, auto); matrixNormList.add(matrixNorm); } convLay.setDymStudyRateList(sumOfSquares); convLay.setConvPower(nerveMatrixList); convLay.setMatrixNormList(matrixNormList); return convLay; } public void setSonResBlock(ResBlock sonResBlock) { this.sonResBlock = sonResBlock; } public void setFatherResBlock(ResBlock fatherResBlock) { this.fatherResBlock = fatherResBlock; } }