From 15081c0b29e4e5d941e5ecea689cd5bbb560ff45 Mon Sep 17 00:00:00 2001 From: inter Date: Thu, 4 Sep 2025 14:08:48 +0800 Subject: [PATCH] Add File --- .../easyai/matrixTools/MatrixNorm.java | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 src/main/java/org/dromara/easyai/matrixTools/MatrixNorm.java diff --git a/src/main/java/org/dromara/easyai/matrixTools/MatrixNorm.java b/src/main/java/org/dromara/easyai/matrixTools/MatrixNorm.java new file mode 100644 index 0000000..3fde9ab --- /dev/null +++ b/src/main/java/org/dromara/easyai/matrixTools/MatrixNorm.java @@ -0,0 +1,108 @@ +package org.dromara.easyai.matrixTools; + +import org.dromara.easyai.conv.DymStudy; +import org.dromara.easyai.resnet.entity.NormModel; + +import java.util.Random; + +/** + * @author lidapeng + * @time 2025/4/11 15:27 + * @des 矩阵批量归一化 + */ +public class MatrixNorm { + private Matrix bTa;//偏移值 + private Matrix power;//膨胀系数矩阵 + private final Matrix bTaDymStudyRate;//偏移系数动态学习率 + private final Matrix powerDymStudyRate;//膨胀系数动态学习率 + private final DymStudy dymStudy; + private final float studyRate;//全局学习率 + private final MatrixOperation matrixOperation = new MatrixOperation(); + private Matrix norm; + + public NormModel getModel() { + NormModel normModel = new NormModel(); + normModel.setBtaParameter(bTa.getMatrixModel()); + normModel.setPowerParameter(power.getMatrixModel()); + return normModel; + } + + public void insertModel(NormModel normModel) { + bTa.insertMatrixModel(normModel.getBtaParameter()); + power.insertMatrixModel(normModel.getPowerParameter()); + } + + public MatrixNorm(int size, float studyRate, float gaMa, float gMaxTh, boolean auTo) throws Exception { + dymStudy = new DymStudy(gaMa, gMaxTh, auTo); + bTa = new Matrix(size, size); + power = new Matrix(size, size); + bTaDymStudyRate = new Matrix(size, size); + powerDymStudyRate = new Matrix(size, size); + this.studyRate = studyRate; + initPower(power); + initPower(bTa); + } + + private void initPower(Matrix matrix) throws Exception { + int size = matrix.getX(); + Random random = new Random(); + for (int i = 0; i < size; i++) { + for (int j = 0; j < size; j++) { + matrix.setNub(i, j, random.nextFloat() / size); + } + } + } + + private Matrix back(Matrix errorMatrix, Matrix myData) throws Exception { + Matrix subPower = matrixOperation.matrixMulPd(errorMatrix, myData, power, false); + Matrix sub = matrixOperation.matrixMulPd(errorMatrix, myData, power, true); + int x = sub.getX(); + int y = sub.getY(); + Matrix errorPower = dymStudy.getErrorMatrixByStudy(studyRate, powerDymStudyRate, subPower); + power = matrixOperation.add(errorPower, power); + float n = (float) Math.sqrt(x * y); + float nt = -n / (n - 1); + Matrix subMatrix = new Matrix(x, y); + for (int i = 0; i < x; i++) { + for (int j = 0; j < y; j++) { + float subValue = sub.getNumber(i, j); + float value = subValue * n + subMatrix.getNumber(i, j); + subMatrix.setNub(i, j, value); + for (int k = 0; k < x; k++) { + for (int l = 0; l < y; l++) { + if (k != i || l != j) { + float otherValue = subValue * nt + subMatrix.getNumber(k, l); + subMatrix.setNub(k, l, otherValue); + } + } + } + } + } + return subMatrix; + } + + public Matrix backError(Matrix errorMatrix) throws Exception { + Matrix error = dymStudy.getErrorMatrixByStudy(studyRate, bTaDymStudyRate, errorMatrix); + bTa = matrixOperation.add(error, bTa);//更新bTa + return back(errorMatrix, norm); + } + + public Matrix norm(Matrix matrix) throws Exception { + int x = matrix.getX(); + int y = matrix.getY(); + if (x != y) { + throw new Exception("必须是方阵才能进行全矩阵的归一化"); + } + Matrix result = new Matrix(x, y); + float avg = matrix.getAVG();//平均值 + float sd = matrixOperation.getSdByMatrix(matrix, avg, 0.0000001f);//标准差 + for (int i = 0; i < x; i++) { + for (int j = 0; j < y; j++) { + float value = (matrix.getNumber(i, j) - avg) / sd; + result.setNub(i, j, value); + } + } + norm = result; + return matrixOperation.add(matrixOperation.mulMatrix(result, power), bTa); + } +}