This commit is contained in:
2025-09-04 14:08:48 +08:00
parent 83a25d8baf
commit 15081c0b29

View File

@@ -0,0 +1,108 @@
package org.dromara.easyai.matrixTools;
import org.dromara.easyai.conv.DymStudy;
import org.dromara.easyai.resnet.entity.NormModel;
import java.util.Random;
/**
* @author lidapeng
* @time 2025/4/11 15:27
* @des 矩阵批量归一化
*/
public class MatrixNorm {
private Matrix bTa;//偏移值
private Matrix power;//膨胀系数矩阵
private final Matrix bTaDymStudyRate;//偏移系数动态学习率
private final Matrix powerDymStudyRate;//膨胀系数动态学习率
private final DymStudy dymStudy;
private final float studyRate;//全局学习率
private final MatrixOperation matrixOperation = new MatrixOperation();
private Matrix norm;
public NormModel getModel() {
NormModel normModel = new NormModel();
normModel.setBtaParameter(bTa.getMatrixModel());
normModel.setPowerParameter(power.getMatrixModel());
return normModel;
}
public void insertModel(NormModel normModel) {
bTa.insertMatrixModel(normModel.getBtaParameter());
power.insertMatrixModel(normModel.getPowerParameter());
}
public MatrixNorm(int size, float studyRate, float gaMa, float gMaxTh, boolean auTo) throws Exception {
dymStudy = new DymStudy(gaMa, gMaxTh, auTo);
bTa = new Matrix(size, size);
power = new Matrix(size, size);
bTaDymStudyRate = new Matrix(size, size);
powerDymStudyRate = new Matrix(size, size);
this.studyRate = studyRate;
initPower(power);
initPower(bTa);
}
private void initPower(Matrix matrix) throws Exception {
int size = matrix.getX();
Random random = new Random();
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
matrix.setNub(i, j, random.nextFloat() / size);
}
}
}
private Matrix back(Matrix errorMatrix, Matrix myData) throws Exception {
Matrix subPower = matrixOperation.matrixMulPd(errorMatrix, myData, power, false);
Matrix sub = matrixOperation.matrixMulPd(errorMatrix, myData, power, true);
int x = sub.getX();
int y = sub.getY();
Matrix errorPower = dymStudy.getErrorMatrixByStudy(studyRate, powerDymStudyRate, subPower);
power = matrixOperation.add(errorPower, power);
float n = (float) Math.sqrt(x * y);
float nt = -n / (n - 1);
Matrix subMatrix = new Matrix(x, y);
for (int i = 0; i < x; i++) {
for (int j = 0; j < y; j++) {
float subValue = sub.getNumber(i, j);
float value = subValue * n + subMatrix.getNumber(i, j);
subMatrix.setNub(i, j, value);
for (int k = 0; k < x; k++) {
for (int l = 0; l < y; l++) {
if (k != i || l != j) {
float otherValue = subValue * nt + subMatrix.getNumber(k, l);
subMatrix.setNub(k, l, otherValue);
}
}
}
}
}
return subMatrix;
}
public Matrix backError(Matrix errorMatrix) throws Exception {
Matrix error = dymStudy.getErrorMatrixByStudy(studyRate, bTaDymStudyRate, errorMatrix);
bTa = matrixOperation.add(error, bTa);//更新bTa
return back(errorMatrix, norm);
}
public Matrix norm(Matrix matrix) throws Exception {
int x = matrix.getX();
int y = matrix.getY();
if (x != y) {
throw new Exception("必须是方阵才能进行全矩阵的归一化");
}
Matrix result = new Matrix(x, y);
float avg = matrix.getAVG();//平均值
float sd = matrixOperation.getSdByMatrix(matrix, avg, 0.0000001f);//标准差
for (int i = 0; i < x; i++) {
for (int j = 0; j < y; j++) {
float value = (matrix.getNumber(i, j) - avg) / sd;
result.setNub(i, j, value);
}
}
norm = result;
return matrixOperation.add(matrixOperation.mulMatrix(result, power), bTa);
}
}