增加模型参数获取

This commit is contained in:
lidapeng 2020-01-18 11:12:37 +08:00
parent c1e7f86c3f
commit 2501596973
8 changed files with 144 additions and 50 deletions

41
pom.xml
View File

@ -11,22 +11,6 @@
<name>myBrain</name>
<!-- FIXME change it to the project's website -->
<url>http://www.example.com</url>
<!--输入在sonatype创建的账户和联系邮箱 -->
<licenses>
<license>
<name>The Apache Software License, Version 2.0</name>
<url>http://www.apache.org/licenses/LICENSE-2.0.txt</url>
<distribution>repo</distribution>
</license>
</licenses>
<developers>
<developer><!--输入在sonatype创建的账户和联系邮箱 -->
<name>thenk008</name>
<email>794757862@qq.com</email>
<organization>hope-redheart</organization>
<organizationUrl>https://www.cnblogs.com/yjp372928571</organizationUrl>
</developer>
</developers>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.8</maven.compiler.source>
@ -48,32 +32,7 @@
</descriptorRefs>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<version>1.6</version>
<executions>
<execution>
<phase>verify</phase>
<goals>
<goal>sign</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</pluginManagement>
</build>
<distributionManagement>
<repository>
<id>releases</id>
<name>Nexus Release Repository</name>
<url>https://oss.sonatype.org/service/local/staging/deploy/maven2</url>
</repository>
<snapshotRepository>
<id>snapshots</id>
<name>Nexus Snapshot Repository</name>
<url>https://oss.sonatype.org/content/repositories/snapshots</url>
</snapshotRepository>
</distributionManagement>
</project>

View File

