mirror of
https://gitee.com/dromara/easyAi.git
synced 2024-11-30 10:47:49 +08:00
增加模型参数获取
This commit is contained in:
parent
c1e7f86c3f
commit
2501596973
41
pom.xml
41
pom.xml
@ -11,22 +11,6 @@
|
||||
<name>myBrain</name>
|
||||
<!-- FIXME change it to the project's website -->
|
||||
<url>http://www.example.com</url>
|
||||
<!--输入在sonatype创建的账户和联系邮箱 -->
|
||||
<licenses>
|
||||
<license>
|
||||
<name>The Apache Software License, Version 2.0</name>
|
||||
<url>http://www.apache.org/licenses/LICENSE-2.0.txt</url>
|
||||
<distribution>repo</distribution>
|
||||
</license>
|
||||
</licenses>
|
||||
<developers>
|
||||
<developer><!--输入在sonatype创建的账户和联系邮箱 -->
|
||||
<name>thenk008</name>
|
||||
<email>794757862@qq.com</email>
|
||||
<organization>hope-redheart</organization>
|
||||
<organizationUrl>https://www.cnblogs.com/yjp372928571</organizationUrl>
|
||||
</developer>
|
||||
</developers>
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<maven.compiler.source>1.8</maven.compiler.source>
|
||||
@ -48,32 +32,7 @@
|
||||
</descriptorRefs>
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-gpg-plugin</artifactId>
|
||||
<version>1.6</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<phase>verify</phase>
|
||||
<goals>
|
||||
<goal>sign</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</pluginManagement>
|
||||
</build>
|
||||
<distributionManagement>
|
||||
<repository>
|
||||
<id>releases</id>
|
||||
<name>Nexus Release Repository</name>
|
||||
<url>https://oss.sonatype.org/service/local/staging/deploy/maven2</url>
|
||||
</repository>
|
||||
<snapshotRepository>
|
||||
<id>snapshots</id>
|
||||
<name>Nexus Snapshot Repository</name>
|
||||
<url>https://oss.sonatype.org/content/repositories/snapshots</url>
|
||||
</snapshotRepository>
|
||||
</distributionManagement>
|
||||
</project>
|
||||
|
@ -29,7 +29,7 @@ public class HelloWorld {
|
||||
Map<Integer, Double> wrongTagging = new HashMap<>();//分类标注
|
||||
rightTagging.put(1, 1.0);
|
||||
wrongTagging.put(1, 0.0);
|
||||
for (int i = 1; i < 500; i++) {
|
||||
for (int i = 1; i < 5; i++) {
|
||||
System.out.println("开始学习1==" + i);
|
||||
//读取本地URL地址图片,并转化成矩阵
|
||||
Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + i + ".png");
|
||||
@ -38,7 +38,7 @@ public class HelloWorld {
|
||||
operation.learning(right, rightTagging, false);
|
||||
operation.learning(wrong, wrongTagging, false);
|
||||
}
|
||||
for (int i = 1; i < 500; i++) {//神经网络学习
|
||||
for (int i = 1; i < 5; i++) {//神经网络学习
|
||||
System.out.println("开始学习2==" + i);
|
||||
//读取本地URL地址图片,并转化成矩阵
|
||||
Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + i + ".png");
|
||||
|
@ -1,7 +1,6 @@
|
||||
package org.wlld.imageRecognition;
|
||||
|
||||
import org.wlld.MatrixTools.Matrix;
|
||||
import org.wlld.MatrixTools.MatrixOperation;
|
||||
import org.wlld.config.StudyPattern;
|
||||
import org.wlld.function.ReLu;
|
||||
import org.wlld.function.Sigmod;
|
||||
@ -99,8 +98,14 @@ public class TempleConfig {
|
||||
convolutionNerveManager.init(initPower, true, nerveManager);
|
||||
}
|
||||
|
||||
public ModelParameter getModel() {//获取模型参数
|
||||
return nerveManager.getModelParameter();
|
||||
public ModelParameter getModel() throws Exception {//获取模型参数
|
||||
ModelParameter modelParameter = nerveManager.getModelParameter();
|
||||
if (studyPattern == StudyPattern.Accuracy_Pattern) {
|
||||
ModelParameter modelParameter1 = convolutionNerveManager.getModelParameter();
|
||||
modelParameter.setDymNerveStudies(modelParameter1.getDymNerveStudies());
|
||||
modelParameter.setDymOutNerveStudy(modelParameter1.getDymOutNerveStudy());
|
||||
}
|
||||
return modelParameter;
|
||||
}
|
||||
|
||||
public List<SensoryNerve> getSensoryNerves() {//获取感知神经元
|
||||
|
@ -4,6 +4,7 @@ import org.wlld.MatrixTools.Matrix;
|
||||
import org.wlld.i.ActiveFunction;
|
||||
import org.wlld.i.OutBack;
|
||||
import org.wlld.nerveEntity.*;
|
||||
import org.wlld.tools.ArithUtil;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
@ -48,7 +49,46 @@ public class NerveManager {
|
||||
}
|
||||
}
|
||||
|
||||
public ModelParameter getModelParameter() {//获取当前模型参数
|
||||
private ModelParameter getDymModelParameter() throws Exception {//获取动态神经元参数
|
||||
ModelParameter modelParameter = new ModelParameter();
|
||||
List<DymNerveStudy> dymNerveStudies = new ArrayList<>();//动态神经元隐层
|
||||
DymNerveStudy dymOutNerveStudy = new DymNerveStudy();//动态神经元输出层
|
||||
modelParameter.setDymNerveStudies(dymNerveStudies);
|
||||
modelParameter.setDymOutNerveStudy(dymOutNerveStudy);
|
||||
for (int i = 0; i < depthNerves.size(); i++) {
|
||||
Nerve depthNerve = depthNerves.get(i).get(0);//隐层神经元
|
||||
DymNerveStudy deepNerveStudy = new DymNerveStudy();//动态神经元输出层
|
||||
List<Double> list = deepNerveStudy.getList();
|
||||
deepNerveStudy.setThreshold(depthNerve.getThreshold());//获取偏移值
|
||||
Matrix matrix = depthNerve.getNerveMatrix();
|
||||
insertWList(matrix, list);
|
||||
dymNerveStudies.add(deepNerveStudy);
|
||||
}
|
||||
Nerve outNerve = outNevers.get(0);
|
||||
Matrix matrix = outNerve.getNerveMatrix();
|
||||
dymOutNerveStudy.setThreshold(outNerve.getThreshold());
|
||||
List<Double> list = dymOutNerveStudy.getList();
|
||||
insertWList(matrix, list);
|
||||
return modelParameter;
|
||||
}
|
||||
|
||||
private void insertWList(Matrix matrix, List<Double> list) throws Exception {//
|
||||
for (int i = 0; i < matrix.getX(); i++) {
|
||||
for (int j = 0; j < matrix.getY(); j++) {
|
||||
list.add(matrix.getNumber(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public ModelParameter getModelParameter() throws Exception {
|
||||
if (isDynamic) {
|
||||
return getDymModelParameter();
|
||||
} else {
|
||||
return getStaticModelParameter();
|
||||
}
|
||||
}
|
||||
|
||||
private ModelParameter getStaticModelParameter() {//获取当前模型参数
|
||||
ModelParameter modelParameter = new ModelParameter();
|
||||
List<List<NerveStudy>> studyDepthNerves = new ArrayList<>();//隐层神经元模型
|
||||
List<NerveStudy> outStudyNevers = new ArrayList<>();//输出神经元
|
||||
@ -78,8 +118,42 @@ public class NerveManager {
|
||||
return modelParameter;
|
||||
}
|
||||
|
||||
//注入模型参数
|
||||
public void insertModelParameter(ModelParameter modelParameter) {
|
||||
public void insertModelParameter(ModelParameter modelParameter) throws Exception {
|
||||
insertBpModelParameter(modelParameter);//全连接层注入参数
|
||||
if (isDynamic) {
|
||||
insertConvolutionModelParameter(modelParameter);
|
||||
}
|
||||
}
|
||||
|
||||
//注入卷积层模型参数
|
||||
private void insertConvolutionModelParameter(ModelParameter modelParameter) throws Exception {
|
||||
List<DymNerveStudy> dymNerveStudyList = modelParameter.getDymNerveStudies();
|
||||
DymNerveStudy dymOutNerveStudy = modelParameter.getDymOutNerveStudy();
|
||||
for (int i = 0; i < depthNerves.size(); i++) {
|
||||
Nerve depthNerve = depthNerves.get(i).get(0);
|
||||
DymNerveStudy dymNerveStudy = dymNerveStudyList.get(i);
|
||||
List<Double> list = dymNerveStudy.getList();
|
||||
Matrix nerveMatrix = depthNerve.getNerveMatrix();
|
||||
depthNerve.setThreshold(dymNerveStudy.getThreshold());//注入偏置项
|
||||
insertMatrix(nerveMatrix, list);
|
||||
}
|
||||
Nerve outNerve = outNevers.get(0);
|
||||
outNerve.setThreshold(dymOutNerveStudy.getThreshold());//输出神经元注入偏置项
|
||||
Matrix outNervMatrix = outNerve.getNerveMatrix();
|
||||
List<Double> list = dymOutNerveStudy.getList();
|
||||
insertMatrix(outNervMatrix, list);
|
||||
}
|
||||
|
||||
private void insertMatrix(Matrix matrix, List<Double> list) throws Exception {
|
||||
for (int i = 0; i < list.size(); i++) {
|
||||
int x = i / 3;
|
||||
int y = i % 3;
|
||||
matrix.setNub(x, y, list.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
//注入全连接模型参数
|
||||
private void insertBpModelParameter(ModelParameter modelParameter) {
|
||||
List<List<NerveStudy>> depthStudyNerves = modelParameter.getDepthNerves();//隐层神经元
|
||||
List<NerveStudy> outStudyNevers = modelParameter.getOutNevers();//输出神经元
|
||||
//隐层神经元参数注入
|
||||
|
30
src/main/java/org/wlld/nerveEntity/DymNerveStudy.java
Normal file
30
src/main/java/org/wlld/nerveEntity/DymNerveStudy.java
Normal file
@ -0,0 +1,30 @@
|
||||
package org.wlld.nerveEntity;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @author lidapeng
|
||||
* @description 动态神经元模型参数
|
||||
* @date 8:14 上午 2020/1/18
|
||||
*/
|
||||
public class DymNerveStudy {
|
||||
private List<Double> list = new ArrayList<>();
|
||||
private double threshold;//此神经元的阈值需要取出
|
||||
|
||||
public List<Double> getList() {
|
||||
return list;
|
||||
}
|
||||
|
||||
public void setList(List<Double> list) {
|
||||
this.list = list;
|
||||
}
|
||||
|
||||
public double getThreshold() {
|
||||
return threshold;
|
||||
}
|
||||
|
||||
public void setThreshold(double threshold) {
|
||||
this.threshold = threshold;
|
||||
}
|
||||
}
|
@ -14,6 +14,24 @@ public class ModelParameter {
|
||||
//神经远模型参数
|
||||
private List<List<NerveStudy>> depthNerves = new ArrayList<>();//隐层神经元
|
||||
private List<NerveStudy> outNevers = new ArrayList<>();//输出神经元
|
||||
private List<DymNerveStudy> dymNerveStudies = new ArrayList<>();//动态神经元隐层
|
||||
private DymNerveStudy dymOutNerveStudy = new DymNerveStudy();//动态神经元输出层
|
||||
|
||||
public List<DymNerveStudy> getDymNerveStudies() {
|
||||
return dymNerveStudies;
|
||||
}
|
||||
|
||||
public void setDymNerveStudies(List<DymNerveStudy> dymNerveStudies) {
|
||||
this.dymNerveStudies = dymNerveStudies;
|
||||
}
|
||||
|
||||
public DymNerveStudy getDymOutNerveStudy() {
|
||||
return dymOutNerveStudy;
|
||||
}
|
||||
|
||||
public void setDymOutNerveStudy(DymNerveStudy dymOutNerveStudy) {
|
||||
this.dymOutNerveStudy = dymOutNerveStudy;
|
||||
}
|
||||
|
||||
public List<List<NerveStudy>> getDepthNerves() {
|
||||
return depthNerves;
|
||||
|
@ -39,6 +39,14 @@ public abstract class Nerve {
|
||||
return dendrites;
|
||||
}
|
||||
|
||||
public Matrix getNerveMatrix() {
|
||||
return nerveMatrix;
|
||||
}
|
||||
|
||||
public void setNerveMatrix(Matrix nerveMatrix) {
|
||||
this.nerveMatrix = nerveMatrix;
|
||||
}
|
||||
|
||||
public void setDendrites(Map<Integer, Double> dendrites) {
|
||||
this.dendrites = dendrites;
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ public class OutNerve extends Nerve {
|
||||
matrixF = new Matrix(myMatrix.getX(), myMatrix.getY());
|
||||
}
|
||||
if (isKernelStudy) {//回传
|
||||
// System.out.println(myMatrix.getString());
|
||||
// System.out.println(myMatrix.getString());
|
||||
for (Map.Entry<Integer, Double> entry : E.entrySet()) {
|
||||
double g;
|
||||
if (entry.getValue() > 0.5) {//正模板
|
||||
|
Loading…
Reference in New Issue
Block a user