修改模型获取类型数据类型转化

This commit is contained in:
794757862@qq.com 2022-07-15 10:39:06 +08:00
parent d8dadea6bc
commit bc62b6e202
4 changed files with 26 additions and 12 deletions

View File

@ -45,7 +45,21 @@ public class NerveManager {
public void setMatrixMap(Map<Integer, Matrix> matrixMap) {
this.matrixMap = matrixMap;
}
private Map<String, Double> conversion(Map<Integer, Double> map) {
Map<String, Double> cMap = new HashMap<>();
for (Map.Entry<Integer, Double> entry : map.entrySet()) {
cMap.put(String.valueOf(entry.getKey()), entry.getValue());
}
return cMap;
}
private Map<Integer, Double> unConversion(Map<String, Double> map) {
Map<Integer, Double> cMap = new HashMap<>();
for (Map.Entry<String, Double> entry : map.entrySet()) {
cMap.put(Integer.parseInt(entry.getKey()), entry.getValue());
}
return cMap;
}
private ModelParameter getDymModelParameter() throws Exception {//获取动态神经元参数
ModelParameter modelParameter = new ModelParameter();
List<DymNerveStudy> dymNerveStudies = new ArrayList<>();//动态神经元隐层
@ -96,7 +110,7 @@ public class NerveManager {
NerveStudy nerveStudy = new NerveStudy();
Nerve hiddenNerve = depthNerve.get(j);
nerveStudy.setThreshold(hiddenNerve.getThreshold());
nerveStudy.setDendrites(hiddenNerve.getDendrites());
nerveStudy.setDendrites(conversion(hiddenNerve.getDendrites()));
deepNerve.add(nerveStudy);
}
studyDepthNerves.add(deepNerve);
@ -105,7 +119,7 @@ public class NerveManager {
NerveStudy nerveStudy = new NerveStudy();
Nerve outNerve = outNerves.get(i);
nerveStudy.setThreshold(outNerve.getThreshold());
nerveStudy.setDendrites(outNerve.getDendrites());
nerveStudy.setDendrites(conversion(outNerve.getDendrites()));
outStudyNerves.add(nerveStudy);
}
modelParameter.setDepthNerves(studyDepthNerves);
@ -156,7 +170,7 @@ public class NerveManager {
Nerve nerve = depthNerve.get(j);
NerveStudy nerveStudy = depth.get(j);
//学习结果
Map<Integer, Double> studyDendrites = nerveStudy.getDendrites();
Map<Integer, Double> studyDendrites = unConversion(nerveStudy.getDendrites());
//神经元参数注入
Map<Integer, Double> dendrites = nerve.getDendrites();
nerve.setThreshold(nerveStudy.getThreshold());//注入隐层阈值
@ -172,7 +186,7 @@ public class NerveManager {
NerveStudy nerveStudy = outStudyNerves.get(i);
outNerve.setThreshold(nerveStudy.getThreshold());
Map<Integer, Double> dendrites = outNerve.getDendrites();
Map<Integer, Double> studyDendrites = nerveStudy.getDendrites();
Map<Integer, Double> studyDendrites =unConversion(nerveStudy.getDendrites());
for (Map.Entry<Integer, Double> outEntry : dendrites.entrySet()) {
int key = outEntry.getKey();
dendrites.put(key, studyDendrites.get(key));

View File

@ -9,14 +9,14 @@ import java.util.Map;
* @date 3:36 下午 2020/1/8
*/
public class NerveStudy {
private Map<Integer, Double> dendrites = new HashMap<>();//上一层权重(需要取出)
private Map<String, Double> dendrites = new HashMap<>();//上一层权重(需要取出)
private double threshold;//此神经元的阈值需要取出
public Map<Integer, Double> getDendrites() {
public Map<String, Double> getDendrites() {
return dendrites;
}
public void setDendrites(Map<Integer, Double> dendrites) {
public void setDendrites(Map<String, Double> dendrites) {
this.dendrites = dendrites;
}

View File

@ -17,8 +17,8 @@ import java.util.Map;
public class ImageTest {
public static void main(String[] args) throws Exception {
//dish();
study();
dish();
//study();
}
public static void dish() throws Exception {//识别
@ -27,7 +27,7 @@ public class ImageTest {
config.setTypeNub(2);//设置类别数量
config.setBoxSize(125);//设置物体大小 单位像素
config.setPictureNumber(5);//设置每个种类训练图片数量
config.setPth(0.7);//设置可信概率只有超过可信概率阈值得出的结果才是可信的
config.setPth(0.55);//设置可信概率只有超过可信概率阈值得出的结果才是可信的
config.setShowLog(true);//输出学习时打印数据
Distinguish distinguish = new Distinguish(config);
distinguish.insertModel(JSONObject.parseObject(ModelData.DATA, Model.class));
@ -49,7 +49,7 @@ public class ImageTest {
config.setTypeNub(2);//设置类别数量
config.setBoxSize(125);//设置物体大小 单位像素 125*125 矩形
config.setPictureNumber(5);//设置每个种类训练图片数量
config.setPth(0.7);//设置可信概率只有超过可信概率阈值得出的结果才是可信的0-1
config.setPth(0.55);//设置可信概率只有超过可信概率阈值得出的结果才是可信的0-1
config.setShowLog(true);//输出学习时打印数据
Distinguish distinguish = new Distinguish(config);//识别类
distinguish.setBackGround(picture.getThreeMatrix("E:\\ls\\fp15\\back.jpg"));//塞入背景图片

File diff suppressed because one or more lines are too long