Add File
This commit is contained in:
299
src/main/java/org/dromara/easyai/unet/UNetDecoder.java
Normal file
299
src/main/java/org/dromara/easyai/unet/UNetDecoder.java
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
package org.dromara.easyai.unet;
|
||||||
|
|
||||||
|
import org.dromara.easyai.conv.ConvCount;
|
||||||
|
import org.dromara.easyai.conv.DymStudy;
|
||||||
|
import org.dromara.easyai.entity.ThreeChannelMatrix;
|
||||||
|
import org.dromara.easyai.i.ActiveFunction;
|
||||||
|
import org.dromara.easyai.i.OutBack;
|
||||||
|
import org.dromara.easyai.matrixTools.Matrix;
|
||||||
|
import org.dromara.easyai.matrixTools.MatrixOperation;
|
||||||
|
import org.dromara.easyai.nerveEntity.ConvParameter;
|
||||||
|
import org.dromara.easyai.nerveEntity.ConvSize;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author lidapeng
|
||||||
|
* @time 2025/3/2 07:55
|
||||||
|
* @des unet解码器
|
||||||
|
*/
|
||||||
|
public class UNetDecoder extends ConvCount {
|
||||||
|
private final ConvParameter convParameter = new ConvParameter();//内存中卷积层模型及临时数据
|
||||||
|
private final MatrixOperation matrixOperation = new MatrixOperation();
|
||||||
|
private final int kerSize;
|
||||||
|
private final int deep;//当前深度
|
||||||
|
private final float studyRate;//学习率
|
||||||
|
private final int channelNo;//通道数
|
||||||
|
private final boolean lastLay;//是否为最后一层
|
||||||
|
private final ActiveFunction activeFunction;
|
||||||
|
private UNetDecoder afterDecoder;//下一个解码器
|
||||||
|
private UNetDecoder beforeDecoder;//上一个解码器
|
||||||
|
private UNetEncoder encoder;//上一个编码器
|
||||||
|
private UNetEncoder myUNetEncoder;//同级编码器
|
||||||
|
private final ConvSize convSize = new ConvSize();
|
||||||
|
private final Cutting cutting;//输出语义切割图像
|
||||||
|
private final float oneConvStudyRate;//
|
||||||
|
private final float gaMa;
|
||||||
|
private final float gMaxTh;
|
||||||
|
private final boolean autoStudyRate;//自动学习率
|
||||||
|
|
||||||
|
public UNetDecoder(int kerSize, int deep, int channelNo, ActiveFunction activeFunction
|
||||||
|
, boolean lastLay, float studyRate, Cutting cutting, float oneConvStudyRate, float gaMa, float gMaxTh, boolean autoStudyRate) throws Exception {
|
||||||
|
this.cutting = cutting;
|
||||||
|
this.autoStudyRate = autoStudyRate;
|
||||||
|
this.gMaxTh = gMaxTh;
|
||||||
|
this.gaMa = gaMa;
|
||||||
|
this.kerSize = kerSize;
|
||||||
|
this.oneConvStudyRate = oneConvStudyRate;
|
||||||
|
this.deep = deep;
|
||||||
|
this.studyRate = studyRate;
|
||||||
|
this.lastLay = lastLay;
|
||||||
|
this.channelNo = channelNo;
|
||||||
|
this.activeFunction = activeFunction;
|
||||||
|
Random random = new Random();
|
||||||
|
List<Matrix> nerveMatrixList = convParameter.getNerveMatrixList();
|
||||||
|
List<Matrix> dymStudyRateList = convParameter.getDymStudyRateList();
|
||||||
|
List<Matrix> upNeverMatrixList = convParameter.getUpNerveMatrixList();//上卷积采样权重
|
||||||
|
List<Matrix> upDYmStudyRateList = convParameter.getUpDymStudyRateList();
|
||||||
|
List<ConvSize> convSizeList = convParameter.getConvSizeList();
|
||||||
|
for (int i = 0; i < channelNo; i++) {
|
||||||
|
int convSize = kerSize * kerSize;
|
||||||
|
upDYmStudyRateList.add(new Matrix(1, convSize));
|
||||||
|
upNeverMatrixList.add(initUpNervePowerMatrix(random));
|
||||||
|
initNervePowerMatrix(random, nerveMatrixList, dymStudyRateList);
|
||||||
|
convSizeList.add(new ConvSize());
|
||||||
|
}
|
||||||
|
if (lastLay) {
|
||||||
|
List<Float> oneConvPower = new ArrayList<>();
|
||||||
|
List<Float> oneDymStudyRate = new ArrayList<>();
|
||||||
|
for (int i = 0; i < channelNo; i++) {
|
||||||
|
oneConvPower.add(random.nextFloat() / channelNo);
|
||||||
|
oneDymStudyRate.add(0f);
|
||||||
|
}
|
||||||
|
convParameter.setUpOneDymStudyRateList(oneDymStudyRate);
|
||||||
|
convParameter.setUpOneConvPower(oneConvPower);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public ConvParameter getConvParameter() {
|
||||||
|
return convParameter;
|
||||||
|
}
|
||||||
|
|
||||||
|
private ThreeChannelMatrix fillColor(ThreeChannelMatrix picture, int heightSize, int widthSize) throws Exception {
|
||||||
|
int myFaceHeight = picture.getX();
|
||||||
|
int sub = myFaceHeight - heightSize;
|
||||||
|
int fillHeight = sub / 2;//高度差
|
||||||
|
if (fillHeight == 0) {
|
||||||
|
fillHeight = 1;
|
||||||
|
}
|
||||||
|
ThreeChannelMatrix fillMatrix = null;
|
||||||
|
if (sub > 0) {//剪切
|
||||||
|
fillMatrix = picture.cutChannel(fillHeight, 0, heightSize, widthSize);
|
||||||
|
} else if (sub < 0) {//补0
|
||||||
|
fillMatrix = getFaceMatrix(heightSize, widthSize);
|
||||||
|
fillMatrix.fill(Math.abs(fillHeight), 0, picture);
|
||||||
|
}
|
||||||
|
return fillMatrix;
|
||||||
|
}
|
||||||
|
|
||||||
|
private ThreeChannelMatrix getFaceMatrix(int height, int width) {
|
||||||
|
ThreeChannelMatrix threeChannelMatrix = new ThreeChannelMatrix();
|
||||||
|
Matrix matrixR = new Matrix(height, width);
|
||||||
|
Matrix matrixG = new Matrix(height, width);
|
||||||
|
Matrix matrixB = new Matrix(height, width);
|
||||||
|
Matrix matrixH = new Matrix(height, width);
|
||||||
|
threeChannelMatrix.setX(height);
|
||||||
|
threeChannelMatrix.setY(width);
|
||||||
|
threeChannelMatrix.setMatrixR(matrixR);
|
||||||
|
threeChannelMatrix.setMatrixG(matrixG);
|
||||||
|
threeChannelMatrix.setMatrixB(matrixB);
|
||||||
|
threeChannelMatrix.setH(matrixH);
|
||||||
|
return threeChannelMatrix;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addFeatures(List<Matrix> encoderFeatures, List<Matrix> myFeatures, boolean study) throws Exception {
|
||||||
|
int size = encoderFeatures.size();
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
addFeature(encoderFeatures.get(i), myFeatures.get(i), study);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addFeature(Matrix encoderFeature, Matrix myFeature, boolean study) throws Exception {//获取残差块
|
||||||
|
if (study) {
|
||||||
|
convSize.setXInput(encoderFeature.getX());
|
||||||
|
convSize.setYInput(encoderFeature.getY());
|
||||||
|
}
|
||||||
|
int tx = encoderFeature.getX();
|
||||||
|
int ty = encoderFeature.getY();
|
||||||
|
int x = myFeature.getX();
|
||||||
|
int y = myFeature.getY();
|
||||||
|
for (int i = 0; i < x; i++) {
|
||||||
|
for (int j = 0; j < y; j++) {
|
||||||
|
float encoderValue = 0;
|
||||||
|
if (i < tx && j < ty) {
|
||||||
|
encoderValue = encoderFeature.getNumber(i, j);
|
||||||
|
}
|
||||||
|
float value = (myFeature.getNumber(i, j) + encoderValue) / 2;
|
||||||
|
myFeature.setNub(i, j, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void toThreeChannelMatrix(List<Matrix> features, ThreeChannelMatrix featureE, boolean study, OutBack outBack
|
||||||
|
, ThreeChannelMatrix backGround) throws Exception {
|
||||||
|
int x = features.get(0).getX();
|
||||||
|
int y = features.get(0).getY();
|
||||||
|
List<Float> upOneConvPower = convParameter.getUpOneConvPower();
|
||||||
|
Matrix feature = oneConv(features, upOneConvPower);
|
||||||
|
if (study) {//训练
|
||||||
|
ThreeChannelMatrix sfe = featureE.scale(true, y);//缩放
|
||||||
|
ThreeChannelMatrix fe = fillColor(sfe, x, y);//补0
|
||||||
|
if (fe == null) {
|
||||||
|
fe = sfe;
|
||||||
|
}
|
||||||
|
Matrix he = fe.calculateAvgGrayscale();
|
||||||
|
Matrix errorMatrix = matrixOperation.sub(he, feature);//总误差
|
||||||
|
//先更新分矩阵误差
|
||||||
|
List<Matrix> errorMatrixList = new ArrayList<>();
|
||||||
|
for (int i = 0; i < channelNo; i++) {
|
||||||
|
float power = upOneConvPower.get(i);
|
||||||
|
Matrix error = matrixOperation.mathMulBySelf(errorMatrix, power);
|
||||||
|
errorMatrixList.add(error);
|
||||||
|
}
|
||||||
|
DymStudy dymStudy = new DymStudy(gaMa, gMaxTh, autoStudyRate);
|
||||||
|
backOneConv(errorMatrix, features, upOneConvPower, oneConvStudyRate, convParameter.getUpOneDymStudyRateList(), dymStudy);//更新1v1卷积核
|
||||||
|
backLastError(errorMatrixList);
|
||||||
|
//误差矩阵开始back
|
||||||
|
} else {//输出
|
||||||
|
int mx = backGround.getX();
|
||||||
|
int my = backGround.getY();
|
||||||
|
int startX = (mx - feature.getX()) / 2;
|
||||||
|
int startY = (my - feature.getY()) / 2;
|
||||||
|
Matrix myMatrix = new Matrix(mx, my);
|
||||||
|
for (int i = startX; i < x; i++) {
|
||||||
|
for (int j = startY; j < y; j++) {
|
||||||
|
myMatrix.setNub(i, j, feature.getNumber(i - startX, j - startY));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ThreeChannelMatrix threeChannelMatrix = new ThreeChannelMatrix();
|
||||||
|
threeChannelMatrix.setX(x);
|
||||||
|
threeChannelMatrix.setY(y);
|
||||||
|
threeChannelMatrix.setMatrixR(myMatrix);
|
||||||
|
threeChannelMatrix.setMatrixG(myMatrix);
|
||||||
|
threeChannelMatrix.setMatrixB(myMatrix);
|
||||||
|
if (cutting != null) {
|
||||||
|
cutting.cut(backGround, threeChannelMatrix, outBack);
|
||||||
|
} else {
|
||||||
|
outBack.getBackThreeChannelMatrix(threeChannelMatrix);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void backLastError(List<Matrix> errorMatrixList) throws Exception {//最后一层的误差反向传播
|
||||||
|
List<Matrix> errorList = backAllDownConv(convParameter, errorMatrixList, studyRate, activeFunction, channelNo, kerSize, gaMa, gMaxTh
|
||||||
|
, autoStudyRate);
|
||||||
|
sendEncoderError(errorList);//给同级解码器发送误差
|
||||||
|
beforeDecoder.backErrorMatrix(errorList);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void sendEncoderError(List<Matrix> errors) throws Exception {//给同级解码器发送误差
|
||||||
|
List<Matrix> encoderErrors = new ArrayList<>();
|
||||||
|
for (Matrix error : errors) {
|
||||||
|
Matrix encoderError = new Matrix(convSize.getXInput(), convSize.getYInput());
|
||||||
|
int x = convSize.getXInput();
|
||||||
|
int y = convSize.getYInput();
|
||||||
|
int tx = error.getX();
|
||||||
|
int ty = error.getY();
|
||||||
|
for (int i = 0; i < x; i++) {
|
||||||
|
for (int j = 0; j < y; j++) {
|
||||||
|
float value = 0;
|
||||||
|
if (i < tx && j < ty) {
|
||||||
|
value = error.getNumber(i, j) / 2;
|
||||||
|
}
|
||||||
|
encoderError.setNub(i, j, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
encoderErrors.add(encoderError);
|
||||||
|
}
|
||||||
|
myUNetEncoder.setDecodeErrorMatrix(encoderErrors);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void backErrorMatrix(List<Matrix> myErrorMatrixList) throws Exception {//接收解码器误差
|
||||||
|
//退上池化,退上卷积 退下卷积 并返回编码器误差
|
||||||
|
List<Matrix> errorList = backManyUpPooling(myErrorMatrixList);//退上池化
|
||||||
|
List<Matrix> errorMatrixList = backManyUpConv(errorList, kerSize, convParameter, studyRate, activeFunction, gaMa, gMaxTh, autoStudyRate);//退上卷积
|
||||||
|
List<Matrix> backList = backAllDownConv(convParameter, errorMatrixList, studyRate, activeFunction, channelNo, kerSize, gaMa, gMaxTh
|
||||||
|
, autoStudyRate);//退下卷积
|
||||||
|
if (myUNetEncoder != null) {
|
||||||
|
sendEncoderError(backList);//给同级编码器发送误差
|
||||||
|
}
|
||||||
|
if (beforeDecoder != null) {
|
||||||
|
beforeDecoder.backErrorMatrix(backList);
|
||||||
|
} else {//给上一个编码器发送误差
|
||||||
|
encoder.backError(backList);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void sendFeature(long eventID, OutBack outBack, ThreeChannelMatrix featureE,
|
||||||
|
List<Matrix> myFeatures, boolean study, ThreeChannelMatrix backGround) throws Exception {
|
||||||
|
if (deep > 1) {
|
||||||
|
List<Matrix> encoderMatrixList = myUNetEncoder.getAfterConvMatrix(eventID);//编码器特征
|
||||||
|
addFeatures(encoderMatrixList, myFeatures, study);
|
||||||
|
}
|
||||||
|
List<Matrix> upConvMatrixList = upConvAndPooling(myFeatures, convParameter, channelNo, activeFunction, kerSize, !lastLay);
|
||||||
|
if (lastLay) {//最后一层解码器
|
||||||
|
toThreeChannelMatrix(upConvMatrixList, featureE, study, outBack, backGround);
|
||||||
|
} else {
|
||||||
|
afterDecoder.sendFeature(eventID, outBack, featureE, upConvMatrixList, study, backGround);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private Matrix initUpNervePowerMatrix(Random random) throws Exception {
|
||||||
|
int convSize = kerSize * kerSize;
|
||||||
|
Matrix nervePowerMatrix = new Matrix(1, convSize);
|
||||||
|
for (int j = 0; j < convSize; j++) {
|
||||||
|
float power = random.nextFloat() / kerSize;
|
||||||
|
nervePowerMatrix.setNub(0, j, power);
|
||||||
|
}
|
||||||
|
return nervePowerMatrix;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initNervePowerMatrix(Random random, List<Matrix> nervePowerMatrixList, List<Matrix> dymStudyRateList) throws Exception {
|
||||||
|
int convSize = kerSize * kerSize;
|
||||||
|
Matrix nervePowerMatrix = new Matrix(convSize, 1);
|
||||||
|
for (int i = 0; i < convSize; i++) {
|
||||||
|
float power = random.nextFloat() / kerSize;
|
||||||
|
nervePowerMatrix.setNub(i, 0, power);
|
||||||
|
}
|
||||||
|
dymStudyRateList.add(new Matrix(convSize, 1));
|
||||||
|
nervePowerMatrixList.add(nervePowerMatrix);
|
||||||
|
}
|
||||||
|
|
||||||
|
public UNetDecoder getAfterDecoder() {
|
||||||
|
return afterDecoder;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setAfterDecoder(UNetDecoder afterDecoder) {
|
||||||
|
this.afterDecoder = afterDecoder;
|
||||||
|
}
|
||||||
|
|
||||||
|
public UNetDecoder getBeforeDecoder() {
|
||||||
|
return beforeDecoder;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setBeforeDecoder(UNetDecoder beforeDecoder) {
|
||||||
|
this.beforeDecoder = beforeDecoder;
|
||||||
|
}
|
||||||
|
|
||||||
|
public UNetEncoder getEncoder() {
|
||||||
|
return encoder;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setEncoder(UNetEncoder encoder) {
|
||||||
|
this.encoder = encoder;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setMyUNetEncoder(UNetEncoder myUNetEncoder) {
|
||||||
|
this.myUNetEncoder = myUNetEncoder;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user