mirror of
https://gitee.com/dromara/easyAi.git
synced 2024-11-29 18:27:37 +08:00
基础神经网络的softMax计算性能优化
This commit is contained in:
parent
1b225aa73b
commit
72a6fb0eb7
@ -45,6 +45,7 @@ public class NerveManager {
|
||||
public void setMatrixMap(Map<Integer, Matrix> matrixMap) {
|
||||
this.matrixMap = matrixMap;
|
||||
}
|
||||
|
||||
private Map<String, Double> conversion(Map<Integer, Double> map) {
|
||||
Map<String, Double> cMap = new HashMap<>();
|
||||
for (Map.Entry<Integer, Double> entry : map.entrySet()) {
|
||||
@ -60,6 +61,7 @@ public class NerveManager {
|
||||
}
|
||||
return cMap;
|
||||
}
|
||||
|
||||
private ModelParameter getDymModelParameter() throws Exception {//获取动态神经元参数
|
||||
ModelParameter modelParameter = new ModelParameter();
|
||||
List<DymNerveStudy> dymNerveStudies = new ArrayList<>();//动态神经元隐层
|
||||
@ -184,7 +186,7 @@ public class NerveManager {
|
||||
NerveStudy nerveStudy = outStudyNerves.get(i);
|
||||
outNerve.setThreshold(nerveStudy.getThreshold());
|
||||
Map<Integer, Double> dendrites = outNerve.getDendrites();
|
||||
Map<Integer, Double> studyDendrites =unConversion(nerveStudy.getDendrites());
|
||||
Map<Integer, Double> studyDendrites = unConversion(nerveStudy.getDendrites());
|
||||
for (Map.Entry<Integer, Double> outEntry : dendrites.entrySet()) {
|
||||
int key = outEntry.getKey();
|
||||
dendrites.put(key, studyDendrites.get(key));
|
||||
@ -246,6 +248,7 @@ public class NerveManager {
|
||||
List<Nerve> nerveList = depthNerves.get(0);//第一层隐层神经元
|
||||
//最后一层隐层神经元啊
|
||||
List<Nerve> lastNerveList = depthNerves.get(depthNerves.size() - 1);
|
||||
List<OutNerve> myOutNerveList = new ArrayList<>();
|
||||
//初始化输出神经元
|
||||
for (int i = 1; i < outNerveNub + 1; i++) {
|
||||
OutNerve outNerve = new OutNerve(i, hiddenNerveNub, 0, studyPoint, initPower,
|
||||
@ -253,16 +256,15 @@ public class NerveManager {
|
||||
if (isMatrix) {//是卷积层神经网络
|
||||
outNerve.setMatrixMap(matrixMap);
|
||||
}
|
||||
if (isSoftMax) {
|
||||
SoftMax softMax = new SoftMax(i, outNerveNub, false, outNerve, isShowLog);
|
||||
softMaxList.add(softMax);
|
||||
}
|
||||
//输出层神经元连接最后一层隐层神经元
|
||||
outNerve.connectFather(lastNerveList);
|
||||
outNerves.add(outNerve);
|
||||
myOutNerveList.add(outNerve);
|
||||
}
|
||||
//生成softMax层
|
||||
if (isSoftMax) {//增加softMax层
|
||||
SoftMax softMax = new SoftMax(outNerveNub, false, myOutNerveList, isShowLog);
|
||||
softMaxList.add(softMax);
|
||||
for (Nerve nerve : outNerves) {
|
||||
nerve.connect(softMaxList);
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ public class OutNerve extends Nerve {
|
||||
this.isSoftMax = isSoftMax;
|
||||
}
|
||||
|
||||
void getGBySoftMax(double g, long eventId, int id) throws Exception {//接收softMax层回传梯度
|
||||
void getGBySoftMax(double g, long eventId) throws Exception {//接收softMax层回传梯度
|
||||
gradient = g;
|
||||
updatePower(eventId);
|
||||
}
|
||||
|
@ -3,17 +3,18 @@ package org.wlld.nerveEntity;
|
||||
import org.wlld.config.RZ;
|
||||
import org.wlld.i.OutBack;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class SoftMax extends Nerve {
|
||||
private OutNerve outNerve;
|
||||
private boolean isShowLog;
|
||||
private final List<OutNerve> outNerves;
|
||||
private final boolean isShowLog;
|
||||
|
||||
public SoftMax(int id, int upNub, boolean isDynamic, OutNerve outNerve, boolean isShowLog) throws Exception {
|
||||
super(id, upNub, "softMax", 0, 0, false, null, isDynamic
|
||||
public SoftMax(int upNub, boolean isDynamic, List<OutNerve> outNerves, boolean isShowLog) throws Exception {
|
||||
super(0, upNub, "softMax", 0, 0, false, null, isDynamic
|
||||
, RZ.NOT_RZ, 0, 0, 0);
|
||||
this.outNerve = outNerve;
|
||||
this.outNerves = outNerves;
|
||||
this.isShowLog = isShowLog;
|
||||
}
|
||||
|
||||
@ -21,24 +22,28 @@ public class SoftMax extends Nerve {
|
||||
protected void input(long eventId, double parameter, boolean isStudy, Map<Integer, Double> E, OutBack outBack) throws Exception {
|
||||
boolean allReady = insertParameter(eventId, parameter);
|
||||
if (allReady) {
|
||||
double out = softMax(eventId);//输出值
|
||||
Mes mes = softMax(eventId, isStudy);//输出值
|
||||
int key = 0;
|
||||
if (isStudy) {//学习
|
||||
outNub = out;
|
||||
if (E.containsKey(getId())) {
|
||||
this.E = E.get(getId());
|
||||
} else {
|
||||
this.E = 0;
|
||||
for (Map.Entry<Integer, Double> entry : E.entrySet()) {
|
||||
if (entry.getValue() > 0.9) {
|
||||
key = entry.getKey();
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (isShowLog) {
|
||||
System.out.println("softMax==" + this.E + ",out==" + out + ",nerveId==" + getId());
|
||||
System.out.println("softMax==" + key + ",out==" + mes.poi + ",nerveId==" + mes.typeID);
|
||||
}
|
||||
gradient = -outGradient();//当前梯度变化 把梯度返回
|
||||
List<Double> errors = error(mes, key);
|
||||
features.remove(eventId); //清空当前上层输入参数参数
|
||||
outNerve.getGBySoftMax(gradient, eventId, getId());
|
||||
int size = outNerves.size();
|
||||
for (int i = 0; i < size; i++) {
|
||||
outNerves.get(i).getGBySoftMax(errors.get(i), eventId);
|
||||
}
|
||||
} else {//输出
|
||||
destoryParameter(eventId);
|
||||
if (outBack != null) {
|
||||
outBack.getBack(out, getId(), eventId);
|
||||
outBack.getBack(mes.poi, mes.typeID, eventId);
|
||||
} else {
|
||||
throw new Exception("not find outBack");
|
||||
}
|
||||
@ -46,26 +51,53 @@ public class SoftMax extends Nerve {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private double outGradient() {//生成输出层神经元梯度变化
|
||||
double g = outNub;
|
||||
if (E == 1) {
|
||||
//g = ArithUtil.sub(g, 1);
|
||||
g = g - 1;
|
||||
private List<Double> error(Mes mes, int key) {
|
||||
int t = key - 1;
|
||||
List<Double> softMax = mes.softMax;
|
||||
List<Double> error = new ArrayList<>();
|
||||
for (int i = 0; i < softMax.size(); i++) {
|
||||
double self = softMax.get(i);
|
||||
double myError;
|
||||
if (i != t) {
|
||||
myError = -self;
|
||||
} else {
|
||||
myError = 1 - self;
|
||||
}
|
||||
error.add(myError);
|
||||
}
|
||||
return g;
|
||||
return error;
|
||||
}
|
||||
|
||||
private double softMax(long eventId) {//计算当前输出结果
|
||||
private Mes softMax(long eventId, boolean isStudy) {//计算当前输出结果
|
||||
double sigma = 0;
|
||||
int id = 0;
|
||||
double poi = 0;
|
||||
Mes mes = new Mes();
|
||||
List<Double> featuresList = features.get(eventId);
|
||||
double self = featuresList.get(getId() - 1);
|
||||
double eSelf = Math.exp(self);
|
||||
for (int i = 0; i < featuresList.size(); i++) {
|
||||
double value = featuresList.get(i);
|
||||
// sigma = ArithUtil.add(Math.exp(value), sigma);
|
||||
for (double value : featuresList) {
|
||||
sigma = Math.exp(value) + sigma;
|
||||
}
|
||||
return eSelf / sigma;//ArithUtil.div(eSelf, sigma);
|
||||
List<Double> softMax = new ArrayList<>();
|
||||
for (int i = 0; i < featuresList.size(); i++) {
|
||||
double eSelf = Math.exp(featuresList.get(i));
|
||||
double value = eSelf / sigma;
|
||||
if (isStudy) {
|
||||
softMax.add(value);
|
||||
}
|
||||
if (value > poi) {
|
||||
poi = value;
|
||||
id = i + 1;
|
||||
}
|
||||
}
|
||||
mes.softMax = softMax;
|
||||
mes.typeID = id;
|
||||
mes.poi = poi;
|
||||
return mes;
|
||||
}
|
||||
|
||||
static class Mes {
|
||||
int typeID;
|
||||
double poi;
|
||||
List<Double> softMax;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user