@ -29,7 +29,7 @@ public class HelloWorld {
Map<Integer, Double> wrongTagging = new HashMap<>();//分类标注
rightTagging.put(1, 1.0);
wrongTagging.put(1, 0.0);
for (int i = 1; i < 500; i++) {
for (int i = 1; i < 5; i++) {
System.out.println("开始学习1==" + i);
//读取本地URL地址图片,并转化成矩阵
Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + i + ".png");
@ -38,7 +38,7 @@ public class HelloWorld {
operation.learning(right, rightTagging, false);
operation.learning(wrong, wrongTagging, false);
}
for (int i = 1; i < 500; i++) {//神经网络学习
for (int i = 1; i < 5; i++) {//神经网络学习
System.out.println("开始学习2==" + i);
//读取本地URL地址图片,并转化成矩阵
Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + i + ".png");

View File

@ -1,7 +1,6 @@
package org.wlld.imageRecognition;
import org.wlld.MatrixTools.Matrix;
import org.wlld.MatrixTools.MatrixOperation;
import org.wlld.config.StudyPattern;
import org.wlld.function.ReLu;
import org.wlld.function.Sigmod;
@ -99,8 +98,14 @@ public class TempleConfig {
convolutionNerveManager.init(initPower, true, nerveManager);
}
public ModelParameter getModel() {//获取模型参数
return nerveManager.getModelParameter();
public ModelParameter getModel() throws Exception {//获取模型参数
ModelParameter modelParameter = nerveManager.getModelParameter();
if (studyPattern == StudyPattern.Accuracy_Pattern) {
ModelParameter modelParameter1 = convolutionNerveManager.getModelParameter();
modelParameter.setDymNerveStudies(modelParameter1.getDymNerveStudies());
modelParameter.setDymOutNerveStudy(modelParameter1.getDymOutNerveStudy());
}
return modelParameter;
}
public List<SensoryNerve> getSensoryNerves() {//获取感知神经元

View File

@ -4,6 +4,7 @@ import org.wlld.MatrixTools.Matrix;
import org.wlld.i.ActiveFunction;
import org.wlld.i.OutBack;
import org.wlld.nerveEntity.*;
import org.wlld.tools.ArithUtil;
import java.util.ArrayList;
import java.util.HashMap;
@ -48,7 +49,46 @@ public class NerveManager {
}
}
public ModelParameter getModelParameter() {//获取当前模型参数
private ModelParameter getDymModelParameter() throws Exception {//获取动态神经元参数
ModelParameter modelParameter = new ModelParameter();
List<DymNerveStudy> dymNerveStudies = new ArrayList<>();//动态神经元隐层
DymNerveStudy dymOutNerveStudy = new DymNerveStudy();//动态神经元输出层
modelParameter.setDymNerveStudies(dymNerveStudies);
modelParameter.setDymOutNerveStudy(dymOutNerveStudy);
for (int i = 0; i < depthNerves.size(); i++) {
Nerve depthNerve = depthNerves.get(i).get(0);//隐层神经元
DymNerveStudy deepNerveStudy = new DymNerveStudy();//动态神经元输出层
List<Double> list = deepNerveStudy.getList();
deepNerveStudy.setThreshold(depthNerve.getThreshold());//获取偏移值
Matrix matrix = depthNerve.getNerveMatrix();
insertWList(matrix, list);
dymNerveStudies.add(deepNerveStudy);
}
Nerve outNerve = outNevers.get(0);
Matrix matrix = outNerve.getNerveMatrix();
dymOutNerveStudy.setThreshold(outNerve.getThreshold());
List<Double> list = dymOutNerveStudy.getList();
insertWList(matrix, list);
return modelParameter;
}
private void insertWList(Matrix matrix, List<Double> list) throws Exception {//
for (int i = 0; i < matrix.getX(); i++) {
for (int j = 0; j < matrix.getY(); j++) {
list.add(matrix.getNumber(i, j));
}
}
}
public ModelParameter getModelParameter() throws Exception {
if (isDynamic) {
return getDymModelParameter();
} else {
return getStaticModelParameter();
}
}
private ModelParameter getStaticModelParameter() {//获取当前模型参数
ModelParameter modelParameter = new ModelParameter();
List<List<NerveStudy>> studyDepthNerves = new ArrayList<>();//隐层神经元模型
List<NerveStudy> outStudyNevers = new ArrayList<>();//输出神经元
@ -78,8 +118,42 @@ public class NerveManager {
return modelParameter;
}
//注入模型参数
public void insertModelParameter(ModelParameter modelParameter) {
public void insertModelParameter(ModelParameter modelParameter) throws Exception {
insertBpModelParameter(modelParameter);//全连接层注入参数
if (isDynamic) {
insertConvolutionModelParameter(modelParameter);
}
}
//注入卷积层模型参数
private void insertConvolutionModelParameter(ModelParameter modelParameter) throws Exception {
List<DymNerveStudy> dymNerveStudyList = modelParameter.getDymNerveStudies();
DymNerveStudy dymOutNerveStudy = modelParameter.getDymOutNerveStudy();
for (int i = 0; i < depthNerves.size(); i++) {
Nerve depthNerve = depthNerves.get(i).get(0);
DymNerveStudy dymNerveStudy = dymNerveStudyList.get(i);
List<Double> list = dymNerveStudy.getList();
Matrix nerveMatrix = depthNerve.getNerveMatrix();
depthNerve.setThreshold(dymNerveStudy.getThreshold());//注入偏置项
insertMatrix(nerveMatrix, list);
}
Nerve outNerve = outNevers.get(0);
outNerve.setThreshold(dymOutNerveStudy.getThreshold());//输出神经元注入偏置项
Matrix outNervMatrix = outNerve.getNerveMatrix();
List<Double> list = dymOutNerveStudy.getList();
insertMatrix(outNervMatrix, list);
}
private void insertMatrix(Matrix matrix, List<Double> list) throws Exception {
for (int i = 0; i < list.size(); i++) {
int x = i / 3;
int y = i % 3;
matrix.setNub(x, y, list.get(i));
}
}
//注入全连接模型参数
private void insertBpModelParameter(ModelParameter modelParameter) {
List<List<NerveStudy>> depthStudyNerves = modelParameter.getDepthNerves();//隐层神经元
List<NerveStudy> outStudyNevers = modelParameter.getOutNevers();//输出神经元
//隐层神经元参数注入

View File

@ -0,0 +1,30 @@
package org.wlld.nerveEntity;
import java.util.ArrayList;
import java.util.List;
/**
* @author lidapeng
* @description 动态神经元模型参数
* @date 8:14 上午 2020/1/18
*/
public class DymNerveStudy {
private List<Double> list = new ArrayList<>();
private double threshold;//此神经元的阈值需要取出
public List<Double> getList() {
return list;
}
public void setList(List<Double> list) {
this.list = list;
}
public double getThreshold() {
return threshold;
}
public void setThreshold(double threshold) {
this.threshold = threshold;
}
}

View File

@ -14,6 +14,24 @@ public class ModelParameter {
//神经远模型参数
private List<List<NerveStudy>> depthNerves = new ArrayList<>();//隐层神经元
private List<NerveStudy> outNevers = new ArrayList<>();//输出神经元
private List<DymNerveStudy> dymNerveStudies = new ArrayList<>();//动态神经元隐层
private DymNerveStudy dymOutNerveStudy = new DymNerveStudy();//动态神经元输出层
public List<DymNerveStudy> getDymNerveStudies() {
return dymNerveStudies;
}
public void setDymNerveStudies(List<DymNerveStudy> dymNerveStudies) {
this.dymNerveStudies = dymNerveStudies;
}
public DymNerveStudy getDymOutNerveStudy() {
return dymOutNerveStudy;
}
public void setDymOutNerveStudy(DymNerveStudy dymOutNerveStudy) {
this.dymOutNerveStudy = dymOutNerveStudy;
}
public List<List<NerveStudy>> getDepthNerves() {
return depthNerves;

View File

@ -39,6 +39,14 @@ public abstract class Nerve {
return dendrites;
}
public Matrix getNerveMatrix() {
return nerveMatrix;
}
public void setNerveMatrix(Matrix nerveMatrix) {
this.nerveMatrix = nerveMatrix;
}
public void setDendrites(Map<Integer, Double> dendrites) {
this.dendrites = dendrites;
}

View File

@ -69,7 +69,7 @@ public class OutNerve extends Nerve {
matrixF = new Matrix(myMatrix.getX(), myMatrix.getY());
}
if (isKernelStudy) {//回传
// System.out.println(myMatrix.getString());
// System.out.println(myMatrix.getString());
for (Map.Entry<Integer, Double> entry : E.entrySet()) {
double g;
if (entry.getValue() > 0.5) {//正模板