276 lines
13 KiB
Java
276 lines
13 KiB
Java
package org.dromara.easyai.resnet;
|
|
|
|
import org.dromara.easyai.conv.ResConvCount;
|
|
import org.dromara.easyai.i.OutBack;
|
|
import org.dromara.easyai.matrixTools.Matrix;
|
|
import org.dromara.easyai.matrixTools.MatrixNorm;
|
|
import org.dromara.easyai.matrixTools.MatrixOperation;
|
|
import org.dromara.easyai.nerveEntity.SensoryNerve;
|
|
import org.dromara.easyai.resnet.entity.ResBlockModel;
|
|
import org.dromara.easyai.resnet.entity.ResnetError;
|
|
|
|
import java.util.ArrayList;
|
|
import java.util.List;
|
|
import java.util.Map;
|
|
import java.util.Random;
|
|
|
|
/**
|
|
* @author lidapeng
|
|
* @time 2025/4/11 13:40
|
|
* @des resnet残差模块
|
|
*/
|
|
public class ResBlock extends ResConvCount {
|
|
private ConvLay firstConvPower;//第一层卷积
|
|
private final ResConvPower firstResConvPower = new ResConvPower();
|
|
private final ResConvPower secondResConvPower = new ResConvPower();
|
|
private final int channelNo;
|
|
private final int deep;
|
|
private final float studyRate;//全局学习率
|
|
private final int imageSize;//图像大小
|
|
private ResBlock fatherResBlock;
|
|
private ResBlock sonResBlock;
|
|
private final List<SensoryNerve> sensoryNerves;//输出神经元
|
|
private final float gaMa;
|
|
private final float gMaxTh;
|
|
private final boolean auto;
|
|
private final float GRate;//梯度衰减
|
|
private final MatrixOperation matrixOperation = new MatrixOperation();
|
|
|
|
public ResBlockModel getModel() {
|
|
ResBlockModel model = new ResBlockModel();
|
|
model.setFirstResConvModel(firstResConvPower.getModel());
|
|
model.setSecondResConvModel(secondResConvPower.getModel());
|
|
if (deep == 1) {
|
|
model.setFirstConvModel(firstConvPower.getModel());
|
|
}
|
|
return model;
|
|
}
|
|
|
|
public void insertModel(ResBlockModel resBlockModel) {
|
|
firstResConvPower.insertModel(resBlockModel.getFirstResConvModel());
|
|
secondResConvPower.insertModel(resBlockModel.getSecondResConvModel());
|
|
if (deep == 1) {
|
|
firstConvPower.insertModel(resBlockModel.getFirstConvModel());
|
|
}
|
|
}
|
|
|
|
public ResBlock(int channelNo, int deep, float studyRate, int imageSize, List<SensoryNerve> sensoryNerves, float gaMa, float gMaxTh
|
|
, boolean auto, float GRate) throws Exception {
|
|
this.imageSize = imageSize;
|
|
this.GRate = GRate;
|
|
this.sensoryNerves = sensoryNerves;
|
|
this.channelNo = channelNo;
|
|
this.deep = deep;
|
|
this.gMaxTh = gMaxTh;
|
|
this.studyRate = studyRate;
|
|
this.auto = auto;
|
|
this.gaMa = gaMa;
|
|
boolean initOneConv = true;
|
|
Random random = new Random();
|
|
if (deep == 1) {
|
|
initOneConv = false;
|
|
firstConvPower = initMatrixPower(random, 7, channelNo, true);
|
|
}
|
|
initBlock(firstResConvPower, random, initOneConv);//初始化两个残差块
|
|
initBlock(secondResConvPower, random, false);//初始化两个残差块
|
|
}
|
|
|
|
private void fillZero(List<Matrix> matrixList, boolean fill) throws Exception {
|
|
int size = matrixList.size();
|
|
for (int i = 0; i < size; i++) {
|
|
Matrix matrix = matrixList.get(i);
|
|
if (fill) {
|
|
matrixList.set(i, padding2(matrix, 1));
|
|
} else {
|
|
matrixList.set(i, unPadding2(matrix, 1));
|
|
}
|
|
}
|
|
}
|
|
|
|
public void backError(List<Matrix> errorMatrixList) throws Exception {//返回误差
|
|
List<Matrix> errorList = backOneConvMatrix(errorMatrixList, secondResConvPower, 2);
|
|
if (deep > 1) {
|
|
List<Matrix> errorFinalMatrix = backOneConvMatrix(errorList, firstResConvPower, 1);
|
|
backFatherError(errorFinalMatrix);
|
|
} else {
|
|
List<Matrix> errorFinalMatrix = backOneConvMatrix(errorList, firstResConvPower, 2);
|
|
errorFinalMatrix = backDownPoolingByList(errorFinalMatrix);//退下池化
|
|
if (fill(deep, imageSize, false)) {//判断是否需要补0
|
|
fillZero(errorFinalMatrix, false);
|
|
}
|
|
ResBlockError(errorFinalMatrix, firstConvPower.getBackParameter(), firstConvPower.getMatrixNormList(),
|
|
firstConvPower.getConvPower(), studyRate, 7, null, firstConvPower.getDymStudyRateList(),
|
|
gaMa, gMaxTh, auto);
|
|
}
|
|
}
|
|
|
|
private void backFatherError(List<Matrix> errorMatrixList) throws Exception {//返回上一层误差
|
|
if (fill(deep, imageSize, true)) {//不是偶数,要补0
|
|
fillZero(errorMatrixList, false);
|
|
}
|
|
fatherResBlock.backError(errorMatrixList);
|
|
}
|
|
|
|
private List<Matrix> backOneConvMatrix(List<Matrix> errorMatrixList, ResConvPower resConvPower, int deep) throws Exception {
|
|
ConvLay firstConv = resConvPower.getFirstConvPower();
|
|
ConvLay secondConv = resConvPower.getSecondConvPower();
|
|
List<List<Float>> oneConvPower = null;
|
|
List<List<Float>> dymStudyRateList = null;
|
|
if (deep == 1) {
|
|
oneConvPower = resConvPower.getOneConvPower();
|
|
dymStudyRateList = resConvPower.getDymStudyRateList();
|
|
}
|
|
ResnetError resnetError = ResBlockError2(errorMatrixList, secondConv.getBackParameter(), secondConv.getMatrixNormList(),
|
|
secondConv.getConvPower(), studyRate, true, oneConvPower, null, secondConv.getDymStudyRateList()
|
|
, dymStudyRateList, gaMa, gMaxTh, auto);
|
|
List<Matrix> resErrorMatrixList = resnetError.getResErrorMatrixList();//残差误差
|
|
List<Matrix> nextErrorMatrixList = resnetError.getNextErrorMatrixList();//下一层误差
|
|
List<Matrix> errorList;
|
|
if (deep == 2) {
|
|
errorList = ResBlockError2(nextErrorMatrixList, firstConv.getBackParameter(), firstConv.getMatrixNormList(),
|
|
firstConv.getConvPower(), studyRate, false, oneConvPower, resErrorMatrixList,
|
|
firstConv.getDymStudyRateList(), dymStudyRateList, gaMa, gMaxTh, auto).getNextErrorMatrixList();
|
|
} else {
|
|
errorList = ResBlockError(nextErrorMatrixList, firstConv.getBackParameter(), firstConv.getMatrixNormList(),
|
|
firstConv.getConvPower(), studyRate, 3, resErrorMatrixList, firstConv.getDymStudyRateList(), gaMa, gMaxTh
|
|
, auto);
|
|
}
|
|
matrixOperation.mathMulByList(errorList, GRate);
|
|
return errorList;
|
|
}
|
|
|
|
private void convMatrix(List<Matrix> feature, int step, boolean study, OutBack outBack, Map<Integer, Float> E, long eventID, boolean outFeature) throws Exception {// feature 准备跳层用
|
|
boolean one = step == 1;
|
|
List<Matrix> featureList = oneConvMatrix(feature, firstResConvPower, study, one);
|
|
List<Matrix> lastFeatureList = oneConvMatrix(featureList, secondResConvPower, study, true);
|
|
if (sonResBlock != null) {
|
|
sonResBlock.sendMatrixList(lastFeatureList, outBack, study, E, eventID, outFeature);
|
|
} else {//最后卷积层了,求平均值
|
|
List<Float> outFeatures = new ArrayList<>();
|
|
for (Matrix matrix : lastFeatureList) {
|
|
outFeatures.add(matrix.getAVG());
|
|
}
|
|
if (!study && outFeature) {
|
|
int size = outFeatures.size();
|
|
Matrix matrix = new Matrix(1, size);
|
|
for (int j = 0; j < size; j++) {
|
|
matrix.setNub(0, j, outFeatures.get(j));
|
|
}
|
|
if (outBack != null) {
|
|
outBack.getBackMatrix(matrix, 1, eventID);
|
|
} else {
|
|
throw new Exception("没有传入OutBack输出回调类");
|
|
}
|
|
} else {
|
|
if (sensoryNerves.size() == outFeatures.size()) {
|
|
int size = sensoryNerves.size();
|
|
for (int i = 0; i < size; i++) {
|
|
sensoryNerves.get(i).postMessage(eventID, outFeatures.get(i), study, E, outBack);
|
|
}
|
|
} else {
|
|
throw new Exception("线性层与特征层特征维度不相等");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
private List<Matrix> oneConvMatrix(List<Matrix> feature, ResConvPower resConvPower, boolean study, boolean one) throws Exception {
|
|
ConvLay firstConvLay = resConvPower.getFirstConvPower();
|
|
ConvLay secondConvLay = resConvPower.getSecondConvPower();
|
|
List<List<Float>> oneConvPower = resConvPower.getOneConvPower();
|
|
List<Matrix> firstOutMatrix;
|
|
if (one) {//步长为1
|
|
firstOutMatrix = downConvMany2(feature, firstConvLay.getConvPower(), study, firstConvLay.getBackParameter(),
|
|
firstConvLay.getMatrixNormList(), null, null);
|
|
} else {//步长为2
|
|
firstOutMatrix = downConvMany(feature, firstConvLay.getConvPower(), 3, study, firstConvLay.getBackParameter(),
|
|
firstConvLay.getMatrixNormList());
|
|
}
|
|
return downConvMany2(firstOutMatrix, secondConvLay.getConvPower(), study, secondConvLay.getBackParameter()
|
|
, secondConvLay.getMatrixNormList(), feature, oneConvPower);
|
|
}
|
|
|
|
public void sendMatrixList(List<Matrix> matrixList, OutBack outBack, boolean study, Map<Integer, Float> E, long eventID, boolean outFeature) throws Exception {
|
|
//判定特征大小是否为偶数
|
|
if (fill(deep, imageSize, true)) {//不是偶数,要补0
|
|
fillZero(matrixList, true);
|
|
}
|
|
if (deep == 1) {
|
|
List<Matrix> myMatrixList = downConvMany(matrixList, firstConvPower.getConvPower(), 7, study, firstConvPower.getBackParameter(),
|
|
firstConvPower.getMatrixNormList());
|
|
if (fill(deep, imageSize, false)) {//池化需要补0
|
|
fillZero(myMatrixList, true);
|
|
}
|
|
//池化
|
|
List<Matrix> nextMatrixList = new ArrayList<>();
|
|
for (Matrix matrix : myMatrixList) {
|
|
nextMatrixList.add(downPooling(matrix));
|
|
}
|
|
convMatrix(nextMatrixList, 1, study, outBack, E, eventID, outFeature);
|
|
} else {
|
|
convMatrix(matrixList, 2, study, outBack, E, eventID, outFeature);
|
|
}
|
|
|
|
}
|
|
|
|
private int getChannelNo() {
|
|
return (int) (channelNo * Math.pow(2, deep - 1));//卷积层输出特征大小
|
|
}
|
|
|
|
private void initBlock(ResConvPower resConvPower, Random random, boolean initOneConv) throws Exception {
|
|
resConvPower.setFirstConvPower(initMatrixPower(random, 3, getChannelNo(), false));
|
|
resConvPower.setSecondConvPower(initMatrixPower(random, 3, getChannelNo(), false));
|
|
if (deep > 1 && initOneConv) {//初始化11卷积层
|
|
int featureLength = getChannelNo();//卷积层输出特征大小
|
|
List<List<Float>> onePowers = new ArrayList<>();
|
|
List<List<Float>> dymStudyRateList = new ArrayList<>();
|
|
resConvPower.setOneConvPower(onePowers);
|
|
resConvPower.setDymStudyRateList(dymStudyRateList);
|
|
int length = featureLength / 2;
|
|
for (int i = 0; i < featureLength; i++) {
|
|
List<Float> oneConvPowerList = new ArrayList<>();
|
|
List<Float> dymStudyRage = new ArrayList<>();
|
|
for (int j = 0; j < length; j++) {
|
|
oneConvPowerList.add(random.nextFloat() / length);
|
|
dymStudyRage.add(0f);
|
|
}
|
|
dymStudyRateList.add(dymStudyRage);
|
|
onePowers.add(oneConvPowerList);
|
|
}
|
|
}
|
|
}
|
|
|
|
private ConvLay initMatrixPower(Random random, int kernLen, int channelNo, boolean seven) throws Exception {
|
|
int nerveNub = kernLen * kernLen;
|
|
ConvLay convLay = new ConvLay();
|
|
List<Matrix> nerveMatrixList = new ArrayList<>();//一层当中所有的深度卷积核
|
|
List<Matrix> sumOfSquares = new ArrayList<>();//动态学习率
|
|
List<MatrixNorm> matrixNormList = new ArrayList<>();
|
|
int size = getFeatureSize(deep, imageSize, seven);
|
|
for (int k = 0; k < channelNo; k++) {//遍历通道
|
|
Matrix nerveMatrix = new Matrix(nerveNub, 1);//一组通道创建一组卷积核
|
|
Matrix dymStudyRate = new Matrix(nerveNub, 1);
|
|
sumOfSquares.add(dymStudyRate);
|
|
for (int i = 0; i < nerveMatrix.getX(); i++) {//初始化深度卷积核权重
|
|
float nub = random.nextFloat() / kernLen;
|
|
nerveMatrix.setNub(i, 0, nub);
|
|
}
|
|
nerveMatrixList.add(nerveMatrix);
|
|
MatrixNorm matrixNorm = new MatrixNorm(size, studyRate, gaMa, gMaxTh, auto);
|
|
matrixNormList.add(matrixNorm);
|
|
}
|
|
convLay.setDymStudyRateList(sumOfSquares);
|
|
convLay.setConvPower(nerveMatrixList);
|
|
convLay.setMatrixNormList(matrixNormList);
|
|
return convLay;
|
|
}
|
|
|
|
public void setSonResBlock(ResBlock sonResBlock) {
|
|
this.sonResBlock = sonResBlock;
|
|
}
|
|
|
|
public void setFatherResBlock(ResBlock fatherResBlock) {
|
|
this.fatherResBlock = fatherResBlock;
|
|
}
|
|
}
|