This commit is contained in:
2025-09-04 14:09:33 +08:00
parent bc48177e3a
commit b293824bf0

View File

@@ -0,0 +1,275 @@
package org.dromara.easyai.resnet;
import org.dromara.easyai.conv.ResConvCount;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixNorm;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.nerveEntity.SensoryNerve;
import org.dromara.easyai.resnet.entity.ResBlockModel;
import org.dromara.easyai.resnet.entity.ResnetError;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
/**
* @author lidapeng
* @time 2025/4/11 13:40
* @des resnet残差模块
*/
public class ResBlock extends ResConvCount {
private ConvLay firstConvPower;//第一层卷积
private final ResConvPower firstResConvPower = new ResConvPower();
private final ResConvPower secondResConvPower = new ResConvPower();
private final int channelNo;
private final int deep;
private final float studyRate;//全局学习率
private final int imageSize;//图像大小
private ResBlock fatherResBlock;
private ResBlock sonResBlock;
private final List<SensoryNerve> sensoryNerves;//输出神经元
private final float gaMa;
private final float gMaxTh;
private final boolean auto;
private final float GRate;//梯度衰减
private final MatrixOperation matrixOperation = new MatrixOperation();
public ResBlockModel getModel() {
ResBlockModel model = new ResBlockModel();
model.setFirstResConvModel(firstResConvPower.getModel());
model.setSecondResConvModel(secondResConvPower.getModel());
if (deep == 1) {
model.setFirstConvModel(firstConvPower.getModel());
}
return model;
}
public void insertModel(ResBlockModel resBlockModel) {
firstResConvPower.insertModel(resBlockModel.getFirstResConvModel());
secondResConvPower.insertModel(resBlockModel.getSecondResConvModel());
if (deep == 1) {
firstConvPower.insertModel(resBlockModel.getFirstConvModel());
}
}
public ResBlock(int channelNo, int deep, float studyRate, int imageSize, List<SensoryNerve> sensoryNerves, float gaMa, float gMaxTh
, boolean auto, float GRate) throws Exception {
this.imageSize = imageSize;
this.GRate = GRate;
this.sensoryNerves = sensoryNerves;
this.channelNo = channelNo;
this.deep = deep;
this.gMaxTh = gMaxTh;
this.studyRate = studyRate;
this.auto = auto;
this.gaMa = gaMa;
boolean initOneConv = true;
Random random = new Random();
if (deep == 1) {
initOneConv = false;
firstConvPower = initMatrixPower(random, 7, channelNo, true);
}
initBlock(firstResConvPower, random, initOneConv);//初始化两个残差块
initBlock(secondResConvPower, random, false);//初始化两个残差块
}
private void fillZero(List<Matrix> matrixList, boolean fill) throws Exception {
int size = matrixList.size();
for (int i = 0; i < size; i++) {
Matrix matrix = matrixList.get(i);
if (fill) {
matrixList.set(i, padding2(matrix, 1));
} else {
matrixList.set(i, unPadding2(matrix, 1));
}
}
}
public void backError(List<Matrix> errorMatrixList) throws Exception {//返回误差
List<Matrix> errorList = backOneConvMatrix(errorMatrixList, secondResConvPower, 2);
if (deep > 1) {
List<Matrix> errorFinalMatrix = backOneConvMatrix(errorList, firstResConvPower, 1);
backFatherError(errorFinalMatrix);
} else {
List<Matrix> errorFinalMatrix = backOneConvMatrix(errorList, firstResConvPower, 2);
errorFinalMatrix = backDownPoolingByList(errorFinalMatrix);//退下池化
if (fill(deep, imageSize, false)) {//判断是否需要补0
fillZero(errorFinalMatrix, false);
}
ResBlockError(errorFinalMatrix, firstConvPower.getBackParameter(), firstConvPower.getMatrixNormList(),
firstConvPower.getConvPower(), studyRate, 7, null, firstConvPower.getDymStudyRateList(),
gaMa, gMaxTh, auto);
}
}
private void backFatherError(List<Matrix> errorMatrixList) throws Exception {//返回上一层误差
if (fill(deep, imageSize, true)) {//不是偶数,要补0
fillZero(errorMatrixList, false);
}
fatherResBlock.backError(errorMatrixList);
}
private List<Matrix> backOneConvMatrix(List<Matrix> errorMatrixList, ResConvPower resConvPower, int deep) throws Exception {
ConvLay firstConv = resConvPower.getFirstConvPower();
ConvLay secondConv = resConvPower.getSecondConvPower();
List<List<Float>> oneConvPower = null;
List<List<Float>> dymStudyRateList = null;
if (deep == 1) {
oneConvPower = resConvPower.getOneConvPower();
dymStudyRateList = resConvPower.getDymStudyRateList();
}
ResnetError resnetError = ResBlockError2(errorMatrixList, secondConv.getBackParameter(), secondConv.getMatrixNormList(),
secondConv.getConvPower(), studyRate, true, oneConvPower, null, secondConv.getDymStudyRateList()
, dymStudyRateList, gaMa, gMaxTh, auto);
List<Matrix> resErrorMatrixList = resnetError.getResErrorMatrixList();//残差误差
List<Matrix> nextErrorMatrixList = resnetError.getNextErrorMatrixList();//下一层误差
List<Matrix> errorList;
if (deep == 2) {
errorList = ResBlockError2(nextErrorMatrixList, firstConv.getBackParameter(), firstConv.getMatrixNormList(),
firstConv.getConvPower(), studyRate, false, oneConvPower, resErrorMatrixList,
firstConv.getDymStudyRateList(), dymStudyRateList, gaMa, gMaxTh, auto).getNextErrorMatrixList();
} else {
errorList = ResBlockError(nextErrorMatrixList, firstConv.getBackParameter(), firstConv.getMatrixNormList(),
firstConv.getConvPower(), studyRate, 3, resErrorMatrixList, firstConv.getDymStudyRateList(), gaMa, gMaxTh
, auto);
}
matrixOperation.mathMulByList(errorList, GRate);
return errorList;
}
private void convMatrix(List<Matrix> feature, int step, boolean study, OutBack outBack, Map<Integer, Float> E, long eventID, boolean outFeature) throws Exception {// feature 准备跳层用
boolean one = step == 1;
List<Matrix> featureList = oneConvMatrix(feature, firstResConvPower, study, one);
List<Matrix> lastFeatureList = oneConvMatrix(featureList, secondResConvPower, study, true);
if (sonResBlock != null) {
sonResBlock.sendMatrixList(lastFeatureList, outBack, study, E, eventID, outFeature);
} else {//最后卷积层了,求平均值
List<Float> outFeatures = new ArrayList<>();
for (Matrix matrix : lastFeatureList) {
outFeatures.add(matrix.getAVG());
}
if (!study && outFeature) {
int size = outFeatures.size();
Matrix matrix = new Matrix(1, size);
for (int j = 0; j < size; j++) {
matrix.setNub(0, j, outFeatures.get(j));
}
if (outBack != null) {
outBack.getBackMatrix(matrix, 1, eventID);
} else {
throw new Exception("没有传入OutBack输出回调类");
}
} else {
if (sensoryNerves.size() == outFeatures.size()) {
int size = sensoryNerves.size();
for (int i = 0; i < size; i++) {
sensoryNerves.get(i).postMessage(eventID, outFeatures.get(i), study, E, outBack);
}
} else {
throw new Exception("线性层与特征层特征维度不相等");
}
}
}
}
private List<Matrix> oneConvMatrix(List<Matrix> feature, ResConvPower resConvPower, boolean study, boolean one) throws Exception {
ConvLay firstConvLay = resConvPower.getFirstConvPower();
ConvLay secondConvLay = resConvPower.getSecondConvPower();
List<List<Float>> oneConvPower = resConvPower.getOneConvPower();
List<Matrix> firstOutMatrix;
if (one) {//步长为1
firstOutMatrix = downConvMany2(feature, firstConvLay.getConvPower(), study, firstConvLay.getBackParameter(),
firstConvLay.getMatrixNormList(), null, null);
} else {//步长为2
firstOutMatrix = downConvMany(feature, firstConvLay.getConvPower(), 3, study, firstConvLay.getBackParameter(),
firstConvLay.getMatrixNormList());
}
return downConvMany2(firstOutMatrix, secondConvLay.getConvPower(), study, secondConvLay.getBackParameter()
, secondConvLay.getMatrixNormList(), feature, oneConvPower);
}
public void sendMatrixList(List<Matrix> matrixList, OutBack outBack, boolean study, Map<Integer, Float> E, long eventID, boolean outFeature) throws Exception {
//判定特征大小是否为偶数
if (fill(deep, imageSize, true)) {//不是偶数,要补0
fillZero(matrixList, true);
}
if (deep == 1) {
List<Matrix> myMatrixList = downConvMany(matrixList, firstConvPower.getConvPower(), 7, study, firstConvPower.getBackParameter(),
firstConvPower.getMatrixNormList());
if (fill(deep, imageSize, false)) {//池化需要补0
fillZero(myMatrixList, true);
}
//池化
List<Matrix> nextMatrixList = new ArrayList<>();
for (Matrix matrix : myMatrixList) {
nextMatrixList.add(downPooling(matrix));
}
convMatrix(nextMatrixList, 1, study, outBack, E, eventID, outFeature);
} else {
convMatrix(matrixList, 2, study, outBack, E, eventID, outFeature);
}
}
private int getChannelNo() {
return (int) (channelNo * Math.pow(2, deep - 1));//卷积层输出特征大小
}
private void initBlock(ResConvPower resConvPower, Random random, boolean initOneConv) throws Exception {
resConvPower.setFirstConvPower(initMatrixPower(random, 3, getChannelNo(), false));
resConvPower.setSecondConvPower(initMatrixPower(random, 3, getChannelNo(), false));
if (deep > 1 && initOneConv) {//初始化11卷积层
int featureLength = getChannelNo();//卷积层输出特征大小
List<List<Float>> onePowers = new ArrayList<>();
List<List<Float>> dymStudyRateList = new ArrayList<>();
resConvPower.setOneConvPower(onePowers);
resConvPower.setDymStudyRateList(dymStudyRateList);
int length = featureLength / 2;
for (int i = 0; i < featureLength; i++) {
List<Float> oneConvPowerList = new ArrayList<>();
List<Float> dymStudyRage = new ArrayList<>();
for (int j = 0; j < length; j++) {
oneConvPowerList.add(random.nextFloat() / length);
dymStudyRage.add(0f);
}
dymStudyRateList.add(dymStudyRage);
onePowers.add(oneConvPowerList);
}
}
}
private ConvLay initMatrixPower(Random random, int kernLen, int channelNo, boolean seven) throws Exception {
int nerveNub = kernLen * kernLen;
ConvLay convLay = new ConvLay();
List<Matrix> nerveMatrixList = new ArrayList<>();//一层当中所有的深度卷积核
List<Matrix> sumOfSquares = new ArrayList<>();//动态学习率
List<MatrixNorm> matrixNormList = new ArrayList<>();
int size = getFeatureSize(deep, imageSize, seven);
for (int k = 0; k < channelNo; k++) {//遍历通道
Matrix nerveMatrix = new Matrix(nerveNub, 1);//一组通道创建一组卷积核
Matrix dymStudyRate = new Matrix(nerveNub, 1);
sumOfSquares.add(dymStudyRate);
for (int i = 0; i < nerveMatrix.getX(); i++) {//初始化深度卷积核权重
float nub = random.nextFloat() / kernLen;
nerveMatrix.setNub(i, 0, nub);
}
nerveMatrixList.add(nerveMatrix);
MatrixNorm matrixNorm = new MatrixNorm(size, studyRate, gaMa, gMaxTh, auto);
matrixNormList.add(matrixNorm);
}
convLay.setDymStudyRateList(sumOfSquares);
convLay.setConvPower(nerveMatrixList);
convLay.setMatrixNormList(matrixNormList);
return convLay;
}
public void setSonResBlock(ResBlock sonResBlock) {
this.sonResBlock = sonResBlock;
}
public void setFatherResBlock(ResBlock fatherResBlock) {
this.fatherResBlock = fatherResBlock;
}
}