Add File
This commit is contained in:
@@ -0,0 +1,62 @@
|
|||||||
|
package org.dromara.easyai.transFormer;
|
||||||
|
|
||||||
|
import org.dromara.easyai.i.OutBack;
|
||||||
|
import org.dromara.easyai.matrixTools.Matrix;
|
||||||
|
import org.dromara.easyai.transFormer.model.FirstDecoderModel;
|
||||||
|
import org.dromara.easyai.transFormer.seflAttention.LayNorm;
|
||||||
|
import org.dromara.easyai.transFormer.seflAttention.MultiSelfAttention;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
|
||||||
|
public class FirstDecoderBlock {//解码器模块
|
||||||
|
private final MultiSelfAttention multiSelfAttention;
|
||||||
|
private final LayNorm attentionLayNorm;
|
||||||
|
//////////////////
|
||||||
|
private final CodecBlock codecBlock;//前方的解码层
|
||||||
|
private CodecBlock lastEncoderBlock;//最后一层编码器
|
||||||
|
|
||||||
|
public void setLastEncoderBlock(CodecBlock lastEncoderBlock) {
|
||||||
|
this.lastEncoderBlock = lastEncoderBlock;
|
||||||
|
}
|
||||||
|
|
||||||
|
public FirstDecoderBlock(int multiNumber, int featureDimension, float studyPoint, CodecBlock codecBlock, int coreNumber,
|
||||||
|
TransWordVector transWordVector) throws Exception {//进行初始化
|
||||||
|
//注意力层残差归一化
|
||||||
|
attentionLayNorm = new LayNorm(1, featureDimension, null, this, studyPoint, coreNumber, false
|
||||||
|
, 1);
|
||||||
|
multiSelfAttention = new MultiSelfAttention(multiNumber, studyPoint, 1, featureDimension, false, null,
|
||||||
|
coreNumber, transWordVector);
|
||||||
|
multiSelfAttention.setLayNorm(attentionLayNorm);
|
||||||
|
attentionLayNorm.setMultiSelfAttention(multiSelfAttention);
|
||||||
|
this.codecBlock = codecBlock;
|
||||||
|
}
|
||||||
|
|
||||||
|
public FirstDecoderModel getModel() throws Exception {
|
||||||
|
FirstDecoderModel firstDecoderModel = new FirstDecoderModel();
|
||||||
|
firstDecoderModel.setMultiSelfAttentionModel(multiSelfAttention.getModel());
|
||||||
|
firstDecoderModel.setAttentionLayNormModel(attentionLayNorm.getModel());
|
||||||
|
return firstDecoderModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void insertModel(FirstDecoderModel firstDecoderModel) throws Exception {
|
||||||
|
multiSelfAttention.insertModel(firstDecoderModel.getMultiSelfAttentionModel());
|
||||||
|
attentionLayNorm.insertModel(firstDecoderModel.getAttentionLayNormModel());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void backError(long eventID, Matrix error) throws Exception {
|
||||||
|
attentionLayNorm.backErrorFromLine(error, eventID);
|
||||||
|
lastEncoderBlock.encoderBackStart(eventID);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void sendOutputMatrix(long eventID, Matrix out, boolean isStudy, OutBack outBack, List<Integer> E, boolean outAllPro) throws Exception {
|
||||||
|
Matrix c = lastEncoderBlock.getOutMatrix(eventID);
|
||||||
|
lastEncoderBlock.removeOutMatrix(eventID);
|
||||||
|
codecBlock.sendInputMatrix(eventID, out, isStudy, outBack, E, c, outAllPro);
|
||||||
|
}
|
||||||
|
|
||||||
|
//Decoder 参数正向入口
|
||||||
|
public void sendInputMatrix(long eventID, Matrix feature, boolean isStudy, OutBack outBack, List<Integer> E, boolean outAllPro) throws Exception {
|
||||||
|
multiSelfAttention.sendMatrixMessage(eventID, feature, isStudy, outBack, E, null, outAllPro);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user