TransFormer增加可选择输出全分类概率参数

This commit is contained in:
lidapeng 2024-10-14 16:13:19 +08:00
parent 80d9f1a598
commit b6f0567ea7
13 changed files with 58 additions and 45 deletions

1
.gitignore vendored
View File

@ -9,6 +9,7 @@ __pycache__
*.so
.DS_Store
.idea/
pom.xml
# Distribution / packaging
.Python
build/

View File

@ -14,6 +14,15 @@ public class TfConfig {
private String splitWord;//词向量默认隔断符无隔断则会逐字隔断
private boolean selfTimeCode = true;//使用自增时间序列位置编码
private int coreNumber = 1;//是否使用多核并行计算进行提速
private boolean outAllPro = false;//是否输出全概率注意若输出全概率只能用来分类概率,否则将消耗大量内存
public boolean isOutAllPro() {
return outAllPro;
}
public void setOutAllPro(boolean outAllPro) {
this.outAllPro = outAllPro;
}
public int getCoreNumber() {
return coreNumber;

View File

@ -77,7 +77,7 @@ public class TalkToTalk extends MatrixOperation {
}
}
index++;
sensoryNerve.postMessage(eventID, qcMatrix, allFeatures, false, null, wordBack);
sensoryNerve.postMessage(eventID, qcMatrix, allFeatures, false, null, wordBack, false);
id = wordBack.getId();
if (id > 1) {
String word = wordEmbedding.getWord(id - 2);
@ -146,7 +146,7 @@ public class TalkToTalk extends MatrixOperation {
Matrix qMatrix = wordEmbedding.getEmbedding(question, 1, false).getFeatureMatrix();
AnswerE answerE = getSentenceMatrix(answer);
Matrix myAnswer = insertStart(answerE.answerMatrix, transFormerManager.getStartMatrix(qMatrix));//第一行补开始符
sensoryNerve.postMessage(1, qMatrix, myAnswer, true, answerE.answerList, null);
sensoryNerve.postMessage(1, qMatrix, myAnswer, true, answerE.answerList, null, false);
}
}
return transFormerManager.getModel();

View File

@ -110,13 +110,13 @@ public class CodecBlock {
}
public void sendOutputMatrix(long eventID, Matrix out, boolean isStudy, OutBack outBack,
List<Integer> E, Matrix encoderFeature) throws Exception {//参数正向出口
List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {//参数正向出口
if (beforeEncoderBlock != null) {
beforeEncoderBlock.sendInputMatrix(eventID, out, isStudy, outBack, E, encoderFeature);
beforeEncoderBlock.sendInputMatrix(eventID, out, isStudy, outBack, E, encoderFeature, outAllPro);
} else if (encoder) {//编码器走到末尾 保存输出矩阵
outMatrixMap.put(eventID, out);
} else {//解码器走到头了 输出线性分类层
lineBlock.sendParameter(eventID, out, isStudy, outBack, E);
lineBlock.sendParameter(eventID, out, isStudy, outBack, E, outAllPro);
}
}
@ -144,8 +144,8 @@ public class CodecBlock {
//Encoder 参数正向入口
public void sendInputMatrix(long eventID, Matrix feature, boolean isStudy, OutBack outBack, List<Integer> E
, Matrix encoderFeature) throws Exception {
multiSelfAttention.sendMatrixMessage(eventID, feature, isStudy, outBack, E, encoderFeature);
, Matrix encoderFeature, boolean outAllPro) throws Exception {
multiSelfAttention.sendMatrixMessage(eventID, feature, isStudy, outBack, E, encoderFeature, outAllPro);
}
private void initLine(int featureDimension, double studyPoint, int regularModel, double regular) throws Exception {

View File

@ -48,14 +48,14 @@ public class FirstDecoderBlock {//解码器模块
lastEncoderBlock.encoderBackStart(eventID);
}
public void sendOutputMatrix(long eventID, Matrix out, boolean isStudy, OutBack outBack, List<Integer> E) throws Exception {
public void sendOutputMatrix(long eventID, Matrix out, boolean isStudy, OutBack outBack, List<Integer> E, boolean outAllPro) throws Exception {
Matrix c = lastEncoderBlock.getOutMatrix(eventID);
lastEncoderBlock.removeOutMatrix(eventID);
codecBlock.sendInputMatrix(eventID, out, isStudy, outBack, E, c);
codecBlock.sendInputMatrix(eventID, out, isStudy, outBack, E, c, outAllPro);
}
//Decoder 参数正向入口
public void sendInputMatrix(long eventID, Matrix feature, boolean isStudy, OutBack outBack, List<Integer> E) throws Exception {
multiSelfAttention.sendMatrixMessage(eventID, feature, isStudy, outBack, E, null);
public void sendInputMatrix(long eventID, Matrix feature, boolean isStudy, OutBack outBack, List<Integer> E, boolean outAllPro) throws Exception {
multiSelfAttention.sendMatrixMessage(eventID, feature, isStudy, outBack, E, null, outAllPro);
}
}

View File

@ -76,9 +76,9 @@ public class LineBlock {//线性层模块
}
}
public void sendParameter(long eventID, Matrix feature, boolean isStudy, OutBack outBack, List<Integer> E) throws Exception {
public void sendParameter(long eventID, Matrix feature, boolean isStudy, OutBack outBack, List<Integer> E, boolean outAllPro) throws Exception {
for (HiddenNerve hiddenNerve : hiddenNerveList) {
hiddenNerve.postMessage(eventID, feature, isStudy, outBack, E);
hiddenNerve.postMessage(eventID, feature, isStudy, outBack, E, outAllPro);
}
}

View File

@ -28,24 +28,24 @@ public class HiddenNerve extends Nerve {
}
public void receive(Matrix feature, long eventId, boolean isStudy, OutBack outBack,
List<Integer> E, Matrix encoderFeature) throws Exception {//接收上一个残差层传过来得参数
List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {//接收上一个残差层传过来得参数
Matrix out = opMatrix(feature, isStudy);
sendMessage(eventId, out, isStudy, feature, outBack, E, encoderFeature);
sendMessage(eventId, out, isStudy, feature, outBack, E, encoderFeature, outAllPro);
}
@Override
protected void input(long eventId, Matrix parameter, boolean isStudy, Matrix allFeature, OutBack outBack,
List<Integer> E, Matrix encoderFeature) throws Exception {//第二层收到参数
List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {//第二层收到参数
boolean allReady = insertMatrixParameter(eventId, parameter);
if (allReady) {//参数齐了开始计算
Matrix out = opMatrix(reMatrixFeatures.get(eventId), isStudy);
reMatrixFeatures.remove(eventId);
beforeLayNorm.addNormFromNerve(eventId, isStudy, out, allFeature, outBack, E, encoderFeature);
beforeLayNorm.addNormFromNerve(eventId, isStudy, out, allFeature, outBack, E, encoderFeature, outAllPro);
}
}
public void postMessage(long eventId, Matrix feature, boolean isStudy, OutBack outBack, List<Integer> E) throws Exception {
public void postMessage(long eventId, Matrix feature, boolean isStudy, OutBack outBack, List<Integer> E, boolean outAllPro) throws Exception {
Matrix out = opMatrix(feature, isStudy);
sendOutMessage(eventId, out, isStudy, outBack, E);
sendOutMessage(eventId, out, isStudy, outBack, E, outAllPro);
}
}

View File

@ -84,10 +84,10 @@ public abstract class Nerve {
}
protected void sendMessage(long eventId, Matrix parameter, boolean isStudy, Matrix allFeature, OutBack outBack,
List<Integer> E, Matrix encoderFeature) throws Exception {
List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {
if (!son.isEmpty()) {
for (Nerve nerve : son) {
nerve.input(eventId, parameter, isStudy, allFeature, outBack, E, encoderFeature);
nerve.input(eventId, parameter, isStudy, allFeature, outBack, E, encoderFeature, outAllPro);
}
}
@ -110,18 +110,18 @@ public abstract class Nerve {
}
protected void input(long eventId, Matrix parameter, boolean isStudy, Matrix allFeature, OutBack outBack,
List<Integer> E, Matrix encoderFeature) throws Exception {//输入参数
List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {//输入参数
}
protected void toOut(long eventId, Matrix parameter, boolean isStudy, OutBack outBack, List<Integer> E) throws Exception {
protected void toOut(long eventId, Matrix parameter, boolean isStudy, OutBack outBack, List<Integer> E, boolean outAllPro) throws Exception {
}
protected void sendOutMessage(long eventId, Matrix parameter, boolean isStudy, OutBack outBack, List<Integer> E) throws Exception {
protected void sendOutMessage(long eventId, Matrix parameter, boolean isStudy, OutBack outBack, List<Integer> E, boolean outAllPro) throws Exception {
if (!son.isEmpty()) {
for (Nerve nerve : son) {
nerve.toOut(eventId, parameter, isStudy, outBack, E);
nerve.toOut(eventId, parameter, isStudy, outBack, E, outAllPro);
}
}
}

View File

@ -29,12 +29,12 @@ public class OutNerve extends Nerve {
@Override
protected void toOut(long eventId, Matrix parameter, boolean isStudy, OutBack outBack, List<Integer> E) throws Exception {
protected void toOut(long eventId, Matrix parameter, boolean isStudy, OutBack outBack, List<Integer> E, boolean outAllPro) throws Exception {
boolean allReady = insertMatrixParameter(eventId, parameter);
if (allReady) {
Matrix out = opMatrix(reMatrixFeatures.get(eventId), isStudy);
reMatrixFeatures.remove(eventId);
softMax.toOut(eventId, out, isStudy, outBack, E);
softMax.toOut(eventId, out, isStudy, outBack, E, outAllPro);
}
}
}

View File

@ -33,9 +33,9 @@ public class SensoryNerve {
* @param outBack 回调结果
*/
public void postMessage(long eventId, Matrix encoderParameter, Matrix decoderParameter, boolean isStudy, List<Integer> E
, OutBack outBack) throws Exception {//感知神经元输入
firstEncoderBlock.sendInputMatrix(eventId, encoderParameter, isStudy, outBack, E, null);
firstDecoderBlock.sendInputMatrix(eventId, decoderParameter, isStudy, outBack, E);
, OutBack outBack, boolean outAllPro) throws Exception {//感知神经元输入
firstEncoderBlock.sendInputMatrix(eventId, encoderParameter, isStudy, outBack, E, null, outAllPro);
firstDecoderBlock.sendInputMatrix(eventId, decoderParameter, isStudy, outBack, E, outAllPro);
}
}

View File

@ -22,7 +22,7 @@ public class SoftMax extends Nerve {
}
@Override
protected void toOut(long eventId, Matrix parameter, boolean isStudy, OutBack outBack, List<Integer> E) throws Exception {
protected void toOut(long eventId, Matrix parameter, boolean isStudy, OutBack outBack, List<Integer> E, boolean outAllPro) throws Exception {
boolean allReady = insertMatrixParameter(eventId, parameter);
if (allReady) {
Matrix feature = reMatrixFeatures.get(eventId);//特征
@ -35,7 +35,7 @@ public class SoftMax extends Nerve {
Matrix allError = null;
for (int i = 0; i < x; i++) {
Matrix row = feature.getRow(i);
Mes mes = softMax(true, row);//输出值
Mes mes = softMax(true, row, false);//输出值
int key = E.get(i);
if (isShowLog) {
System.out.println("softMax==" + key + ",out==" + mes.poi + ",nerveId==" + mes.typeID);
@ -54,8 +54,11 @@ public class SoftMax extends Nerve {
}
} else {
if (outBack != null) {
Mes mes = softMax(false, feature.getRow(x - 1));//输出值
Mes mes = softMax(false, feature.getRow(x - 1), outAllPro);//输出值
outBack.getBack(mes.poi, mes.typeID, eventId);
if (outAllPro) {
outBack.getSoftMaxBack(eventId, mes.softMax);
}
} else {
throw new Exception("not find outBack");
}
@ -80,7 +83,7 @@ public class SoftMax extends Nerve {
return matrix;
}
private Mes softMax(boolean isStudy, Matrix matrix) throws Exception {//计算当前输出结果
private Mes softMax(boolean isStudy, Matrix matrix, boolean outAllPro) throws Exception {//计算当前输出结果
double sigma = 0;
int id = 0;
double poi = 0;
@ -94,7 +97,7 @@ public class SoftMax extends Nerve {
for (int i = 0; i < size; i++) {
double eSelf = Math.exp(matrix.getNumber(0, i));
double value = eSelf / sigma;
if (isStudy) {
if (isStudy || outAllPro) {
softMax.add(value);
}
if (value > poi) {

View File

@ -152,22 +152,22 @@ public class LayNorm {//残差与归一化
}
public void addNorm(Matrix feature, Matrix outMatrix, long eventID, boolean isStudy
, OutBack outBack, List<Integer> E, Matrix encoderFeature) throws Exception {//残差及归一化
, OutBack outBack, List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {//残差及归一化
Matrix myMatrix = matrixOperation.add(feature, outMatrix);//残差相加
Matrix out = layNorm(myMatrix, isStudy);
if (type == 1) {
if (myEncoderBlock != null) {
sendHiddenParameter(out, eventID, isStudy, outBack, E, encoderFeature);//发送线性第一层
sendHiddenParameter(out, eventID, isStudy, outBack, E, encoderFeature, outAllPro);//发送线性第一层
} else if (firstDecoderBlock != null) {//解码器第一层//输出
firstDecoderBlock.sendOutputMatrix(eventID, out, isStudy, outBack, E);
firstDecoderBlock.sendOutputMatrix(eventID, out, isStudy, outBack, E, outAllPro);
}
} else {//输出矩阵
myEncoderBlock.sendOutputMatrix(eventID, out, isStudy, outBack, E, encoderFeature);
myEncoderBlock.sendOutputMatrix(eventID, out, isStudy, outBack, E, encoderFeature, outAllPro);
}
}
public void addNormFromNerve(long eventID, boolean isStudy, Matrix parameter, Matrix allFeature,
OutBack outBack, List<Integer> E, Matrix encoderFeature) throws Exception {
OutBack outBack, List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {
Matrix matrixFeature;
if (reMatrixMap.containsKey(eventID)) {
Matrix myFeature = reMatrixMap.get(eventID);
@ -178,14 +178,14 @@ public class LayNorm {//残差与归一化
reMatrixMap.put(eventID, matrixFeature);
if (matrixFeature.getY() == featureDimension) {//执行残差
reMatrixMap.remove(eventID);
addNorm(matrixFeature, allFeature, eventID, isStudy, outBack, E, encoderFeature);
addNorm(matrixFeature, allFeature, eventID, isStudy, outBack, E, encoderFeature, outAllPro);
}
}
private void sendHiddenParameter(Matrix feature, long eventId, boolean isStudy
, OutBack outBack, List<Integer> E, Matrix encoderFeature) throws Exception {//hiddenNerves
, OutBack outBack, List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {//hiddenNerves
for (HiddenNerve hiddenNerve : hiddenNerves) {
hiddenNerve.receive(feature, eventId, isStudy, outBack, E, encoderFeature);
hiddenNerve.receive(feature, eventId, isStudy, outBack, E, encoderFeature, outAllPro);
}
}

View File

@ -204,7 +204,7 @@ public class MultiSelfAttention {//多头自注意力层
}
public void sendMatrixMessage(long eventID, Matrix feature, boolean isStudy
, OutBack outBack, List<Integer> E, Matrix encoderFeature) throws Exception {//从输入神经元
, OutBack outBack, List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {//从输入神经元
if (depth == 1) {//如果是第一层则添加时间序列参数
if (selfTimeCode) {
addTimeCodeBySelf(feature);
@ -218,7 +218,7 @@ public class MultiSelfAttention {//多头自注意力层
eventBodies.add(eventBody);
}
Matrix matrix = countMultiSelfAttention(eventBodies, isStudy);//多头输出
layNorm.addNorm(feature, matrix, eventID, isStudy, outBack, E, encoderFeature);//进第一个残差层
layNorm.addNorm(feature, matrix, eventID, isStudy, outBack, E, encoderFeature, outAllPro);//进第一个残差层
}