This commit is contained in:
2025-09-04 14:09:17 +08:00
parent 10e6b209aa
commit 3f89e062eb

View File

@@ -0,0 +1,62 @@
package org.dromara.easyai.transFormer;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.transFormer.model.FirstDecoderModel;
import org.dromara.easyai.transFormer.seflAttention.LayNorm;
import org.dromara.easyai.transFormer.seflAttention.MultiSelfAttention;
import java.util.List;
public class FirstDecoderBlock {//解码器模块
private final MultiSelfAttention multiSelfAttention;
private final LayNorm attentionLayNorm;
//////////////////
private final CodecBlock codecBlock;//前方的解码层
private CodecBlock lastEncoderBlock;//最后一层编码器
public void setLastEncoderBlock(CodecBlock lastEncoderBlock) {
this.lastEncoderBlock = lastEncoderBlock;
}
public FirstDecoderBlock(int multiNumber, int featureDimension, float studyPoint, CodecBlock codecBlock, int coreNumber,
TransWordVector transWordVector) throws Exception {//进行初始化
//注意力层残差归一化
attentionLayNorm = new LayNorm(1, featureDimension, null, this, studyPoint, coreNumber, false
, 1);
multiSelfAttention = new MultiSelfAttention(multiNumber, studyPoint, 1, featureDimension, false, null,
coreNumber, transWordVector);
multiSelfAttention.setLayNorm(attentionLayNorm);
attentionLayNorm.setMultiSelfAttention(multiSelfAttention);
this.codecBlock = codecBlock;
}
public FirstDecoderModel getModel() throws Exception {
FirstDecoderModel firstDecoderModel = new FirstDecoderModel();
firstDecoderModel.setMultiSelfAttentionModel(multiSelfAttention.getModel());
firstDecoderModel.setAttentionLayNormModel(attentionLayNorm.getModel());
return firstDecoderModel;
}
public void insertModel(FirstDecoderModel firstDecoderModel) throws Exception {
multiSelfAttention.insertModel(firstDecoderModel.getMultiSelfAttentionModel());
attentionLayNorm.insertModel(firstDecoderModel.getAttentionLayNormModel());
}
public void backError(long eventID, Matrix error) throws Exception {
attentionLayNorm.backErrorFromLine(error, eventID);
lastEncoderBlock.encoderBackStart(eventID);
}
public void sendOutputMatrix(long eventID, Matrix out, boolean isStudy, OutBack outBack, List<Integer> E, boolean outAllPro) throws Exception {
Matrix c = lastEncoderBlock.getOutMatrix(eventID);
lastEncoderBlock.removeOutMatrix(eventID);
codecBlock.sendInputMatrix(eventID, out, isStudy, outBack, E, c, outAllPro);
}
//Decoder 参数正向入口
public void sendInputMatrix(long eventID, Matrix feature, boolean isStudy, OutBack outBack, List<Integer> E, boolean outAllPro) throws Exception {
multiSelfAttention.sendMatrixMessage(eventID, feature, isStudy, outBack, E, null, outAllPro);
}
}