Add File
This commit is contained in:
94
src/main/java/org/dromara/easyai/tsd/QBlock.java
Normal file
94
src/main/java/org/dromara/easyai/tsd/QBlock.java
Normal file
@@ -0,0 +1,94 @@
|
||||
package org.dromara.easyai.tsd;
|
||||
|
||||
import org.dromara.easyai.conv.DymStudy;
|
||||
import org.dromara.easyai.i.ActiveFunction;
|
||||
import org.dromara.easyai.i.OutBack;
|
||||
import org.dromara.easyai.matrixTools.Matrix;
|
||||
import org.dromara.easyai.matrixTools.MatrixOperation;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* @author lidapeng
|
||||
* @time 2025/8/7 08:55
|
||||
* @des QLearning 基础运算单元
|
||||
*/
|
||||
public class QBlock {
|
||||
private final DymStudy dymStudy;
|
||||
private final ActiveFunction activeFunction;
|
||||
private Matrix powerMatrix;//权重矩阵
|
||||
private Matrix bMatrix;//偏移矩阵
|
||||
private Matrix bDymStudyRate;//偏移量动态学习率
|
||||
private Matrix powerDymStudyRate;//权重动态学习率
|
||||
private MatrixOperation matrixOperation = new MatrixOperation();
|
||||
private Matrix inputMatrix;
|
||||
private Matrix outputMatrix;
|
||||
private QBlock sonBlock;//向前模块
|
||||
private QBlock outBlock;//输出模块
|
||||
private QBlock fatherBlock;//向后模块
|
||||
private int deep;//深度
|
||||
|
||||
public QBlock(DymStudy dymStudy, int inputSize, int outputSize, ActiveFunction activeFunction, int deep) throws Exception {
|
||||
Random random = new Random();
|
||||
this.deep = deep;
|
||||
this.dymStudy = dymStudy;
|
||||
this.activeFunction = activeFunction;
|
||||
this.powerMatrix = new Matrix(inputSize, outputSize);
|
||||
this.powerDymStudyRate = new Matrix(inputSize, outputSize);
|
||||
this.bMatrix = new Matrix(1, outputSize);
|
||||
this.bDymStudyRate = new Matrix(1, outputSize);
|
||||
initMatrix(powerMatrix, random);
|
||||
initMatrix(bMatrix, random);
|
||||
}
|
||||
|
||||
private Matrix calculation(Matrix featureMatrix, boolean study, Map<Integer, Float> E, OutBack outBack, Matrix wordMatrix) throws Exception {
|
||||
Matrix result = matrixOperation.mulMatrix(featureMatrix, powerMatrix);
|
||||
if (study) {
|
||||
inputMatrix = result;
|
||||
}
|
||||
Matrix outMatrix = matrixOperation.add(result, bMatrix);
|
||||
int x = outMatrix.getX();
|
||||
int y = outMatrix.getY();
|
||||
for (int i = 0; i < x; i++) {
|
||||
for (int j = 0; j < y; j++) {
|
||||
float value = activeFunction.function(outMatrix.getNumber(i, j));
|
||||
outMatrix.setNub(i, j, value);
|
||||
}
|
||||
}
|
||||
if (study) {
|
||||
outputMatrix = outMatrix;
|
||||
}
|
||||
Matrix res;
|
||||
if (deep < wordMatrix.getX()) {
|
||||
Matrix word = wordMatrix.getRow(deep);
|
||||
res = matrixOperation.add(outMatrix, word);
|
||||
} else {
|
||||
res = outMatrix;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
private void initMatrix(Matrix matrix, Random random) throws Exception {//初始化矩阵
|
||||
int x = matrix.getX();
|
||||
int y = matrix.getY();
|
||||
float sh = (float) Math.sqrt(x);
|
||||
for (int i = 0; i < x; i++) {
|
||||
for (int j = 0; j < y; j++) {
|
||||
matrix.setNub(i, j, random.nextFloat() / sh);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void setSonBlock(QBlock sonBlock) {
|
||||
this.sonBlock = sonBlock;
|
||||
}
|
||||
|
||||
public void setOutBlock(QBlock outBlock) {
|
||||
this.outBlock = outBlock;
|
||||
}
|
||||
|
||||
public void setFatherBlock(QBlock fatherBlock) {
|
||||
this.fatherBlock = fatherBlock;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user