Add File
This commit is contained in:
275
src/main/java/org/dromara/easyai/resnet/ResBlock.java
Normal file
275
src/main/java/org/dromara/easyai/resnet/ResBlock.java
Normal file
@@ -0,0 +1,275 @@
|
||||
package org.dromara.easyai.resnet;
|
||||
|
||||
import org.dromara.easyai.conv.ResConvCount;
|
||||
import org.dromara.easyai.i.OutBack;
|
||||
import org.dromara.easyai.matrixTools.Matrix;
|
||||
import org.dromara.easyai.matrixTools.MatrixNorm;
|
||||
import org.dromara.easyai.matrixTools.MatrixOperation;
|
||||
import org.dromara.easyai.nerveEntity.SensoryNerve;
|
||||
import org.dromara.easyai.resnet.entity.ResBlockModel;
|
||||
import org.dromara.easyai.resnet.entity.ResnetError;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* @author lidapeng
|
||||
* @time 2025/4/11 13:40
|
||||
* @des resnet残差模块
|
||||
*/
|
||||
public class ResBlock extends ResConvCount {
|
||||
private ConvLay firstConvPower;//第一层卷积
|
||||
private final ResConvPower firstResConvPower = new ResConvPower();
|
||||
private final ResConvPower secondResConvPower = new ResConvPower();
|
||||
private final int channelNo;
|
||||
private final int deep;
|
||||
private final float studyRate;//全局学习率
|
||||
private final int imageSize;//图像大小
|
||||
private ResBlock fatherResBlock;
|
||||
private ResBlock sonResBlock;
|
||||
private final List<SensoryNerve> sensoryNerves;//输出神经元
|
||||
private final float gaMa;
|
||||
private final float gMaxTh;
|
||||
private final boolean auto;
|
||||
private final float GRate;//梯度衰减
|
||||
private final MatrixOperation matrixOperation = new MatrixOperation();
|
||||
|
||||
public ResBlockModel getModel() {
|
||||
ResBlockModel model = new ResBlockModel();
|
||||
model.setFirstResConvModel(firstResConvPower.getModel());
|
||||
model.setSecondResConvModel(secondResConvPower.getModel());
|
||||
if (deep == 1) {
|
||||
model.setFirstConvModel(firstConvPower.getModel());
|
||||
}
|
||||
return model;
|
||||
}
|
||||
|
||||
public void insertModel(ResBlockModel resBlockModel) {
|
||||
firstResConvPower.insertModel(resBlockModel.getFirstResConvModel());
|
||||
secondResConvPower.insertModel(resBlockModel.getSecondResConvModel());
|
||||
if (deep == 1) {
|
||||
firstConvPower.insertModel(resBlockModel.getFirstConvModel());
|
||||
}
|
||||
}
|
||||
|
||||
public ResBlock(int channelNo, int deep, float studyRate, int imageSize, List<SensoryNerve> sensoryNerves, float gaMa, float gMaxTh
|
||||
, boolean auto, float GRate) throws Exception {
|
||||
this.imageSize = imageSize;
|
||||
this.GRate = GRate;
|
||||
this.sensoryNerves = sensoryNerves;
|
||||
this.channelNo = channelNo;
|
||||
this.deep = deep;
|
||||
this.gMaxTh = gMaxTh;
|
||||
this.studyRate = studyRate;
|
||||
this.auto = auto;
|
||||
this.gaMa = gaMa;
|
||||
boolean initOneConv = true;
|
||||
Random random = new Random();
|
||||
if (deep == 1) {
|
||||
initOneConv = false;
|
||||
firstConvPower = initMatrixPower(random, 7, channelNo, true);
|
||||
}
|
||||
initBlock(firstResConvPower, random, initOneConv);//初始化两个残差块
|
||||
initBlock(secondResConvPower, random, false);//初始化两个残差块
|
||||
}
|
||||
|
||||
private void fillZero(List<Matrix> matrixList, boolean fill) throws Exception {
|
||||
int size = matrixList.size();
|
||||
for (int i = 0; i < size; i++) {
|
||||
Matrix matrix = matrixList.get(i);
|
||||
if (fill) {
|
||||
matrixList.set(i, padding2(matrix, 1));
|
||||
} else {
|
||||
matrixList.set(i, unPadding2(matrix, 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void backError(List<Matrix> errorMatrixList) throws Exception {//返回误差
|
||||
List<Matrix> errorList = backOneConvMatrix(errorMatrixList, secondResConvPower, 2);
|
||||
if (deep > 1) {
|
||||
List<Matrix> errorFinalMatrix = backOneConvMatrix(errorList, firstResConvPower, 1);
|
||||
backFatherError(errorFinalMatrix);
|
||||
} else {
|
||||
List<Matrix> errorFinalMatrix = backOneConvMatrix(errorList, firstResConvPower, 2);
|
||||
errorFinalMatrix = backDownPoolingByList(errorFinalMatrix);//退下池化
|
||||
if (fill(deep, imageSize, false)) {//判断是否需要补0
|
||||
fillZero(errorFinalMatrix, false);
|
||||
}
|
||||
ResBlockError(errorFinalMatrix, firstConvPower.getBackParameter(), firstConvPower.getMatrixNormList(),
|
||||
firstConvPower.getConvPower(), studyRate, 7, null, firstConvPower.getDymStudyRateList(),
|
||||
gaMa, gMaxTh, auto);
|
||||
}
|
||||
}
|
||||
|
||||
private void backFatherError(List<Matrix> errorMatrixList) throws Exception {//返回上一层误差
|
||||
if (fill(deep, imageSize, true)) {//不是偶数,要补0
|
||||
fillZero(errorMatrixList, false);
|
||||
}
|
||||
fatherResBlock.backError(errorMatrixList);
|
||||
}
|
||||
|
||||
private List<Matrix> backOneConvMatrix(List<Matrix> errorMatrixList, ResConvPower resConvPower, int deep) throws Exception {
|
||||
ConvLay firstConv = resConvPower.getFirstConvPower();
|
||||
ConvLay secondConv = resConvPower.getSecondConvPower();
|
||||
List<List<Float>> oneConvPower = null;
|
||||
List<List<Float>> dymStudyRateList = null;
|
||||
if (deep == 1) {
|
||||
oneConvPower = resConvPower.getOneConvPower();
|
||||
dymStudyRateList = resConvPower.getDymStudyRateList();
|
||||
}
|
||||
ResnetError resnetError = ResBlockError2(errorMatrixList, secondConv.getBackParameter(), secondConv.getMatrixNormList(),
|
||||
secondConv.getConvPower(), studyRate, true, oneConvPower, null, secondConv.getDymStudyRateList()
|
||||
, dymStudyRateList, gaMa, gMaxTh, auto);
|
||||
List<Matrix> resErrorMatrixList = resnetError.getResErrorMatrixList();//残差误差
|
||||
List<Matrix> nextErrorMatrixList = resnetError.getNextErrorMatrixList();//下一层误差
|
||||
List<Matrix> errorList;
|
||||
if (deep == 2) {
|
||||
errorList = ResBlockError2(nextErrorMatrixList, firstConv.getBackParameter(), firstConv.getMatrixNormList(),
|
||||
firstConv.getConvPower(), studyRate, false, oneConvPower, resErrorMatrixList,
|
||||
firstConv.getDymStudyRateList(), dymStudyRateList, gaMa, gMaxTh, auto).getNextErrorMatrixList();
|
||||
} else {
|
||||
errorList = ResBlockError(nextErrorMatrixList, firstConv.getBackParameter(), firstConv.getMatrixNormList(),
|
||||
firstConv.getConvPower(), studyRate, 3, resErrorMatrixList, firstConv.getDymStudyRateList(), gaMa, gMaxTh
|
||||
, auto);
|
||||
}
|
||||
matrixOperation.mathMulByList(errorList, GRate);
|
||||
return errorList;
|
||||
}
|
||||
|
||||
private void convMatrix(List<Matrix> feature, int step, boolean study, OutBack outBack, Map<Integer, Float> E, long eventID, boolean outFeature) throws Exception {// feature 准备跳层用
|
||||
boolean one = step == 1;
|
||||
List<Matrix> featureList = oneConvMatrix(feature, firstResConvPower, study, one);
|
||||
List<Matrix> lastFeatureList = oneConvMatrix(featureList, secondResConvPower, study, true);
|
||||
if (sonResBlock != null) {
|
||||
sonResBlock.sendMatrixList(lastFeatureList, outBack, study, E, eventID, outFeature);
|
||||
} else {//最后卷积层了,求平均值
|
||||
List<Float> outFeatures = new ArrayList<>();
|
||||
for (Matrix matrix : lastFeatureList) {
|
||||
outFeatures.add(matrix.getAVG());
|
||||
}
|
||||
if (!study && outFeature) {
|
||||
int size = outFeatures.size();
|
||||
Matrix matrix = new Matrix(1, size);
|
||||
for (int j = 0; j < size; j++) {
|
||||
matrix.setNub(0, j, outFeatures.get(j));
|
||||
}
|
||||
if (outBack != null) {
|
||||
outBack.getBackMatrix(matrix, 1, eventID);
|
||||
} else {
|
||||
throw new Exception("没有传入OutBack输出回调类");
|
||||
}
|
||||
} else {
|
||||
if (sensoryNerves.size() == outFeatures.size()) {
|
||||
int size = sensoryNerves.size();
|
||||
for (int i = 0; i < size; i++) {
|
||||
sensoryNerves.get(i).postMessage(eventID, outFeatures.get(i), study, E, outBack);
|
||||
}
|
||||
} else {
|
||||
throw new Exception("线性层与特征层特征维度不相等");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private List<Matrix> oneConvMatrix(List<Matrix> feature, ResConvPower resConvPower, boolean study, boolean one) throws Exception {
|
||||
ConvLay firstConvLay = resConvPower.getFirstConvPower();
|
||||
ConvLay secondConvLay = resConvPower.getSecondConvPower();
|
||||
List<List<Float>> oneConvPower = resConvPower.getOneConvPower();
|
||||
List<Matrix> firstOutMatrix;
|
||||
if (one) {//步长为1
|
||||
firstOutMatrix = downConvMany2(feature, firstConvLay.getConvPower(), study, firstConvLay.getBackParameter(),
|
||||
firstConvLay.getMatrixNormList(), null, null);
|
||||
} else {//步长为2
|
||||
firstOutMatrix = downConvMany(feature, firstConvLay.getConvPower(), 3, study, firstConvLay.getBackParameter(),
|
||||
firstConvLay.getMatrixNormList());
|
||||
}
|
||||
return downConvMany2(firstOutMatrix, secondConvLay.getConvPower(), study, secondConvLay.getBackParameter()
|
||||
, secondConvLay.getMatrixNormList(), feature, oneConvPower);
|
||||
}
|
||||
|
||||
public void sendMatrixList(List<Matrix> matrixList, OutBack outBack, boolean study, Map<Integer, Float> E, long eventID, boolean outFeature) throws Exception {
|
||||
//判定特征大小是否为偶数
|
||||
if (fill(deep, imageSize, true)) {//不是偶数,要补0
|
||||
fillZero(matrixList, true);
|
||||
}
|
||||
if (deep == 1) {
|
||||
List<Matrix> myMatrixList = downConvMany(matrixList, firstConvPower.getConvPower(), 7, study, firstConvPower.getBackParameter(),
|
||||
firstConvPower.getMatrixNormList());
|
||||
if (fill(deep, imageSize, false)) {//池化需要补0
|
||||
fillZero(myMatrixList, true);
|
||||
}
|
||||
//池化
|
||||
List<Matrix> nextMatrixList = new ArrayList<>();
|
||||
for (Matrix matrix : myMatrixList) {
|
||||
nextMatrixList.add(downPooling(matrix));
|
||||
}
|
||||
convMatrix(nextMatrixList, 1, study, outBack, E, eventID, outFeature);
|
||||
} else {
|
||||
convMatrix(matrixList, 2, study, outBack, E, eventID, outFeature);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private int getChannelNo() {
|
||||
return (int) (channelNo * Math.pow(2, deep - 1));//卷积层输出特征大小
|
||||
}
|
||||
|
||||
private void initBlock(ResConvPower resConvPower, Random random, boolean initOneConv) throws Exception {
|
||||
resConvPower.setFirstConvPower(initMatrixPower(random, 3, getChannelNo(), false));
|
||||
resConvPower.setSecondConvPower(initMatrixPower(random, 3, getChannelNo(), false));
|
||||
if (deep > 1 && initOneConv) {//初始化11卷积层
|
||||
int featureLength = getChannelNo();//卷积层输出特征大小
|
||||
List<List<Float>> onePowers = new ArrayList<>();
|
||||
List<List<Float>> dymStudyRateList = new ArrayList<>();
|
||||
resConvPower.setOneConvPower(onePowers);
|
||||
resConvPower.setDymStudyRateList(dymStudyRateList);
|
||||
int length = featureLength / 2;
|
||||
for (int i = 0; i < featureLength; i++) {
|
||||
List<Float> oneConvPowerList = new ArrayList<>();
|
||||
List<Float> dymStudyRage = new ArrayList<>();
|
||||
for (int j = 0; j < length; j++) {
|
||||
oneConvPowerList.add(random.nextFloat() / length);
|
||||
dymStudyRage.add(0f);
|
||||
}
|
||||
dymStudyRateList.add(dymStudyRage);
|
||||
onePowers.add(oneConvPowerList);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private ConvLay initMatrixPower(Random random, int kernLen, int channelNo, boolean seven) throws Exception {
|
||||
int nerveNub = kernLen * kernLen;
|
||||
ConvLay convLay = new ConvLay();
|
||||
List<Matrix> nerveMatrixList = new ArrayList<>();//一层当中所有的深度卷积核
|
||||
List<Matrix> sumOfSquares = new ArrayList<>();//动态学习率
|
||||
List<MatrixNorm> matrixNormList = new ArrayList<>();
|
||||
int size = getFeatureSize(deep, imageSize, seven);
|
||||
for (int k = 0; k < channelNo; k++) {//遍历通道
|
||||
Matrix nerveMatrix = new Matrix(nerveNub, 1);//一组通道创建一组卷积核
|
||||
Matrix dymStudyRate = new Matrix(nerveNub, 1);
|
||||
sumOfSquares.add(dymStudyRate);
|
||||
for (int i = 0; i < nerveMatrix.getX(); i++) {//初始化深度卷积核权重
|
||||
float nub = random.nextFloat() / kernLen;
|
||||
nerveMatrix.setNub(i, 0, nub);
|
||||
}
|
||||
nerveMatrixList.add(nerveMatrix);
|
||||
MatrixNorm matrixNorm = new MatrixNorm(size, studyRate, gaMa, gMaxTh, auto);
|
||||
matrixNormList.add(matrixNorm);
|
||||
}
|
||||
convLay.setDymStudyRateList(sumOfSquares);
|
||||
convLay.setConvPower(nerveMatrixList);
|
||||
convLay.setMatrixNormList(matrixNormList);
|
||||
return convLay;
|
||||
}
|
||||
|
||||
public void setSonResBlock(ResBlock sonResBlock) {
|
||||
this.sonResBlock = sonResBlock;
|
||||
}
|
||||
|
||||
public void setFatherResBlock(ResBlock fatherResBlock) {
|
||||
this.fatherResBlock = fatherResBlock;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user