Add File
This commit is contained in:
181
src/main/java/org/dromara/easyai/unet/UNetEncoder.java
Normal file
181
src/main/java/org/dromara/easyai/unet/UNetEncoder.java
Normal file
@@ -0,0 +1,181 @@
|
||||
package org.dromara.easyai.unet;
|
||||
|
||||
import org.dromara.easyai.conv.ConvCount;
|
||||
import org.dromara.easyai.entity.ThreeChannelMatrix;
|
||||
import org.dromara.easyai.i.ActiveFunction;
|
||||
import org.dromara.easyai.i.OutBack;
|
||||
import org.dromara.easyai.matrixTools.Matrix;
|
||||
import org.dromara.easyai.matrixTools.MatrixOperation;
|
||||
import org.dromara.easyai.nerveEntity.ConvParameter;
|
||||
import org.dromara.easyai.nerveEntity.ConvSize;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* @author lidapeng
|
||||
* @time 2025/3/2 07:51
|
||||
* @des unet编码器
|
||||
*/
|
||||
public class UNetEncoder extends ConvCount {
|
||||
private final ConvParameter convParameter = new ConvParameter();//内存中卷积层模型及临时数据
|
||||
private final MatrixOperation matrixOperation = new MatrixOperation();
|
||||
private final int kerSize;
|
||||
private final float studyRate;//学习率
|
||||
private final int deep;//当前深度
|
||||
private final int channelNo;//卷积层数
|
||||
private List<Matrix> decodeErrorMatrix;//从解码器传来的误差矩阵
|
||||
private final ActiveFunction activeFunction;
|
||||
private UNetEncoder afterEncoder;//下一个编码器
|
||||
private UNetEncoder beforeEncoder;//上一个编码器
|
||||
private UNetDecoder decoder;//下一个解码器
|
||||
private final int xSize;
|
||||
private final int ySize;
|
||||
private final float oneStudyRate;
|
||||
private final float gaMa;
|
||||
private final float gMaxTh;
|
||||
private final boolean aoTu;
|
||||
|
||||
public UNetEncoder(int kerSize, int channelNo, int deep, ActiveFunction activeFunction
|
||||
, float studyRate, int xSize, int ySize, float oneStudyRate, float gaMa, float gMaxTh, boolean aoTu) throws Exception {//核心大小
|
||||
Random random = new Random();
|
||||
this.xSize = xSize;
|
||||
this.aoTu = aoTu;
|
||||
this.gMaxTh = gMaxTh;
|
||||
this.gaMa = gaMa;
|
||||
this.ySize = ySize;
|
||||
this.oneStudyRate = oneStudyRate;
|
||||
this.studyRate = studyRate;
|
||||
this.kerSize = kerSize;
|
||||
this.activeFunction = activeFunction;
|
||||
this.deep = deep;
|
||||
this.channelNo = channelNo;
|
||||
List<Matrix> nerveMatrixList = convParameter.getNerveMatrixList();
|
||||
List<ConvSize> convSizeList = convParameter.getConvSizeList();
|
||||
List<Matrix> dymStudyRateList = convParameter.getDymStudyRateList();
|
||||
for (int i = 0; i < channelNo; i++) {
|
||||
initNervePowerMatrix(random, nerveMatrixList, dymStudyRateList);
|
||||
convSizeList.add(new ConvSize());
|
||||
}
|
||||
if (deep == 1) {
|
||||
List<List<Float>> oneConvPowers = new ArrayList<>();
|
||||
List<List<Float>> oneDymStudyRateList = new ArrayList<>();
|
||||
for (int k = 0; k < channelNo; k++) {
|
||||
List<Float> oneConvPower = new ArrayList<>();
|
||||
List<Float> oneDymStudyRate = new ArrayList<>();
|
||||
oneConvPowers.add(oneConvPower);
|
||||
oneDymStudyRateList.add(oneDymStudyRate);
|
||||
//通道数
|
||||
int channelNum = 3;
|
||||
for (int i = 0; i < channelNum; i++) {
|
||||
oneConvPower.add(random.nextFloat() / channelNum);
|
||||
oneDymStudyRate.add(0f);
|
||||
}
|
||||
}
|
||||
convParameter.setOneDymStudyRateList(oneDymStudyRateList);
|
||||
convParameter.setOneConvPower(oneConvPowers);
|
||||
}
|
||||
}
|
||||
|
||||
public ConvParameter getConvParameter() {
|
||||
return convParameter;
|
||||
}
|
||||
|
||||
protected void setDecodeErrorMatrix(List<Matrix> decodeErrorMatrix) {
|
||||
this.decodeErrorMatrix = decodeErrorMatrix;
|
||||
}
|
||||
|
||||
protected List<Matrix> getAfterConvMatrix(long eventID) {//卷积后的矩阵
|
||||
List<Matrix> outMatrixList = convParameter.getFeatureMap().get(eventID);
|
||||
convParameter.getFeatureMap().remove(eventID);
|
||||
return outMatrixList;
|
||||
}
|
||||
|
||||
//发送特征三通道矩阵
|
||||
public void sendThreeChannel(long eventID, OutBack outBack, ThreeChannelMatrix feature, ThreeChannelMatrix featureE,
|
||||
boolean study) throws Exception {
|
||||
if (study && featureE == null) {
|
||||
throw new Exception("训练时期望矩阵不能为空");
|
||||
}
|
||||
if (feature.getX() != xSize && feature.getY() != ySize) {
|
||||
throw new Exception("输入图片尺寸与初始化参数不一致");
|
||||
}
|
||||
List<Matrix> matrixList = new ArrayList<>();
|
||||
matrixList.add(feature.getMatrixR());
|
||||
matrixList.add(feature.getMatrixG());
|
||||
matrixList.add(feature.getMatrixB());
|
||||
if (study) {
|
||||
convParameter.setFeatureMatrixList(matrixList);
|
||||
}
|
||||
sendMatrixList(eventID, outBack, featureE, matrixList, study, feature);
|
||||
}
|
||||
|
||||
protected void sendFeature(long eventID, OutBack outBack, ThreeChannelMatrix featureE,
|
||||
List<Matrix> myFeatures, boolean study, ThreeChannelMatrix backGround) throws Exception {
|
||||
List<Matrix> convMatrixList = downConvAndPooling(myFeatures, convParameter, channelNo, activeFunction, kerSize, true, eventID);
|
||||
if (afterEncoder != null) {//后面还有编码器,继续向后传递
|
||||
afterEncoder.sendFeature(eventID, outBack, featureE, convMatrixList, study, backGround);
|
||||
} else {//向解码器传递
|
||||
decoder.sendFeature(eventID, outBack, featureE, convMatrixList, study, backGround);
|
||||
}
|
||||
}
|
||||
|
||||
protected void backError(List<Matrix> errorMatrix) throws Exception {//接收误差
|
||||
List<Matrix> errorList = backDownPoolingByList(errorMatrix, convParameter.getOutX(), convParameter.getOutY());//池化误差返回
|
||||
List<Matrix> errorMatrixList = matrixOperation.addMatrixList(errorList, decodeErrorMatrix);
|
||||
List<Matrix> myErrorMatrix = backAllDownConv(convParameter, errorMatrixList, studyRate, activeFunction, channelNo, kerSize,
|
||||
gaMa, gMaxTh, aoTu);
|
||||
if (beforeEncoder != null) {
|
||||
beforeEncoder.backError(myErrorMatrix);
|
||||
} else {//最后一层 调整1v1卷积
|
||||
backOneConvByList(myErrorMatrix, convParameter.getFeatureMatrixList(), convParameter.getOneConvPower(), oneStudyRate
|
||||
, convParameter.getOneDymStudyRateList(), gaMa, gMaxTh, aoTu);
|
||||
}
|
||||
}
|
||||
|
||||
public void sendMatrixList(long eventID, OutBack outBack, ThreeChannelMatrix featureE, List<Matrix> feature,
|
||||
boolean study, ThreeChannelMatrix backGround) throws Exception {
|
||||
List<Matrix> myFeatures = manyOneConv(feature, convParameter.getOneConvPower());//矩阵重新调整维度
|
||||
List<Matrix> convMatrixList = downConvAndPooling(myFeatures, convParameter, channelNo, activeFunction, kerSize, true, eventID);
|
||||
if (afterEncoder != null) {//后面还有编码器,继续向后传递
|
||||
afterEncoder.sendFeature(eventID, outBack, featureE, convMatrixList, study, backGround);
|
||||
} else {//向解码器传递
|
||||
decoder.sendFeature(eventID, outBack, featureE, convMatrixList, study, backGround);
|
||||
}
|
||||
}
|
||||
|
||||
private void initNervePowerMatrix(Random random, List<Matrix> nervePowerMatrixList, List<Matrix> dymStudyRageList) throws Exception {
|
||||
int convSize = kerSize * kerSize;
|
||||
Matrix nervePowerMatrix = new Matrix(convSize, 1);
|
||||
for (int i = 0; i < convSize; i++) {
|
||||
float power = random.nextFloat() / kerSize;
|
||||
nervePowerMatrix.setNub(i, 0, power);
|
||||
}
|
||||
dymStudyRageList.add(new Matrix(convSize, 1));
|
||||
nervePowerMatrixList.add(nervePowerMatrix);
|
||||
}
|
||||
|
||||
public UNetEncoder getAfterEncoder() {
|
||||
return afterEncoder;
|
||||
}
|
||||
|
||||
public void setAfterEncoder(UNetEncoder afterEncoder) {
|
||||
this.afterEncoder = afterEncoder;
|
||||
}
|
||||
|
||||
public UNetEncoder getBeforeEncoder() {
|
||||
return beforeEncoder;
|
||||
}
|
||||
|
||||
public void setBeforeEncoder(UNetEncoder beforeEncoder) {
|
||||
this.beforeEncoder = beforeEncoder;
|
||||
}
|
||||
|
||||
public UNetDecoder getDecoder() {
|
||||
return decoder;
|
||||
}
|
||||
|
||||
public void setDecoder(UNetDecoder decoder) {
|
||||
this.decoder = decoder;
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user