From f0beccab0b20d926a7e93c3bd7d7a46e5278da68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=A4=A7=E9=B9=8F?= <794757862@qq.com> Date: Mon, 15 Jul 2024 10:17:58 +0800 Subject: [PATCH] =?UTF-8?q?transFormer=E5=A2=9E=E5=8A=A0=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E8=8E=B7=E5=8F=96=E4=B8=8E=E6=B3=A8=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../java/org/wlld/transFormer/CodecBlock.java | 35 ++++++++++- .../wlld/transFormer/FirstDecoderBlock.java | 13 +++++ .../java/org/wlld/transFormer/LineBlock.java | 29 +++++++++- .../wlld/transFormer/TransFormerManager.java | 29 +++++++++- .../transFormer/model/CodecBlockModel.java | 52 +++++++++++++++++ .../transFormer/model/FirstDecoderModel.java | 22 +++++++ .../wlld/transFormer/model/LayNormModel.java | 23 ++++++++ .../transFormer/model/LineBlockModel.java | 24 ++++++++ .../MultiSelfAttentionModel.java | 2 +- .../{seflAttention => model}/QKVModel.java | 2 +- .../transFormer/model/TransFormerModel.java | 42 ++++++++++++++ .../wlld/transFormer/nerve/HiddenNerve.java | 6 +- .../org/wlld/transFormer/nerve/Nerve.java | 58 +++++-------------- .../org/wlld/transFormer/nerve/OutNerve.java | 2 +- .../org/wlld/transFormer/nerve/SoftMax.java | 5 +- .../transFormer/seflAttention/LayNorm.java | 27 ++++++++- .../seflAttention/MultiSelfAttention.java | 4 +- .../seflAttention/SelfAttention.java | 1 + 18 files changed, 312 insertions(+), 64 deletions(-) create mode 100644 src/main/java/org/wlld/transFormer/model/CodecBlockModel.java create mode 100644 src/main/java/org/wlld/transFormer/model/FirstDecoderModel.java create mode 100644 src/main/java/org/wlld/transFormer/model/LayNormModel.java create mode 100644 src/main/java/org/wlld/transFormer/model/LineBlockModel.java rename src/main/java/org/wlld/transFormer/{seflAttention => model}/MultiSelfAttentionModel.java (93%) rename src/main/java/org/wlld/transFormer/{seflAttention => model}/QKVModel.java (93%) create mode 100644 src/main/java/org/wlld/transFormer/model/TransFormerModel.java diff --git a/src/main/java/org/wlld/transFormer/CodecBlock.java b/src/main/java/org/wlld/transFormer/CodecBlock.java index dbcb7e6..27d40a9 100644 --- a/src/main/java/org/wlld/transFormer/CodecBlock.java +++ b/src/main/java/org/wlld/transFormer/CodecBlock.java @@ -4,6 +4,7 @@ import org.wlld.function.ReLu; import org.wlld.i.OutBack; import org.wlld.matrixTools.Matrix; import org.wlld.matrixTools.MatrixOperation; +import org.wlld.transFormer.model.CodecBlockModel; import org.wlld.transFormer.nerve.HiddenNerve; import org.wlld.transFormer.nerve.Nerve; import org.wlld.transFormer.seflAttention.LayNorm; @@ -24,11 +25,39 @@ public class CodecBlock { private CodecBlock afterEncoderBlock;//后编码模块 private CodecBlock beforeEncoderBlock;//前编码模块 private CodecBlock lastEncoderBlock;//最后一层编码器 - private Map outMatrixMap = new HashMap<>(); + private final Map outMatrixMap = new HashMap<>(); private final boolean encoder;//是否为编码器 private LineBlock lineBlock;//解码器最后的线性分类器 private FirstDecoderBlock firstDecoderBlock;//解码器第一层 + public CodecBlockModel getModel() { + List firstNerveModel = new ArrayList<>(); + List secondNerveModel = new ArrayList<>(); + for (int i = 0; i < fistHiddenNerves.size(); i++) { + firstNerveModel.add(fistHiddenNerves.get(i).getModel()); + secondNerveModel.add(secondHiddenNerves.get(i).getModel()); + } + CodecBlockModel codecBlockModel = new CodecBlockModel(); + codecBlockModel.setMultiSelfAttentionModel(multiSelfAttention.getModel()); + codecBlockModel.setAttentionLayNormModel(attentionLayNorm.getModel()); + codecBlockModel.setFistNervesModel(firstNerveModel); + codecBlockModel.setSecondNervesModel(secondNerveModel); + codecBlockModel.setLineLayNormModel(lineLayNorm.getModel()); + return codecBlockModel; + } + + public void insertModel(CodecBlockModel codecBlockModel) throws Exception { + multiSelfAttention.insertModel(codecBlockModel.getMultiSelfAttentionModel()); + attentionLayNorm.insertModel(codecBlockModel.getAttentionLayNormModel()); + List firstNerveModel = codecBlockModel.getFistNervesModel(); + List secondNerveModel = codecBlockModel.getSecondNervesModel(); + for (int i = 0; i < fistHiddenNerves.size(); i++) { + fistHiddenNerves.get(i).insertModel(firstNerveModel.get(i)); + secondHiddenNerves.get(i).insertModel(secondNerveModel.get(i)); + } + lineLayNorm.insertModel(codecBlockModel.getLineLayNormModel()); + } + public void setFirstDecoderBlock(FirstDecoderBlock firstDecoderBlock) { this.firstDecoderBlock = firstDecoderBlock; } @@ -117,14 +146,14 @@ public class CodecBlock { List secondNerves = new ArrayList<>(); for (int i = 0; i < featureDimension; i++) { HiddenNerve hiddenNerve1 = new HiddenNerve(i + 1, 1, studyPoint, new ReLu(), featureDimension, - featureDimension, true, null); + featureDimension, null); fistHiddenNerves.add(hiddenNerve1); hiddenNerve1.setAfterLayNorm(attentionLayNorm); firstNerves.add(hiddenNerve1); } for (int i = 0; i < featureDimension; i++) { HiddenNerve hiddenNerve2 = new HiddenNerve(i + 1, 2, studyPoint, null, - featureDimension, 1, true, null); + featureDimension, 1, null); hiddenNerve2.setBeforeLayNorm(lineLayNorm); secondHiddenNerves.add(hiddenNerve2); secondNerves.add(hiddenNerve2); diff --git a/src/main/java/org/wlld/transFormer/FirstDecoderBlock.java b/src/main/java/org/wlld/transFormer/FirstDecoderBlock.java index fd1be77..0c717ce 100644 --- a/src/main/java/org/wlld/transFormer/FirstDecoderBlock.java +++ b/src/main/java/org/wlld/transFormer/FirstDecoderBlock.java @@ -2,6 +2,7 @@ package org.wlld.transFormer; import org.wlld.i.OutBack; import org.wlld.matrixTools.Matrix; +import org.wlld.transFormer.model.FirstDecoderModel; import org.wlld.transFormer.seflAttention.LayNorm; import org.wlld.transFormer.seflAttention.MultiSelfAttention; @@ -29,6 +30,18 @@ public class FirstDecoderBlock {//解码器模块 this.codecBlock = codecBlock; } + public FirstDecoderModel getModel() { + FirstDecoderModel firstDecoderModel = new FirstDecoderModel(); + firstDecoderModel.setMultiSelfAttentionModel(multiSelfAttention.getModel()); + firstDecoderModel.setAttentionLayNormModel(attentionLayNorm.getModel()); + return firstDecoderModel; + } + + public void insertModel(FirstDecoderModel firstDecoderModel) throws Exception { + multiSelfAttention.insertModel(firstDecoderModel.getMultiSelfAttentionModel()); + attentionLayNorm.insertModel(firstDecoderModel.getAttentionLayNormModel()); + } + public void backError(long eventID, Matrix error) throws Exception { attentionLayNorm.backErrorFromLine(error, eventID); lastEncoderBlock.encoderBackStart(eventID); diff --git a/src/main/java/org/wlld/transFormer/LineBlock.java b/src/main/java/org/wlld/transFormer/LineBlock.java index 1867a9a..c152d24 100644 --- a/src/main/java/org/wlld/transFormer/LineBlock.java +++ b/src/main/java/org/wlld/transFormer/LineBlock.java @@ -4,6 +4,7 @@ import org.wlld.function.Tanh; import org.wlld.i.OutBack; import org.wlld.matrixTools.Matrix; import org.wlld.matrixTools.MatrixOperation; +import org.wlld.transFormer.model.LineBlockModel; import org.wlld.transFormer.nerve.HiddenNerve; import org.wlld.transFormer.nerve.Nerve; import org.wlld.transFormer.nerve.OutNerve; @@ -20,6 +21,32 @@ public class LineBlock {//线性层模块 private final int featureDimension; private int backNumber = 0;//误差返回次数 + public LineBlockModel getModel() { + LineBlockModel lineBlockModel = new LineBlockModel(); + List hiddenNerveModel = new ArrayList<>(); + List outNerveModel = new ArrayList<>(); + for (HiddenNerve hiddenNerve : hiddenNerveList) { + hiddenNerveModel.add(hiddenNerve.getModel()); + } + for (OutNerve outNerve : outNerveList) { + outNerveModel.add(outNerve.getModel()); + } + lineBlockModel.setHiddenNervesModel(hiddenNerveModel); + lineBlockModel.setOutNervesModel(outNerveModel); + return lineBlockModel; + } + + public void insertModel(LineBlockModel lineBlockModel) throws Exception { + List hiddenNerveModel = lineBlockModel.getHiddenNervesModel(); + List outNerveModel = lineBlockModel.getOutNervesModel(); + for (int i = 0; i < hiddenNerveList.size(); i++) { + hiddenNerveList.get(i).insertModel(hiddenNerveModel.get(i)); + } + for (int i = 0; i < outNerveList.size(); i++) { + outNerveList.get(i).insertModel(outNerveModel.get(i)); + } + } + public LineBlock(int typeNumber, int featureDimension, double studyPoint, CodecBlock lastCodecBlock, boolean showLog) throws Exception { this.featureDimension = featureDimension; this.lastCodecBlock = lastCodecBlock; @@ -28,7 +55,7 @@ public class LineBlock {//线性层模块 List hiddenNerves = new ArrayList<>(); for (int i = 0; i < featureDimension; i++) { HiddenNerve hiddenNerve = new HiddenNerve(i + 1, 1, studyPoint, new Tanh(), featureDimension, - typeNumber, false, this); + typeNumber, this); hiddenNerves.add(hiddenNerve); hiddenNerveList.add(hiddenNerve); } diff --git a/src/main/java/org/wlld/transFormer/TransFormerManager.java b/src/main/java/org/wlld/transFormer/TransFormerManager.java index f0bd90c..faa88e2 100644 --- a/src/main/java/org/wlld/transFormer/TransFormerManager.java +++ b/src/main/java/org/wlld/transFormer/TransFormerManager.java @@ -1,6 +1,7 @@ package org.wlld.transFormer; -import org.wlld.naturalLanguage.word.WordEmbedding; +import org.wlld.transFormer.model.CodecBlockModel; +import org.wlld.transFormer.model.TransFormerModel; import org.wlld.transFormer.nerve.SensoryNerve; import java.util.ArrayList; @@ -17,6 +18,32 @@ public class TransFormerManager { return sensoryNerve; } + public TransFormerModel getModel() { + TransFormerModel transFormerModel = new TransFormerModel(); + List encoderBlockModels = new ArrayList<>(); + List decoderBlockModels = new ArrayList<>(); + for (int i = 0; i < encoderBlocks.size(); i++) { + encoderBlockModels.add(encoderBlocks.get(i).getModel()); + decoderBlockModels.add(decoderBlocks.get(i).getModel()); + } + transFormerModel.setEncoderBlockModels(encoderBlockModels); + transFormerModel.setDecoderBlockModels(decoderBlockModels); + transFormerModel.setFirstDecoderBlockModel(firstDecoderBlock.getModel()); + transFormerModel.setLineBlockModel(lineBlock.getModel()); + return transFormerModel; + } + + public void insertModel(TransFormerModel transFormerModel) throws Exception { + List encoderBlockModels = transFormerModel.getEncoderBlockModels(); + List decoderBlockModels = transFormerModel.getDecoderBlockModels(); + for (int i = 0; i < encoderBlocks.size(); i++) { + encoderBlocks.get(i).insertModel(encoderBlockModels.get(i)); + decoderBlocks.get(i).insertModel(decoderBlockModels.get(i)); + } + firstDecoderBlock.insertModel(transFormerModel.getFirstDecoderBlockModel()); + lineBlock.insertModel(transFormerModel.getLineBlockModel()); + } + /** * 初始化神经元参数 * diff --git a/src/main/java/org/wlld/transFormer/model/CodecBlockModel.java b/src/main/java/org/wlld/transFormer/model/CodecBlockModel.java new file mode 100644 index 0000000..93cf0a2 --- /dev/null +++ b/src/main/java/org/wlld/transFormer/model/CodecBlockModel.java @@ -0,0 +1,52 @@ +package org.wlld.transFormer.model; + + +import java.util.List; + +public class CodecBlockModel { + private MultiSelfAttentionModel multiSelfAttentionModel;//注意力层model + private LayNormModel attentionLayNormModel;//残差1层model + private List fistNervesModel;//FNN层第一层model + private List secondNervesModel;//FNN层第二层model + private LayNormModel lineLayNormModel;//残差层最后2层model + + public MultiSelfAttentionModel getMultiSelfAttentionModel() { + return multiSelfAttentionModel; + } + + public void setMultiSelfAttentionModel(MultiSelfAttentionModel multiSelfAttentionModel) { + this.multiSelfAttentionModel = multiSelfAttentionModel; + } + + public LayNormModel getAttentionLayNormModel() { + return attentionLayNormModel; + } + + public void setAttentionLayNormModel(LayNormModel attentionLayNormModel) { + this.attentionLayNormModel = attentionLayNormModel; + } + + public List getFistNervesModel() { + return fistNervesModel; + } + + public void setFistNervesModel(List fistNervesModel) { + this.fistNervesModel = fistNervesModel; + } + + public List getSecondNervesModel() { + return secondNervesModel; + } + + public void setSecondNervesModel(List secondNervesModel) { + this.secondNervesModel = secondNervesModel; + } + + public LayNormModel getLineLayNormModel() { + return lineLayNormModel; + } + + public void setLineLayNormModel(LayNormModel lineLayNormModel) { + this.lineLayNormModel = lineLayNormModel; + } +} diff --git a/src/main/java/org/wlld/transFormer/model/FirstDecoderModel.java b/src/main/java/org/wlld/transFormer/model/FirstDecoderModel.java new file mode 100644 index 0000000..6eca046 --- /dev/null +++ b/src/main/java/org/wlld/transFormer/model/FirstDecoderModel.java @@ -0,0 +1,22 @@ +package org.wlld.transFormer.model; + +public class FirstDecoderModel { + private MultiSelfAttentionModel multiSelfAttentionModel;//注意力层model + private LayNormModel attentionLayNormModel;//残差1层model + + public MultiSelfAttentionModel getMultiSelfAttentionModel() { + return multiSelfAttentionModel; + } + + public void setMultiSelfAttentionModel(MultiSelfAttentionModel multiSelfAttentionModel) { + this.multiSelfAttentionModel = multiSelfAttentionModel; + } + + public LayNormModel getAttentionLayNormModel() { + return attentionLayNormModel; + } + + public void setAttentionLayNormModel(LayNormModel attentionLayNormModel) { + this.attentionLayNormModel = attentionLayNormModel; + } +} diff --git a/src/main/java/org/wlld/transFormer/model/LayNormModel.java b/src/main/java/org/wlld/transFormer/model/LayNormModel.java new file mode 100644 index 0000000..85f333b --- /dev/null +++ b/src/main/java/org/wlld/transFormer/model/LayNormModel.java @@ -0,0 +1,23 @@ +package org.wlld.transFormer.model; + + +public class LayNormModel { + private double[][] bTa;//模型需要保存 + private double[][] power;//模型需要保存 + + public double[][] getbTa() { + return bTa; + } + + public void setbTa(double[][] bTa) { + this.bTa = bTa; + } + + public double[][] getPower() { + return power; + } + + public void setPower(double[][] power) { + this.power = power; + } +} diff --git a/src/main/java/org/wlld/transFormer/model/LineBlockModel.java b/src/main/java/org/wlld/transFormer/model/LineBlockModel.java new file mode 100644 index 0000000..d30f992 --- /dev/null +++ b/src/main/java/org/wlld/transFormer/model/LineBlockModel.java @@ -0,0 +1,24 @@ +package org.wlld.transFormer.model; + +import java.util.List; + +public class LineBlockModel { + private List hiddenNervesModel;//隐层model + private List outNervesModel;//输出层model + + public List getHiddenNervesModel() { + return hiddenNervesModel; + } + + public void setHiddenNervesModel(List hiddenNervesModel) { + this.hiddenNervesModel = hiddenNervesModel; + } + + public List getOutNervesModel() { + return outNervesModel; + } + + public void setOutNervesModel(List outNervesModel) { + this.outNervesModel = outNervesModel; + } +} diff --git a/src/main/java/org/wlld/transFormer/seflAttention/MultiSelfAttentionModel.java b/src/main/java/org/wlld/transFormer/model/MultiSelfAttentionModel.java similarity index 93% rename from src/main/java/org/wlld/transFormer/seflAttention/MultiSelfAttentionModel.java rename to src/main/java/org/wlld/transFormer/model/MultiSelfAttentionModel.java index bf98579..b8cf17a 100644 --- a/src/main/java/org/wlld/transFormer/seflAttention/MultiSelfAttentionModel.java +++ b/src/main/java/org/wlld/transFormer/model/MultiSelfAttentionModel.java @@ -1,4 +1,4 @@ -package org.wlld.transFormer.seflAttention; +package org.wlld.transFormer.model; import java.util.List; diff --git a/src/main/java/org/wlld/transFormer/seflAttention/QKVModel.java b/src/main/java/org/wlld/transFormer/model/QKVModel.java similarity index 93% rename from src/main/java/org/wlld/transFormer/seflAttention/QKVModel.java rename to src/main/java/org/wlld/transFormer/model/QKVModel.java index 3ae3120..a87d68b 100644 --- a/src/main/java/org/wlld/transFormer/seflAttention/QKVModel.java +++ b/src/main/java/org/wlld/transFormer/model/QKVModel.java @@ -1,4 +1,4 @@ -package org.wlld.transFormer.seflAttention; +package org.wlld.transFormer.model; public class QKVModel { private double[][] Q; diff --git a/src/main/java/org/wlld/transFormer/model/TransFormerModel.java b/src/main/java/org/wlld/transFormer/model/TransFormerModel.java new file mode 100644 index 0000000..0ffd0cd --- /dev/null +++ b/src/main/java/org/wlld/transFormer/model/TransFormerModel.java @@ -0,0 +1,42 @@ +package org.wlld.transFormer.model; + +import java.util.List; + +public class TransFormerModel { + private List encoderBlockModels;//编码器模块 + private List decoderBlockModels;//解码器模块 + private FirstDecoderModel firstDecoderBlockModel;//第一个解码器模块 + private LineBlockModel lineBlockModel;//线性分类层 + + public List getEncoderBlockModels() { + return encoderBlockModels; + } + + public void setEncoderBlockModels(List encoderBlockModels) { + this.encoderBlockModels = encoderBlockModels; + } + + public List getDecoderBlockModels() { + return decoderBlockModels; + } + + public void setDecoderBlockModels(List decoderBlockModels) { + this.decoderBlockModels = decoderBlockModels; + } + + public FirstDecoderModel getFirstDecoderBlockModel() { + return firstDecoderBlockModel; + } + + public void setFirstDecoderBlockModel(FirstDecoderModel firstDecoderBlockModel) { + this.firstDecoderBlockModel = firstDecoderBlockModel; + } + + public LineBlockModel getLineBlockModel() { + return lineBlockModel; + } + + public void setLineBlockModel(LineBlockModel lineBlockModel) { + this.lineBlockModel = lineBlockModel; + } +} diff --git a/src/main/java/org/wlld/transFormer/nerve/HiddenNerve.java b/src/main/java/org/wlld/transFormer/nerve/HiddenNerve.java index 0e01e11..fca54ee 100644 --- a/src/main/java/org/wlld/transFormer/nerve/HiddenNerve.java +++ b/src/main/java/org/wlld/transFormer/nerve/HiddenNerve.java @@ -21,9 +21,9 @@ import java.util.Map; public class HiddenNerve extends Nerve { public HiddenNerve(int id, int depth, double studyPoint, ActiveFunction activeFunction, int sensoryNerveNub, - int outNerveNub, boolean isEncoder, LineBlock lineBlock) throws Exception {//隐层神经元 + int outNerveNub, LineBlock lineBlock) throws Exception {//隐层神经元 super(id, "HiddenNerve", studyPoint, activeFunction, sensoryNerveNub, 0, - outNerveNub, isEncoder, lineBlock); + outNerveNub, lineBlock); this.depth = depth; } @@ -41,7 +41,7 @@ public class HiddenNerve extends Nerve { protected void input(long eventId, Matrix parameter, boolean isStudy, Matrix allFeature, OutBack outBack, List E, Matrix encoderFeature) throws Exception {//第二层收到参数 boolean allReady = insertMatrixParameter(eventId, parameter); - if (allReady) {//参数齐了,开始计算 sigma - threshold + if (allReady) {//参数齐了,开始计算 Matrix out = opMatrix(reMatrixFeatures.get(eventId), isStudy); reMatrixFeatures.remove(eventId); beforeLayNorm.addNormFromNerve(eventId, isStudy, out, allFeature, outBack, E, encoderFeature); diff --git a/src/main/java/org/wlld/transFormer/nerve/Nerve.java b/src/main/java/org/wlld/transFormer/nerve/Nerve.java index 482ab07..e24b33e 100644 --- a/src/main/java/org/wlld/transFormer/nerve/Nerve.java +++ b/src/main/java/org/wlld/transFormer/nerve/Nerve.java @@ -21,16 +21,12 @@ public abstract class Nerve { private final List father = new ArrayList<>();//树突上一层的连接神经元 protected LayNorm beforeLayNorm;//多头自注意力层 protected LayNorm afterLayNorm;//多头自注意力层 - protected Map dendrites = new HashMap<>();//上一层权重(需要取出) - protected Matrix powerMatrix;//权重矩阵 + protected Matrix powerMatrix;//权重矩阵 作为模型取出 private final int id;//同级神经元编号,注意在同层编号中ID应有唯一性 private final int hiddenNerveNub;//隐层神经元个数 private final int sensoryNerveNub;//输入神经元个数 private final int outNerveNub;//输出神经元个数 - private final boolean encoder; - protected Map> features = new HashMap<>();//上一层神经元输入的数值 protected Map reMatrixFeatures = new HashMap<>(); - protected double threshold;//此神经元的阈值需要取出 protected String name;//该神经元所属类型 protected Matrix featureMatrix; protected double E;//模板期望值 @@ -47,23 +43,6 @@ public abstract class Nerve { return depth; } - public Map getDendrites() { - return dendrites; - } - - - public void setDendrites(Map dendrites) { - this.dendrites = dendrites; - } - - public double getThreshold() { - return threshold; - } - - public void setThreshold(double threshold) { - this.threshold = threshold; - } - public void setBeforeLayNorm(LayNorm beforeLayNorm) { this.beforeLayNorm = beforeLayNorm; } @@ -73,10 +52,9 @@ public abstract class Nerve { } protected Nerve(int id, String name, double studyPoint, ActiveFunction activeFunction, int sensoryNerveNub, - int hiddenNerveNub, int outNerveNub, boolean encoder, LineBlock lineBlock) throws Exception {//该神经元在同层神经元中的编号 + int hiddenNerveNub, int outNerveNub, LineBlock lineBlock) throws Exception {//该神经元在同层神经元中的编号 this.id = id; this.lineBlock = lineBlock; - this.encoder = encoder; this.hiddenNerveNub = hiddenNerveNub;//隐层神经元个数 this.sensoryNerveNub = sensoryNerveNub;//输入神经元个数 this.outNerveNub = outNerveNub;//输出神经元个数 @@ -86,10 +64,17 @@ public abstract class Nerve { initPower();//生成随机权重 } - protected void setStudyPoint(double studyPoint) { - this.studyPoint = studyPoint; + public double[][] getModel() { + return powerMatrix.getMatrix(); } + public void insertModel(double[][] modelPower) throws Exception { + for (int i = 0; i < powerMatrix.getX(); i++) { + for (int j = 0; j < powerMatrix.getY(); j++) { + powerMatrix.setNub(i, j, modelPower[i][j]); + } + } + } protected void sendMessage(long eventId, Matrix parameter, boolean isStudy, Matrix allFeature, OutBack outBack, List E, Matrix encoderFeature) throws Exception { @@ -208,19 +193,6 @@ public abstract class Nerve { return sigma; } - protected double calculation(long eventId) throws Exception {//计算当前输出结果 - double sigma = 0; - List featuresList = features.get(eventId); - if (dendrites.size() != featuresList.size()) { - throw new Exception("隐层参数数量与权重数量不一致"); - } - for (int i = 0; i < featuresList.size(); i++) { - double value = featuresList.get(i); - double w = dendrites.get(i + 1);//当value不为0的时候把w取出来 - sigma = w * value + sigma; - } - return sigma - threshold; - } private void initPower() throws Exception {//初始化权重及阈值 Random random = new Random(); @@ -235,14 +207,12 @@ public abstract class Nerve { if (myUpNumber > 0) {//输入个数 powerMatrix = new Matrix(myUpNumber + 1, 1); double sh = Math.sqrt(myUpNumber); - for (int i = 1; i < myUpNumber + 1; i++) { + for (int i = 0; i < myUpNumber; i++) { double nub = random.nextDouble() / sh; - dendrites.put(i, nub);//random.nextDouble() - powerMatrix.setNub(i - 1, 0, nub); + powerMatrix.setNub(i, 0, nub); } //生成随机阈值 - threshold = random.nextDouble() / sh; - powerMatrix.setNub(myUpNumber, 0, threshold); + powerMatrix.setNub(myUpNumber, 0, random.nextDouble() / sh); } } diff --git a/src/main/java/org/wlld/transFormer/nerve/OutNerve.java b/src/main/java/org/wlld/transFormer/nerve/OutNerve.java index bc98b97..5a3054d 100644 --- a/src/main/java/org/wlld/transFormer/nerve/OutNerve.java +++ b/src/main/java/org/wlld/transFormer/nerve/OutNerve.java @@ -18,7 +18,7 @@ public class OutNerve extends Nerve { public OutNerve(int id, double studyPoint, int sensoryNerveNub, int hiddenNerveNub, int outNerveNub, SoftMax softMax) throws Exception { super(id, "OutNerve", studyPoint, null, sensoryNerveNub, - hiddenNerveNub, outNerveNub, false, null); + hiddenNerveNub, outNerveNub, null); this.softMax = softMax; } diff --git a/src/main/java/org/wlld/transFormer/nerve/SoftMax.java b/src/main/java/org/wlld/transFormer/nerve/SoftMax.java index 3535e4d..2c7b5b3 100644 --- a/src/main/java/org/wlld/transFormer/nerve/SoftMax.java +++ b/src/main/java/org/wlld/transFormer/nerve/SoftMax.java @@ -1,14 +1,12 @@ package org.wlld.transFormer.nerve; -import org.wlld.config.RZ; import org.wlld.i.OutBack; import org.wlld.matrixTools.Matrix; import org.wlld.matrixTools.MatrixOperation; import java.util.ArrayList; import java.util.List; -import java.util.Map; public class SoftMax extends Nerve { private final List outNerves; @@ -16,8 +14,7 @@ public class SoftMax extends Nerve { public SoftMax(List outNerves, boolean isShowLog , int sensoryNerveNub, int hiddenNerveNub, int outNerveNub) throws Exception { - super(0, "softMax", 0, null, sensoryNerveNub, hiddenNerveNub, outNerveNub, - false, null); + super(0, "softMax", 0, null, sensoryNerveNub, hiddenNerveNub, outNerveNub, null); this.outNerves = outNerves; this.isShowLog = isShowLog; } diff --git a/src/main/java/org/wlld/transFormer/seflAttention/LayNorm.java b/src/main/java/org/wlld/transFormer/seflAttention/LayNorm.java index 88f1df9..dc53bba 100644 --- a/src/main/java/org/wlld/transFormer/seflAttention/LayNorm.java +++ b/src/main/java/org/wlld/transFormer/seflAttention/LayNorm.java @@ -5,6 +5,7 @@ import org.wlld.matrixTools.Matrix; import org.wlld.matrixTools.MatrixOperation; import org.wlld.transFormer.CodecBlock; import org.wlld.transFormer.FirstDecoderBlock; +import org.wlld.transFormer.model.LayNormModel; import org.wlld.transFormer.nerve.HiddenNerve; import java.util.HashMap; @@ -17,16 +18,36 @@ public class LayNorm {//残差与归一化 private final CodecBlock myEncoderBlock; private final int featureDimension;//特征维度 private List hiddenNerves;//第一层隐层 - private final int type;//类别层 + private final int type;//类别层模型需要保存 private final Map reMatrixMap = new HashMap<>(); private final FirstDecoderBlock firstDecoderBlock; - private Matrix bTa; - private Matrix power; + private Matrix bTa;//模型需要保存 + private Matrix power;//模型需要保存 private Matrix myNormData;//第一步归一化后的数据 private final double study;//学习率 private Matrix myFinalError;//从FNN传来的总误差 private int number;//记录fnn传来的误差次数 + public LayNormModel getModel() { + LayNormModel layNormModel = new LayNormModel(); + layNormModel.setbTa(bTa.getMatrix()); + layNormModel.setPower(power.getMatrix()); + return layNormModel; + } + + public void insertModel(LayNormModel layNormModel) throws Exception { + insertPower(layNormModel.getPower(), power); + insertPower(layNormModel.getbTa(), bTa); + } + + private void insertPower(double[][] modelPower, Matrix power) throws Exception { + for (int i = 0; i < power.getX(); i++) { + for (int j = 0; j < power.getY(); j++) { + power.setNub(i, j, modelPower[i][j]); + } + } + } + public LayNorm(int type, int featureDimension, CodecBlock myEncoderBlock, FirstDecoderBlock firstDecoderBlock , double study) throws Exception { this.study = study; diff --git a/src/main/java/org/wlld/transFormer/seflAttention/MultiSelfAttention.java b/src/main/java/org/wlld/transFormer/seflAttention/MultiSelfAttention.java index d524d51..e7d71d7 100644 --- a/src/main/java/org/wlld/transFormer/seflAttention/MultiSelfAttention.java +++ b/src/main/java/org/wlld/transFormer/seflAttention/MultiSelfAttention.java @@ -3,9 +3,9 @@ package org.wlld.transFormer.seflAttention; import org.wlld.matrixTools.Matrix; import org.wlld.matrixTools.MatrixOperation; import org.wlld.i.OutBack; -import org.wlld.tools.Frequency; import org.wlld.transFormer.CodecBlock; -import org.wlld.transFormer.nerve.HiddenNerve; +import org.wlld.transFormer.model.MultiSelfAttentionModel; +import org.wlld.transFormer.model.QKVModel; import java.util.*; diff --git a/src/main/java/org/wlld/transFormer/seflAttention/SelfAttention.java b/src/main/java/org/wlld/transFormer/seflAttention/SelfAttention.java index 5ab6509..9eee9a2 100644 --- a/src/main/java/org/wlld/transFormer/seflAttention/SelfAttention.java +++ b/src/main/java/org/wlld/transFormer/seflAttention/SelfAttention.java @@ -2,6 +2,7 @@ package org.wlld.transFormer.seflAttention; import org.wlld.matrixTools.Matrix; import org.wlld.matrixTools.MatrixOperation; +import org.wlld.transFormer.model.QKVModel; import java.util.HashMap; import java.util.Map;