Add File
This commit is contained in:
108
src/main/java/org/dromara/easyai/matrixTools/MatrixNorm.java
Normal file
108
src/main/java/org/dromara/easyai/matrixTools/MatrixNorm.java
Normal file
@@ -0,0 +1,108 @@
|
||||
package org.dromara.easyai.matrixTools;
|
||||
|
||||
import org.dromara.easyai.conv.DymStudy;
|
||||
import org.dromara.easyai.resnet.entity.NormModel;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* @author lidapeng
|
||||
* @time 2025/4/11 15:27
|
||||
* @des 矩阵批量归一化
|
||||
*/
|
||||
public class MatrixNorm {
|
||||
private Matrix bTa;//偏移值
|
||||
private Matrix power;//膨胀系数矩阵
|
||||
private final Matrix bTaDymStudyRate;//偏移系数动态学习率
|
||||
private final Matrix powerDymStudyRate;//膨胀系数动态学习率
|
||||
private final DymStudy dymStudy;
|
||||
private final float studyRate;//全局学习率
|
||||
private final MatrixOperation matrixOperation = new MatrixOperation();
|
||||
private Matrix norm;
|
||||
|
||||
public NormModel getModel() {
|
||||
NormModel normModel = new NormModel();
|
||||
normModel.setBtaParameter(bTa.getMatrixModel());
|
||||
normModel.setPowerParameter(power.getMatrixModel());
|
||||
return normModel;
|
||||
}
|
||||
|
||||
public void insertModel(NormModel normModel) {
|
||||
bTa.insertMatrixModel(normModel.getBtaParameter());
|
||||
power.insertMatrixModel(normModel.getPowerParameter());
|
||||
}
|
||||
|
||||
public MatrixNorm(int size, float studyRate, float gaMa, float gMaxTh, boolean auTo) throws Exception {
|
||||
dymStudy = new DymStudy(gaMa, gMaxTh, auTo);
|
||||
bTa = new Matrix(size, size);
|
||||
power = new Matrix(size, size);
|
||||
bTaDymStudyRate = new Matrix(size, size);
|
||||
powerDymStudyRate = new Matrix(size, size);
|
||||
this.studyRate = studyRate;
|
||||
initPower(power);
|
||||
initPower(bTa);
|
||||
}
|
||||
|
||||
private void initPower(Matrix matrix) throws Exception {
|
||||
int size = matrix.getX();
|
||||
Random random = new Random();
|
||||
for (int i = 0; i < size; i++) {
|
||||
for (int j = 0; j < size; j++) {
|
||||
matrix.setNub(i, j, random.nextFloat() / size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Matrix back(Matrix errorMatrix, Matrix myData) throws Exception {
|
||||
Matrix subPower = matrixOperation.matrixMulPd(errorMatrix, myData, power, false);
|
||||
Matrix sub = matrixOperation.matrixMulPd(errorMatrix, myData, power, true);
|
||||
int x = sub.getX();
|
||||
int y = sub.getY();
|
||||
Matrix errorPower = dymStudy.getErrorMatrixByStudy(studyRate, powerDymStudyRate, subPower);
|
||||
power = matrixOperation.add(errorPower, power);
|
||||
float n = (float) Math.sqrt(x * y);
|
||||
float nt = -n / (n - 1);
|
||||
Matrix subMatrix = new Matrix(x, y);
|
||||
for (int i = 0; i < x; i++) {
|
||||
for (int j = 0; j < y; j++) {
|
||||
float subValue = sub.getNumber(i, j);
|
||||
float value = subValue * n + subMatrix.getNumber(i, j);
|
||||
subMatrix.setNub(i, j, value);
|
||||
for (int k = 0; k < x; k++) {
|
||||
for (int l = 0; l < y; l++) {
|
||||
if (k != i || l != j) {
|
||||
float otherValue = subValue * nt + subMatrix.getNumber(k, l);
|
||||
subMatrix.setNub(k, l, otherValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return subMatrix;
|
||||
}
|
||||
|
||||
public Matrix backError(Matrix errorMatrix) throws Exception {
|
||||
Matrix error = dymStudy.getErrorMatrixByStudy(studyRate, bTaDymStudyRate, errorMatrix);
|
||||
bTa = matrixOperation.add(error, bTa);//更新bTa
|
||||
return back(errorMatrix, norm);
|
||||
}
|
||||
|
||||
public Matrix norm(Matrix matrix) throws Exception {
|
||||
int x = matrix.getX();
|
||||
int y = matrix.getY();
|
||||
if (x != y) {
|
||||
throw new Exception("必须是方阵才能进行全矩阵的归一化");
|
||||
}
|
||||
Matrix result = new Matrix(x, y);
|
||||
float avg = matrix.getAVG();//平均值
|
||||
float sd = matrixOperation.getSdByMatrix(matrix, avg, 0.0000001f);//标准差
|
||||
for (int i = 0; i < x; i++) {
|
||||
for (int j = 0; j < y; j++) {
|
||||
float value = (matrix.getNumber(i, j) - avg) / sd;
|
||||
result.setNub(i, j, value);
|
||||
}
|
||||
}
|
||||
norm = result;
|
||||
return matrixOperation.add(matrixOperation.mulMatrix(result, power), bTa);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user