From 72a6fb0eb74aa24701dbd4829e7babcf0b10739d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=A4=A7=E9=B9=8F?= <794757862@qq.com> Date: Wed, 10 Apr 2024 10:57:14 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9F=BA=E7=A1=80=E7=A5=9E=E7=BB=8F=E7=BD=91?= =?UTF-8?q?=E7=BB=9C=E7=9A=84softMax=E8=AE=A1=E7=AE=97=E6=80=A7=E8=83=BD?= =?UTF-8?q?=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../org/wlld/nerveCenter/NerveManager.java | 12 +-- .../java/org/wlld/nerveEntity/OutNerve.java | 2 +- .../java/org/wlld/nerveEntity/SoftMax.java | 90 +++++++++++++------ 3 files changed, 69 insertions(+), 35 deletions(-) diff --git a/src/main/java/org/wlld/nerveCenter/NerveManager.java b/src/main/java/org/wlld/nerveCenter/NerveManager.java index 4bdea92..b7ae256 100644 --- a/src/main/java/org/wlld/nerveCenter/NerveManager.java +++ b/src/main/java/org/wlld/nerveCenter/NerveManager.java @@ -45,6 +45,7 @@ public class NerveManager { public void setMatrixMap(Map matrixMap) { this.matrixMap = matrixMap; } + private Map conversion(Map map) { Map cMap = new HashMap<>(); for (Map.Entry entry : map.entrySet()) { @@ -60,6 +61,7 @@ public class NerveManager { } return cMap; } + private ModelParameter getDymModelParameter() throws Exception {//获取动态神经元参数 ModelParameter modelParameter = new ModelParameter(); List dymNerveStudies = new ArrayList<>();//动态神经元隐层 @@ -184,7 +186,7 @@ public class NerveManager { NerveStudy nerveStudy = outStudyNerves.get(i); outNerve.setThreshold(nerveStudy.getThreshold()); Map dendrites = outNerve.getDendrites(); - Map studyDendrites =unConversion(nerveStudy.getDendrites()); + Map studyDendrites = unConversion(nerveStudy.getDendrites()); for (Map.Entry outEntry : dendrites.entrySet()) { int key = outEntry.getKey(); dendrites.put(key, studyDendrites.get(key)); @@ -246,6 +248,7 @@ public class NerveManager { List nerveList = depthNerves.get(0);//第一层隐层神经元 //最后一层隐层神经元啊 List lastNerveList = depthNerves.get(depthNerves.size() - 1); + List myOutNerveList = new ArrayList<>(); //初始化输出神经元 for (int i = 1; i < outNerveNub + 1; i++) { OutNerve outNerve = new OutNerve(i, hiddenNerveNub, 0, studyPoint, initPower, @@ -253,16 +256,15 @@ public class NerveManager { if (isMatrix) {//是卷积层神经网络 outNerve.setMatrixMap(matrixMap); } - if (isSoftMax) { - SoftMax softMax = new SoftMax(i, outNerveNub, false, outNerve, isShowLog); - softMaxList.add(softMax); - } //输出层神经元连接最后一层隐层神经元 outNerve.connectFather(lastNerveList); outNerves.add(outNerve); + myOutNerveList.add(outNerve); } //生成softMax层 if (isSoftMax) {//增加softMax层 + SoftMax softMax = new SoftMax(outNerveNub, false, myOutNerveList, isShowLog); + softMaxList.add(softMax); for (Nerve nerve : outNerves) { nerve.connect(softMaxList); } diff --git a/src/main/java/org/wlld/nerveEntity/OutNerve.java b/src/main/java/org/wlld/nerveEntity/OutNerve.java index 1ee7a74..4f4b947 100644 --- a/src/main/java/org/wlld/nerveEntity/OutNerve.java +++ b/src/main/java/org/wlld/nerveEntity/OutNerve.java @@ -24,7 +24,7 @@ public class OutNerve extends Nerve { this.isSoftMax = isSoftMax; } - void getGBySoftMax(double g, long eventId, int id) throws Exception {//接收softMax层回传梯度 + void getGBySoftMax(double g, long eventId) throws Exception {//接收softMax层回传梯度 gradient = g; updatePower(eventId); } diff --git a/src/main/java/org/wlld/nerveEntity/SoftMax.java b/src/main/java/org/wlld/nerveEntity/SoftMax.java index fbf05c8..37f9108 100644 --- a/src/main/java/org/wlld/nerveEntity/SoftMax.java +++ b/src/main/java/org/wlld/nerveEntity/SoftMax.java @@ -3,17 +3,18 @@ package org.wlld.nerveEntity; import org.wlld.config.RZ; import org.wlld.i.OutBack; +import java.util.ArrayList; import java.util.List; import java.util.Map; public class SoftMax extends Nerve { - private OutNerve outNerve; - private boolean isShowLog; + private final List outNerves; + private final boolean isShowLog; - public SoftMax(int id, int upNub, boolean isDynamic, OutNerve outNerve, boolean isShowLog) throws Exception { - super(id, upNub, "softMax", 0, 0, false, null, isDynamic + public SoftMax(int upNub, boolean isDynamic, List outNerves, boolean isShowLog) throws Exception { + super(0, upNub, "softMax", 0, 0, false, null, isDynamic , RZ.NOT_RZ, 0, 0, 0); - this.outNerve = outNerve; + this.outNerves = outNerves; this.isShowLog = isShowLog; } @@ -21,24 +22,28 @@ public class SoftMax extends Nerve { protected void input(long eventId, double parameter, boolean isStudy, Map E, OutBack outBack) throws Exception { boolean allReady = insertParameter(eventId, parameter); if (allReady) { - double out = softMax(eventId);//输出值 + Mes mes = softMax(eventId, isStudy);//输出值 + int key = 0; if (isStudy) {//学习 - outNub = out; - if (E.containsKey(getId())) { - this.E = E.get(getId()); - } else { - this.E = 0; + for (Map.Entry entry : E.entrySet()) { + if (entry.getValue() > 0.9) { + key = entry.getKey(); + break; + } } if (isShowLog) { - System.out.println("softMax==" + this.E + ",out==" + out + ",nerveId==" + getId()); + System.out.println("softMax==" + key + ",out==" + mes.poi + ",nerveId==" + mes.typeID); } - gradient = -outGradient();//当前梯度变化 把梯度返回 + List errors = error(mes, key); features.remove(eventId); //清空当前上层输入参数参数 - outNerve.getGBySoftMax(gradient, eventId, getId()); + int size = outNerves.size(); + for (int i = 0; i < size; i++) { + outNerves.get(i).getGBySoftMax(errors.get(i), eventId); + } } else {//输出 destoryParameter(eventId); if (outBack != null) { - outBack.getBack(out, getId(), eventId); + outBack.getBack(mes.poi, mes.typeID, eventId); } else { throw new Exception("not find outBack"); } @@ -46,26 +51,53 @@ public class SoftMax extends Nerve { } } - - private double outGradient() {//生成输出层神经元梯度变化 - double g = outNub; - if (E == 1) { - //g = ArithUtil.sub(g, 1); - g = g - 1; + private List error(Mes mes, int key) { + int t = key - 1; + List softMax = mes.softMax; + List error = new ArrayList<>(); + for (int i = 0; i < softMax.size(); i++) { + double self = softMax.get(i); + double myError; + if (i != t) { + myError = -self; + } else { + myError = 1 - self; + } + error.add(myError); } - return g; + return error; } - private double softMax(long eventId) {//计算当前输出结果 + private Mes softMax(long eventId, boolean isStudy) {//计算当前输出结果 double sigma = 0; + int id = 0; + double poi = 0; + Mes mes = new Mes(); List featuresList = features.get(eventId); - double self = featuresList.get(getId() - 1); - double eSelf = Math.exp(self); - for (int i = 0; i < featuresList.size(); i++) { - double value = featuresList.get(i); - // sigma = ArithUtil.add(Math.exp(value), sigma); + for (double value : featuresList) { sigma = Math.exp(value) + sigma; } - return eSelf / sigma;//ArithUtil.div(eSelf, sigma); + List softMax = new ArrayList<>(); + for (int i = 0; i < featuresList.size(); i++) { + double eSelf = Math.exp(featuresList.get(i)); + double value = eSelf / sigma; + if (isStudy) { + softMax.add(value); + } + if (value > poi) { + poi = value; + id = i + 1; + } + } + mes.softMax = softMax; + mes.typeID = id; + mes.poi = poi; + return mes; + } + + static class Mes { + int typeID; + double poi; + List softMax; } }