Add File
This commit is contained in:
@@ -0,0 +1,206 @@
|
||||
package org.dromara.easyai.transFormer.seflAttention;
|
||||
|
||||
import org.dromara.easyai.matrixTools.Matrix;
|
||||
import org.dromara.easyai.matrixTools.MatrixOperation;
|
||||
import org.dromara.easyai.transFormer.model.QKVModel;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
public class SelfAttention {//自注意力层
|
||||
private final Map<Long, MyFeature> featureMatrix = new HashMap<>();//特征矩阵
|
||||
private Matrix powerQ;//q权重矩阵
|
||||
private Matrix powerK;//k权重矩阵
|
||||
private Matrix powerV;//v权重矩阵
|
||||
private final int wordVectorDimension;//特征矩阵维度
|
||||
private final int depth;//深度
|
||||
private final float studyPoint;//学习率
|
||||
private final int selfID;
|
||||
private final boolean encoder;//是否为编码器模块
|
||||
private final MatrixOperation matrixOperation;
|
||||
|
||||
public int getSelfID() {
|
||||
return selfID;
|
||||
}
|
||||
|
||||
public SelfAttention(float studyPoint, int depth, int wordVectorDimension, int selfID, boolean encoder
|
||||
, int coreNumber) throws Exception {
|
||||
matrixOperation = new MatrixOperation(coreNumber);
|
||||
this.studyPoint = studyPoint;
|
||||
this.depth = depth;
|
||||
this.encoder = encoder;
|
||||
this.wordVectorDimension = wordVectorDimension;
|
||||
this.selfID = selfID;
|
||||
powerQ = initPowerMatrix(wordVectorDimension);
|
||||
powerK = initPowerMatrix(wordVectorDimension);
|
||||
powerV = initPowerMatrix(wordVectorDimension);
|
||||
}
|
||||
|
||||
public void insertModel(QKVModel qkvModel) throws Exception {
|
||||
insertPower(qkvModel.getQ(), powerQ);
|
||||
insertPower(qkvModel.getK(), powerK);
|
||||
insertPower(qkvModel.getV(), powerV);
|
||||
}
|
||||
|
||||
private void insertPower(float[][] modelPower, Matrix power) throws Exception {
|
||||
for (int i = 0; i < power.getX(); i++) {
|
||||
for (int j = 0; j < power.getY(); j++) {
|
||||
power.setNub(i, j, modelPower[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public QKVModel getModel() throws Exception {
|
||||
QKVModel qkvModel = new QKVModel();
|
||||
qkvModel.setQ(powerQ.getMatrix());
|
||||
qkvModel.setK(powerK.getMatrix());
|
||||
qkvModel.setV(powerV.getMatrix());
|
||||
qkvModel.setSelfID(selfID);
|
||||
return qkvModel;
|
||||
}
|
||||
|
||||
|
||||
public AttentionError backError(Matrix feature, long eventID) throws Exception {//返回误差
|
||||
Matrix myError = matrixOperation.mathMulBySelf(feature, studyPoint);
|
||||
MyFeature featureBody = this.featureMatrix.get(eventID);
|
||||
Matrix q = featureBody.q;
|
||||
Matrix kt = featureBody.kt;
|
||||
Matrix v = featureBody.v;
|
||||
Matrix qkt = featureBody.qkt;
|
||||
Matrix errorV = matrixOperation.matrixMulPd(myError, qkt, v, false);//先求V的偏导
|
||||
Matrix subQktMax = matrixOperation.matrixMulPd(feature, qkt, v, true);
|
||||
Matrix grMatrix = matrixOperation.matrixSoftMaxPd(qkt, subQktMax, wordVectorDimension);//对softMax做误差求导
|
||||
if (depth == 1 && !encoder) {
|
||||
backMask(grMatrix);
|
||||
}
|
||||
Matrix errorKt = matrixOperation.matrixMulPd(grMatrix, q, kt, false);
|
||||
Matrix errorQ = matrixOperation.matrixMulPd(grMatrix, q, kt, true);
|
||||
Matrix errorK = matrixOperation.transPosition(errorKt);
|
||||
ErrorFeature QPower = updateError(errorQ, featureBody.allFeature, powerQ);
|
||||
Matrix leftMatrix = featureBody.allFeature;
|
||||
if (!encoder && depth > 1) {//大于一层的解码器
|
||||
leftMatrix = featureBody.encoderFeature;
|
||||
}
|
||||
ErrorFeature KPower = updateError(errorK, leftMatrix, powerK);
|
||||
ErrorFeature VPower = updateError(errorV, leftMatrix, powerV);
|
||||
powerQ = QPower.powerMatrix;//更新权重
|
||||
powerK = KPower.powerMatrix;
|
||||
powerV = VPower.powerMatrix;
|
||||
AttentionError attentionError = new AttentionError();
|
||||
Matrix nextFeatureError;//下一层误差
|
||||
Matrix lastEncoderError = null;//编码器最后一层误差
|
||||
if (!encoder && depth > 1) {//大于一层的解码器
|
||||
nextFeatureError = QPower.errorFeatureMatrix;
|
||||
lastEncoderError = matrixOperation.add(KPower.errorFeatureMatrix, VPower.errorFeatureMatrix);
|
||||
} else {
|
||||
nextFeatureError = matrixOperation.addThreeMatrix(QPower.errorFeatureMatrix, KPower.errorFeatureMatrix
|
||||
, VPower.errorFeatureMatrix);
|
||||
}
|
||||
attentionError.setNextFeatureError(nextFeatureError);
|
||||
attentionError.setLastEncoderError(lastEncoderError);
|
||||
this.featureMatrix.remove(eventID);//清除老数据
|
||||
return attentionError;
|
||||
}
|
||||
|
||||
private ErrorFeature updateError(Matrix errorMatrix, Matrix feature, Matrix powerMatrix) throws Exception {//调整误差
|
||||
Matrix errorPower = matrixOperation.matrixMulPd(errorMatrix, feature, powerMatrix, false);
|
||||
Matrix featureError = matrixOperation.matrixMulPd(errorMatrix, feature, powerMatrix, true);
|
||||
Matrix nextPowerMatrix = matrixOperation.add(powerMatrix, errorPower);
|
||||
ErrorFeature errorFeature = new ErrorFeature();
|
||||
errorFeature.errorFeatureMatrix = featureError;
|
||||
errorFeature.powerMatrix = nextPowerMatrix;
|
||||
return errorFeature;
|
||||
}
|
||||
|
||||
private void backMask(Matrix matrix) throws Exception {
|
||||
int x = matrix.getX();
|
||||
int y = matrix.getY();
|
||||
for (int i = 0; i < x; i++) {
|
||||
for (int j = i + 1; j < y; j++) {
|
||||
matrix.setNub(i, j, 0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void mask(Matrix matrix) throws Exception {
|
||||
int x = matrix.getX();
|
||||
int y = matrix.getY();
|
||||
for (int i = 0; i < x; i++) {
|
||||
for (int j = i + 1; j < y; j++) {
|
||||
matrix.setNub(i, j, -1000f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Matrix countSelfAttention(long eventID, boolean isStudy) throws Exception {//进行注意力层正向计算
|
||||
MyFeature featureBody = this.featureMatrix.get(eventID);
|
||||
Matrix myFeature = featureBody.allFeature;
|
||||
Matrix kvFeature;
|
||||
if (!encoder && depth > 1) {//大于1层的解码模块 使用 编码器输出的矩阵来生成kv矩阵
|
||||
kvFeature = featureBody.encoderFeature;
|
||||
//System.out.println(kvFeature);
|
||||
} else {
|
||||
kvFeature = featureBody.allFeature;
|
||||
}
|
||||
Matrix q = matrixOperation.mulMatrix(myFeature, powerQ);
|
||||
Matrix k = matrixOperation.mulMatrix(kvFeature, powerK);
|
||||
Matrix v = matrixOperation.mulMatrix(kvFeature, powerV);
|
||||
Matrix kt = matrixOperation.transPosition(k);//k转置
|
||||
Matrix qkt = matrixOperation.mulMatrix(q, kt);
|
||||
matrixOperation.mathDiv(qkt, (float) Math.sqrt(wordVectorDimension));
|
||||
//做蒙版
|
||||
if (depth == 1 && !encoder) {//第一层解码器 需要先做蒙版操作
|
||||
mask(qkt);
|
||||
}
|
||||
matrixOperation.softMax(qkt);
|
||||
Matrix result = matrixOperation.mulMatrix(qkt, v);
|
||||
if (!isStudy) {
|
||||
this.featureMatrix.remove(eventID);
|
||||
} else {
|
||||
featureBody.q = q;
|
||||
featureBody.kt = kt;
|
||||
featureBody.v = v;
|
||||
featureBody.qkt = qkt;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public EventBody sendMatrixFeature(long eventID, boolean isStudy, Matrix feature, Matrix encoderFeature) throws Exception {
|
||||
EventBody eventBody = new EventBody();
|
||||
eventBody.setEventID(eventID);
|
||||
eventBody.setSelfID(selfID);
|
||||
MyFeature myFeature = new MyFeature();
|
||||
myFeature.allFeature = feature;
|
||||
myFeature.encoderFeature = encoderFeature;
|
||||
featureMatrix.put(eventID, myFeature);
|
||||
eventBody.setFeatureMatrix(countSelfAttention(eventID, isStudy));
|
||||
return eventBody;
|
||||
}
|
||||
|
||||
|
||||
private Matrix initPowerMatrix(int wordVectorDimension) throws Exception {//初始化权重矩阵
|
||||
Random random = new Random();
|
||||
Matrix power = new Matrix(wordVectorDimension, wordVectorDimension);
|
||||
for (int i = 0; i < wordVectorDimension; i++) {
|
||||
for (int j = 0; j < wordVectorDimension; j++) {
|
||||
power.setNub(i, j, random.nextFloat() / wordVectorDimension);
|
||||
}
|
||||
}
|
||||
return power;
|
||||
}
|
||||
|
||||
static class MyFeature {
|
||||
Matrix allFeature;//后方传送特征
|
||||
Matrix encoderFeature;//编码器传送特征
|
||||
Matrix q;
|
||||
Matrix kt;
|
||||
Matrix v;
|
||||
Matrix qkt;
|
||||
}
|
||||
|
||||
static class ErrorFeature {
|
||||
Matrix errorFeatureMatrix;
|
||||
Matrix powerMatrix;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user