Add File
This commit is contained in:
87
src/main/java/org/dromara/easyai/resnet/ResnetManager.java
Normal file
87
src/main/java/org/dromara/easyai/resnet/ResnetManager.java
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package org.dromara.easyai.resnet;
|
||||||
|
|
||||||
|
import org.dromara.easyai.config.ResnetConfig;
|
||||||
|
import org.dromara.easyai.conv.ResConvCount;
|
||||||
|
import org.dromara.easyai.i.ActiveFunction;
|
||||||
|
import org.dromara.easyai.nerveCenter.NerveManager;
|
||||||
|
import org.dromara.easyai.nerveEntity.SensoryNerve;
|
||||||
|
import org.dromara.easyai.resnet.entity.ResBlockModel;
|
||||||
|
import org.dromara.easyai.resnet.entity.ResnetModel;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author lidapeng
|
||||||
|
* @time 2025/4/11 10:51
|
||||||
|
* @des resnet管理器
|
||||||
|
*/
|
||||||
|
public class ResnetManager extends ResConvCount {
|
||||||
|
private final NerveManager nerveManager;
|
||||||
|
private final List<ResBlock> resBlockList = new ArrayList<>();//残差集合
|
||||||
|
private final ResnetInput restNetInput;
|
||||||
|
|
||||||
|
public ResnetInput getRestNetInput() {
|
||||||
|
return restNetInput;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ResnetModel getModel() throws Exception {
|
||||||
|
ResnetModel resnetModel = new ResnetModel();
|
||||||
|
List<ResBlockModel> resBlockModelList = new ArrayList<>();
|
||||||
|
resnetModel.setResBlockModelList(resBlockModelList);
|
||||||
|
for (ResBlock resBlock : resBlockList) {
|
||||||
|
resBlockModelList.add(resBlock.getModel());
|
||||||
|
}
|
||||||
|
resnetModel.setParameter(nerveManager.getDnnModel());
|
||||||
|
return resnetModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void insertModel(ResnetModel resnetModel) {
|
||||||
|
List<ResBlockModel> resBlockModelList = resnetModel.getResBlockModelList();
|
||||||
|
int size = resBlockList.size();
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
resBlockList.get(i).insertModel(resBlockModelList.get(i));
|
||||||
|
}
|
||||||
|
nerveManager.insertDnnModel(resnetModel.getParameter());
|
||||||
|
}
|
||||||
|
|
||||||
|
public ResnetManager(ResnetConfig resNetConfig, ActiveFunction activeFunction) throws Exception {
|
||||||
|
int deep = getConvDeep(resNetConfig.getSize(), resNetConfig.getMinFeatureSize());//获取深度
|
||||||
|
int channelNo = resNetConfig.getChannelNo();//通道数
|
||||||
|
int lastSize = getFeatureSize(deep, resNetConfig.getSize(), true);//最后一层特征大小
|
||||||
|
//全局学习率
|
||||||
|
float studyRate = resNetConfig.getStudyRate();
|
||||||
|
if (deep < 2) {
|
||||||
|
throw new Exception("图像尺寸太小了,不能用resnet进行训练");
|
||||||
|
}
|
||||||
|
int featureLength = (int) (channelNo * Math.pow(2, deep - 1));//卷积层输出特征大小
|
||||||
|
nerveManager = new NerveManager(featureLength, resNetConfig.getHiddenNerveNumber(), resNetConfig.getTypeNumber(), resNetConfig.getHiddenDeep()
|
||||||
|
, activeFunction, studyRate, resNetConfig.getRegularModel(), resNetConfig.getRegular(), 0, resNetConfig.getGaMa()
|
||||||
|
, resNetConfig.getGMaxTh(), resNetConfig.isAuto());
|
||||||
|
ResNetConnectionLine resNetConnectionLine = new ResNetConnectionLine();
|
||||||
|
nerveManager.init(true, resNetConfig.isShowLog(), resNetConfig.isSoftMax(), resNetConnectionLine);
|
||||||
|
for (int i = 0; i < deep; i++) {
|
||||||
|
List<SensoryNerve> sensoryNerves = null;
|
||||||
|
if (i == deep - 1) {
|
||||||
|
sensoryNerves = nerveManager.getSensoryNerves();
|
||||||
|
}
|
||||||
|
ResBlock resBlock = new ResBlock(channelNo, i + 1, studyRate, resNetConfig.getSize(), sensoryNerves, resNetConfig.getGaMa()
|
||||||
|
, resNetConfig.getGMaxTh(), resNetConfig.isAuto(), resNetConfig.getGRate());
|
||||||
|
resBlockList.add(resBlock);
|
||||||
|
}
|
||||||
|
restNetInput = new ResnetInput(resBlockList.get(0), resNetConfig.getSize());
|
||||||
|
connection();//残差块进行互相连接
|
||||||
|
resNetConnectionLine.setLastBlock(resBlockList.get(deep - 1), lastSize, resNetConfig.getHiddenNerveNumber(), featureLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void connection() {//残差块相互连接
|
||||||
|
int size = resBlockList.size();
|
||||||
|
for (int i = 0; i < size - 1; i++) {
|
||||||
|
ResBlock resBlock = resBlockList.get(i);
|
||||||
|
ResBlock nextResBlock = resBlockList.get(i + 1);
|
||||||
|
resBlock.setSonResBlock(nextResBlock);
|
||||||
|
nextResBlock.setFatherResBlock(resBlock);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user