添加多种类干食切图

This commit is contained in:
lidapeng 2020-11-01 18:32:55 +08:00
parent 2254885c57
commit 121cfda3f6
12 changed files with 23 additions and 578 deletions

View File

@ -70,7 +70,7 @@ public class Convolution extends Frequency {
List<ThreeChannelMatrix> threeChannelMatrixList = regionThreeChannelMatrix(threeMatrix, regionSize); List<ThreeChannelMatrix> threeChannelMatrixList = regionThreeChannelMatrix(threeMatrix, regionSize);
for (ThreeChannelMatrix threeChannelMatrix : threeChannelMatrixList) { for (ThreeChannelMatrix threeChannelMatrix : threeChannelMatrixList) {
List<Double> feature = new ArrayList<>(); List<Double> feature = new ArrayList<>();
MeanClustering meanClustering = new MeanClustering(sqNub, templeConfig); MeanClustering meanClustering = new MeanClustering(sqNub);
Matrix matrixR = threeChannelMatrix.getMatrixR(); Matrix matrixR = threeChannelMatrix.getMatrixR();
Matrix matrixG = threeChannelMatrix.getMatrixG(); Matrix matrixG = threeChannelMatrix.getMatrixG();
Matrix matrixB = threeChannelMatrix.getMatrixB(); Matrix matrixB = threeChannelMatrix.getMatrixB();
@ -82,7 +82,7 @@ public class Convolution extends Frequency {
meanClustering.setColor(color); meanClustering.setColor(color);
} }
} }
meanClustering.start(false); meanClustering.start();
List<RGBNorm> rgbNorms = meanClustering.getMatrices(); List<RGBNorm> rgbNorms = meanClustering.getMatrices();
Collections.sort(rgbNorms, rgbSort); Collections.sort(rgbNorms, rgbSort);
for (RGBNorm rgbNorm : rgbNorms) { for (RGBNorm rgbNorm : rgbNorms) {
@ -96,48 +96,10 @@ public class Convolution extends Frequency {
return features; return features;
} }
public List<Double> getCenterColor(ThreeChannelMatrix threeChannelMatrix, TempleConfig templeConfig,
int sqNub) throws Exception {
Matrix matrixR = threeChannelMatrix.getMatrixR();
Matrix matrixG = threeChannelMatrix.getMatrixG();
Matrix matrixB = threeChannelMatrix.getMatrixB();
MeanClustering meanClustering = new MeanClustering(sqNub, templeConfig);
int maxX = matrixR.getX();
int maxY = matrixR.getY();
ColorFunction colorFunction = new ColorFunction(threeChannelMatrix);
int[] minBorder = new int[]{0, 0};
int[] maxBorder = new int[]{maxX - 1, maxY - 1};
//创建粒子群
PSO pso = new PSO(2, minBorder, maxBorder, 200, 200,
colorFunction, 0.2, 1, 0.5, true, 10, 1);
List<double[]> positions = pso.start();
for (int i = 0; i < positions.size(); i++) {
double[] parameter = positions.get(i);
//获取取样坐标
int x = (int) parameter[0];
int y = (int) parameter[1];
double[] rgb = new double[]{matrixR.getNumber(x, y), matrixG.getNumber(x, y),
matrixB.getNumber(x, y)};
meanClustering.setColor(rgb);
}
meanClustering.start(true);
List<RGBNorm> rgbNorms = meanClustering.getMatrices();
List<Double> features = new ArrayList<>();
for (int i = 0; i < sqNub; i++) {
double[] rgb = rgbNorms.get(i).getRgb();
for (int j = 0; j < rgb.length; j++) {
features.add(rgb[j]);
}
}
return features;
}
public List<Double> getCenterTexture(ThreeChannelMatrix threeChannelMatrix, int size, TempleConfig templeConfig public List<Double> getCenterTexture(ThreeChannelMatrix threeChannelMatrix, int size, TempleConfig templeConfig
, int sqNub) throws Exception { , int sqNub) throws Exception {
//MeanClustering meanClustering = new MeanClustering(sqNub, templeConfig); //MeanClustering meanClustering = new MeanClustering(sqNub);
GMClustering meanClustering = new GMClustering(sqNub, templeConfig); GMClustering meanClustering = new GMClustering(sqNub);
Matrix matrixR = threeChannelMatrix.getMatrixR(); Matrix matrixR = threeChannelMatrix.getMatrixR();
Matrix matrixG = threeChannelMatrix.getMatrixG(); Matrix matrixG = threeChannelMatrix.getMatrixG();
Matrix matrixB = threeChannelMatrix.getMatrixB(); Matrix matrixB = threeChannelMatrix.getMatrixB();
@ -158,40 +120,7 @@ public class Convolution extends Frequency {
} }
} }
} }
// for (int i = 0; i <= xn - size; i += 2) { meanClustering.start();
// for (int j = 0; j <= yn - size; j += 2) {
// Matrix sonR = matrixR.getSonOfMatrix(i, j, size, size);
// Matrix sonG = matrixG.getSonOfMatrix(i, j, size, size);
// Matrix sonB = matrixB.getSonOfMatrix(i, j, size, size);
// Matrix sonRGB = matrixRGB.getSonOfMatrix(i, j, size, size);
// double[] h = new double[nub];
// double[] rgb = new double[nub * 3];
// for (int t = 0; t < size; t++) {
// for (int k = 0; k < size; k++) {
// int index = t * size + k;
// h[index] = sonRGB.getNumber(t, k);
// rgb[index] = sonR.getNumber(t, k) / 255;
// rgb[nub + index] = sonG.getNumber(t, k) / 255;
// rgb[twoNub + index] = sonB.getNumber(t, k) / 255;
// }
// }
// //900 200
// double dispersed = variance(h);
// if (dispersed < 900 && dispersed > 200) {
// for (int m = 0; m < nub; m++) {
// double[] color = new double[]{rgb[m], rgb[m + nub], rgb[m + twoNub]};
// meanClustering.setColor(color);
// }
// }
// }
// }
//List<double[]> list = meanClustering.start(true);//开始聚类
meanClustering.start(true);
// if (tag == 0) {//识别
// templeConfig.getFood().getkNerveManger().look(list);
// } else {//训练
// templeConfig.getFood().getkNerveManger().setFeature(tag, list);
// }
List<RGBNorm> rgbNorms = meanClustering.getMatrices(); List<RGBNorm> rgbNorms = meanClustering.getMatrices();
List<Double> features = new ArrayList<>(); List<Double> features = new ArrayList<>();
for (int i = 0; i < sqNub; i++) { for (int i = 0; i < sqNub; i++) {

View File

@ -38,7 +38,7 @@ public class CutFood {
mean.setColor(rgb); mean.setColor(rgb);
} }
} }
mean.start(true); mean.start();
} }
private double getAvg(Matrix matrix) throws Exception { private double getAvg(Matrix matrix) throws Exception {
@ -89,7 +89,7 @@ public class CutFood {
double regionSize = gmBody.getPixelNub() * s; double regionSize = gmBody.getPixelNub() * s;
int type = gmBody.getType(); int type = gmBody.getType();
if (type != 1) {//背景直接过滤 if (type != 1) {//背景直接过滤
int oneSize = meanMap.get(type).getRegionSize(); double oneSize = meanMap.get(type).getRegionSize();
if (regionSize > oneSize * 0.8) { if (regionSize > oneSize * 0.8) {
gmBodies2.add(gmBody); gmBodies2.add(gmBody);
} }
@ -98,7 +98,7 @@ public class CutFood {
for (GMBody gmBody : gmBodies2) { for (GMBody gmBody : gmBodies2) {
int type = gmBody.getType(); int type = gmBody.getType();
double regionSize = gmBody.getPixelNub() * s; double regionSize = gmBody.getPixelNub() * s;
int oneSize = meanMap.get(type).getRegionSize(); double oneSize = meanMap.get(type).getRegionSize();
double nub = regionSize / (double) oneSize; double nub = regionSize / (double) oneSize;
System.out.println("type==" + type + ",nub==" + nub + ",onSize==" + oneSize + ",gmNub==" System.out.println("type==" + type + ",nub==" + nub + ",onSize==" + oneSize + ",gmNub=="
+ gmBody.getPixelNub()); + gmBody.getPixelNub());
@ -138,8 +138,8 @@ public class CutFood {
Matrix matrixR = threeChannelMatrix.getMatrixR(); Matrix matrixR = threeChannelMatrix.getMatrixR();
int x = matrixR.getX(); int x = matrixR.getX();
int y = matrixR.getY(); int y = matrixR.getY();
GMClustering mean = new GMClustering(templeConfig.getFeatureNub(), templeConfig); GMClustering mean = new GMClustering(templeConfig.getFeatureNub());
mean.setRegionSize(x * y); mean.setRegionSize(x * y * 0.8);
meanMap.put(type, mean); meanMap.put(type, mean);
mean(threeChannelMatrix, mean); mean(threeChannelMatrix, mean);
//记录非背景的单物体面积 //记录非背景的单物体面积

View File

@ -1,24 +0,0 @@
package org.wlld.imageRecognition;
import org.wlld.imageRecognition.border.DistBody;
import java.util.Comparator;
/**
* @param
* @DATA
* @Author LiDaPeng
* @Description
*/
public class DistSort implements Comparator<DistBody> {
@Override
public int compare(DistBody o1, DistBody o2) {
if (o1.getDist() < o2.getDist()) {
return -1;
} else if (o1.getDist() > o2.getDist()) {
return 1;
} else {
return 0;
}
}
}

View File

@ -9,24 +9,19 @@ public class MeanClustering {
private int length;//向量长度(模型需要返回) private int length;//向量长度(模型需要返回)
protected int speciesQuantity;//种类数量(模型需要返回) protected int speciesQuantity;//种类数量(模型需要返回)
protected List<RGBNorm> matrices = new ArrayList<>();//均值K模型(模型需要返回) protected List<RGBNorm> matrices = new ArrayList<>();//均值K模型(模型需要返回)
private TempleConfig templeConfig;
private int sensoryNerveNub;//神经元个数
private List<MeanClustering> kList = new ArrayList<>();
public List<RGBNorm> getMatrices() { public List<RGBNorm> getMatrices() {
return matrices; return matrices;
} }
public MeanClustering(int speciesQuantity, TempleConfig templeConfig) throws Exception { public MeanClustering(int speciesQuantity) throws Exception {
this.speciesQuantity = speciesQuantity;//聚类的数量 this.speciesQuantity = speciesQuantity;//聚类的数量
this.templeConfig = templeConfig;
} }
public void setColor(double[] color) throws Exception { public void setColor(double[] color) throws Exception {
if (matrixList.size() == 0) { if (matrixList.size() == 0) {
matrixList.add(color); matrixList.add(color);
length = color.length; length = color.length;
sensoryNerveNub = templeConfig.getFeatureNub() * length;
} else { } else {
if (length == color.length) { if (length == color.length) {
matrixList.add(color); matrixList.add(color);
@ -90,76 +85,7 @@ public class MeanClustering {
return sigma; return sigma;
} }
private List<double[]> listK(List<double[]> listOne, int nub) { public void start() throws Exception {//开始聚类
int size = listOne.size();
int oneSize = size / nub;//几份取一份平均值
//System.out.println("oneSize==" + oneSize);
List<double[]> allList = new ArrayList<>();
for (int i = 0; i <= size - oneSize; i += oneSize) {
double[] avg = getListAvg(listOne.subList(i, i + oneSize));
allList.add(avg);
}
return allList;
}
private List<double[]> startBp() {
int times = 2000;
int index = 0;
List<double[]> features = new ArrayList<>();
List<List<double[]>> lists = new ArrayList<>();
for (int j = 0; j < matrices.size(); j++) {
List<double[]> listOne = matrices.get(j).getRgbs();
// List<double[]> list = listK(listOne, times);
//System.out.println(listOne.size());
List<double[]> list = listOne.subList(index, times + index);
lists.add(list);
}
for (int j = 0; j < times; j++) {
double[] feature = new double[sensoryNerveNub];
for (int i = 0; i < lists.size(); i++) {
double[] data = lists.get(i).get(j);
int len = data.length;
for (int k = 0; k < len; k++) {
feature[i * len + k] = data[k];
}
}
features.add(feature);
}
return features;
}
private List<double[]> startRegression() throws Exception {//开始聚类回归
for (int i = 0; i < matrices.size(); i++) {
List<double[]> list = matrices.get(i).getRgbs();
MeanClustering k = kList.get(i);
for (double[] rgb : list) {
k.setColor(rgb);
}
k.start(false);
}
//遍历子聚类
int times = 2000;
Random random = new Random();
List<double[]> features = new ArrayList<>();
for (int i = 0; i < times; i++) {
double[] feature = new double[sensoryNerveNub];
for (int k = 0; k < kList.size(); k++) {
MeanClustering mean = kList.get(k);
List<RGBNorm> rgbNorms = mean.getMatrices();
double[] rgb = rgbNorms.get(random.nextInt(rgbNorms.size())).getRgb();
int rgbLen = rgb.length;
for (int t = 0; t < rgbLen; t++) {
int index = k * rgbLen + t;
feature[index] = rgb[t];
}
}
//System.out.println(Arrays.toString(feature));
features.add(feature);
}
return features;
}
public void start(boolean isRegression) throws Exception {//开始聚类
if (matrixList.size() > 1) { if (matrixList.size() > 1) {
Random random = new Random(); Random random = new Random();
for (int i = 0; i < speciesQuantity; i++) {//初始化均值向量 for (int i = 0; i < speciesQuantity; i++) {//初始化均值向量
@ -182,10 +108,6 @@ public class MeanClustering {
} }
RGBSort rgbSort = new RGBSort(); RGBSort rgbSort = new RGBSort();
Collections.sort(matrices, rgbSort); Collections.sort(matrices, rgbSort);
// for (RGBNorm rgbNorm : matrices) {
// rgbNorm.finish();
// }
// return startBp();
} else { } else {
throw new Exception("matrixList number less than 2"); throw new Exception("matrixList number less than 2");
} }

View File

@ -1,33 +0,0 @@
package org.wlld.imageRecognition;
import org.wlld.imageRecognition.modelEntity.RegressionBody;
public class XYBody {
private double[] X;
private double[] Y;
private RegressionBody regressionBody;
public RegressionBody getRegressionBody() {
return regressionBody;
}
public void setRegressionBody(RegressionBody regressionBody) {
this.regressionBody = regressionBody;
}
public double[] getX() {
return X;
}
public void setX(double[] x) {
X = x;
}
public double[] getY() {
return Y;
}
public void setY(double[] y) {
Y = y;
}
}

View File

@ -12,18 +12,18 @@ import org.wlld.imageRecognition.TempleConfig;
* @Description * @Description
*/ */
public class GMClustering extends MeanClustering { public class GMClustering extends MeanClustering {
private int regionSize;//单区域面积 private double regionSize;//单区域面积
public int getRegionSize() { public double getRegionSize() {
return regionSize; return regionSize;
} }
public void setRegionSize(int regionSize) { public void setRegionSize(double regionSize) {
this.regionSize = regionSize; this.regionSize = regionSize;
} }
public GMClustering(int speciesQuantity, TempleConfig templeConfig) throws Exception { public GMClustering(int speciesQuantity) throws Exception {
super(speciesQuantity, templeConfig); super(speciesQuantity);
} }
public double getProbabilityDensity(double[] feature) throws Exception {//获取总概率密度 public double getProbabilityDensity(double[] feature) throws Exception {//获取总概率密度
@ -35,8 +35,8 @@ public class GMClustering extends MeanClustering {
} }
@Override @Override
public void start(boolean isRegression) throws Exception { public void start() throws Exception {
super.start(isRegression); super.start();
for (RGBNorm rgbNorm : matrices) {//高斯系数初始化 for (RGBNorm rgbNorm : matrices) {//高斯系数初始化
rgbNorm.gm(); rgbNorm.gm();
} }

View File

@ -2,13 +2,9 @@ package org.wlld.imageRecognition.modelEntity;
import org.wlld.MatrixTools.Matrix; import org.wlld.MatrixTools.Matrix;
import org.wlld.imageRecognition.TempleConfig; import org.wlld.imageRecognition.TempleConfig;
import org.wlld.imageRecognition.border.Knn;
import org.wlld.imageRecognition.segmentation.DimensionMappingStudy; import org.wlld.imageRecognition.segmentation.DimensionMappingStudy;
import java.util.List;
import java.util.Set;
/** /**
* @param * @param
* @DATA * @DATA
@ -17,25 +13,14 @@ import java.util.Set;
*/ */
public class DeepMappingBody { public class DeepMappingBody {
private DimensionMappingStudy dimensionAll; private DimensionMappingStudy dimensionAll;
private TempleConfig templeConfig;
private List<KeyMapping> mappingList;
public DeepMappingBody(TempleConfig templeConfig) throws Exception { public DeepMappingBody(TempleConfig templeConfig) throws Exception {
this.templeConfig = templeConfig;
dimensionAll = new DimensionMappingStudy(templeConfig, true); dimensionAll = new DimensionMappingStudy(templeConfig, true);
mappingList = dimensionAll.start(); dimensionAll.start();
} }
public int getType(Matrix feature) throws Exception { public int getType(Matrix feature) throws Exception {
int type = dimensionAll.getType(feature); int type = dimensionAll.getType(feature);
for (KeyMapping keyMapping : mappingList) {
Set<Integer> region = keyMapping.getKeys();
if (region.contains(type)) {
DimensionMappingStudy mapping = keyMapping.getDimensionMapping();
type = mapping.getType(feature);
break;
}
}
return type; return type;
} }
} }

View File

@ -274,9 +274,8 @@ public class DimensionMappingStudy {
myKnn.setFeatureMap(featureMapping(featureMap, mappingSigma)); myKnn.setFeatureMap(featureMapping(featureMap, mappingSigma));
} }
public List<KeyMapping> start() throws Exception { public void start() throws Exception {
mappingStart(); mappingStart();
return selfTest(1);
} }
private Map<Integer, List<Matrix>> featureMapping(Map<Integer, List<Matrix>> featureMap, double[] mapping) throws Exception { private Map<Integer, List<Matrix>> featureMapping(Map<Integer, List<Matrix>> featureMap, double[] mapping) throws Exception {

View File

@ -1,112 +0,0 @@
package org.wlld.imageRecognition.segmentation;
import org.wlld.config.RZ;
import org.wlld.function.Sigmod;
import org.wlld.function.Tanh;
import org.wlld.imageRecognition.modelEntity.RgbBack;
import org.wlld.nerveCenter.NerveManager;
import org.wlld.nerveEntity.SensoryNerve;
import java.awt.image.Kernel;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* @param
* @DATA
* @Author LiDaPeng
* @Description
*/
public class KNerveManger {
private Map<Integer, List<double[]>> featureMap = new HashMap<>();
private int sensoryNerveNub;//输出神经元个数
private int speciesNub;//种类数
private NerveManager nerveManager;
private int times;
private RgbBack rgbBack = new RgbBack();
public KNerveManger(int sensoryNerveNub, int speciesNub, int times) throws Exception {
this.sensoryNerveNub = sensoryNerveNub;
this.speciesNub = speciesNub;
this.times = times;
nerveManager = new NerveManager(sensoryNerveNub, 24, speciesNub,
1, new Tanh(),//0.008 l1 0.02
false, false, 0.008, RZ.L1, 0.02);
nerveManager.init(true, false, true, true);
}
private Map<Integer, Double> createTag(int tag) {//创建一个标注
Map<Integer, Double> tagging = new HashMap<>();
tagging.put(tag, 1.0);
return tagging;
}
public void look(List<double[]> data) throws Exception {
int size = data.size();
Map<Integer, Integer> map = new HashMap<>();
for (int i = 0; i < size; i++) {
rgbBack.clear();
post(data.get(i), null, false);
int type = rgbBack.getId();
if (map.containsKey(type)) {
map.put(type, map.get(type) + 1);
} else {
map.put(type, 1);
}
}
double max = 0;
int type = 0;
for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
int nub = entry.getValue();
if (nub > max) {
max = nub;
type = entry.getKey();
}
}
double point = max / size;
System.out.println("类型是:" + type + ",总票数:" + size + ",得票率:" + point);
System.out.println("=================================完成");
}
public void startStudy() throws Exception {
for (int j = 0; j < 2; j++) {
for (int i = 0; i < times; i++) {
for (Map.Entry<Integer, List<double[]>> entry : featureMap.entrySet()) {
int type = entry.getKey();
System.out.println("=============================" + type);
Map<Integer, Double> tag = createTag(type);//标注
double[] feature = entry.getValue().get(i);//数据
post(feature, tag, true);
}
}
}
for (Map.Entry<Integer, List<double[]>> entry : featureMap.entrySet()) {
int type = entry.getKey();
System.out.println("=============================" + type);
List<double[]> list = entry.getValue();
look(list);
}
}
private void post(double[] data, Map<Integer, Double> tagging, boolean isStudy) throws Exception {
List<SensoryNerve> sensoryNerveList = nerveManager.getSensoryNerves();
int size = sensoryNerveList.size();
for (int i = 0; i < size; i++) {
sensoryNerveList.get(i).postMessage(1, data[i], isStudy, tagging, rgbBack);
}
}
public void setFeature(int type, List<double[]> feature) {
if (type > 0) {
if (featureMap.containsKey(type)) {
featureMap.get(type).addAll(feature);
} else {
featureMap.put(type, feature);
}
}
}
}

View File

@ -1,206 +0,0 @@
package org.wlld.imageRecognition.segmentation;
import org.wlld.MatrixTools.Matrix;
import org.wlld.MatrixTools.MatrixOperation;
import org.wlld.imageRecognition.DistSort;
import org.wlld.imageRecognition.TempleConfig;
import org.wlld.imageRecognition.border.DistBody;
import java.util.*;
/**
* @param
* @DATA
* @Author LiDaPeng
* @Description 修正映射
*/
public class RegionMapping {
private Map<Integer, List<Matrix>> featureMap;
private double studyPoint = 0.01;
private Map<Integer, Matrix> powerMatrixMap = new HashMap<>();//映射层的权重矩阵
private int times = 200;
public RegionMapping(TempleConfig templeConfig) {
this.featureMap = templeConfig.getKnn().getFeatureMap();
}
public void detection() throws Exception {//对模型进行检查,先对当前模型最相近的三个模型进行检出
Map<Integer, List<DistBody>> map = new HashMap<>();
DistSort distSort = new DistSort();
for (Map.Entry<Integer, List<Matrix>> entry : featureMap.entrySet()) {
int key = entry.getKey();
List<Matrix> myFeature = entry.getValue();
//记录 哪个类别 每个类别距离这个类别的最小距离
List<DistBody> distBodies = new ArrayList<>();
map.put(key, distBodies);
for (Map.Entry<Integer, List<Matrix>> entrySon : featureMap.entrySet()) {
int type = entrySon.getKey();
if (key != type) {//当前其余类别
DistBody distBody = new DistBody();
List<Matrix> features = entrySon.getValue();//其余类别的所有特征
double minError = -1;
for (Matrix myMatrix : myFeature) {//待验证类别所有特征
for (Matrix otherFeature : features) {
double dist = MatrixOperation.getEDistByMatrix(myMatrix, otherFeature);
if (minError < 0 || dist < minError) {
minError = dist;
}
}
}
//当前类别遍历结束
distBody.setId(type);
distBody.setDist(minError);
distBodies.add(distBody);
}
}
}
//进行排序
for (Map.Entry<Integer, List<DistBody>> entry : map.entrySet()) {
List<DistBody> list = entry.getValue();
Collections.sort(list, distSort);
}
adjustWeight(map, 3);
test(12, map, 3);
// int testType = 5;
// List<DistBody> testBody = map.get(testType);
// Matrix matrix = powerMatrixMap.get(testType);
// System.out.println(matrix.getString());
}
private Matrix mapping(Matrix feature, Matrix mapping) throws Exception {
int y = feature.getY();
if (y == mapping.getY()) {
Matrix matrix = new Matrix(1, y);
for (int i = 0; i < y; i++) {
double nub = feature.getNumber(0, i) * mapping.getNumber(0, i);
matrix.setNub(0, i, nub);
}
return matrix;
} else {
throw new Exception("matrix is not equals");
}
}
private void test(int key, Map<Integer, List<DistBody>> map, int nub) throws Exception {
Matrix matrixMapping = powerMatrixMap.get(key);//映射
List<Matrix> rightFeature = featureMap.get(key);//正确的特征
System.out.println("正确特征====================");
for (Matrix matrix : rightFeature) {
System.out.println(matrix.getString());
}
int y = rightFeature.get(0).getY();
Matrix featureMatrix = getFeatureAvg(rightFeature, y);//均值
List<DistBody> distBodies = map.get(key);
List<Matrix> wrongFeature = new ArrayList<>();
for (int i = 0; i < nub; i++) {
DistBody distBody = distBodies.get(i);
int id = distBody.getId();//相近的特征id
wrongFeature.addAll(featureMap.get(id));
}
System.out.println("错误特征=============");
for (Matrix matrix : wrongFeature) {
System.out.println(matrix.getString());
}
//做验证
double minOtherDist = -1;//其余特征最小距离
for (Matrix wrongMatrix : wrongFeature) {
Matrix wrongMapping = mapping(wrongMatrix, matrixMapping);
double dist = MatrixOperation.getEDistByMatrix(featureMatrix, wrongMapping);
System.out.println("异类距离:" + dist);
if (minOtherDist < 0 || dist < minOtherDist) {
minOtherDist = dist;
}
}
System.out.println("异类最小距离:" + minOtherDist);
for (Matrix rightMatrix : rightFeature) {
Matrix rightMapping = mapping(rightMatrix, matrixMapping);
double dist = MatrixOperation.getEDistByMatrix(featureMatrix, rightMapping);
System.out.println("同类距离:" + dist);
}
}
private void adjustWeight(Map<Integer, List<DistBody>> map, int nub) throws Exception {//进行权重调整
for (Map.Entry<Integer, List<DistBody>> entry : map.entrySet()) {
int key = entry.getKey();//当前类别的id
List<Matrix> rightFeature = featureMap.get(key);//正确的特征
List<DistBody> distBodies = entry.getValue();
List<Matrix> wrongFeature = new ArrayList<>();
for (int i = 0; i < nub; i++) {
DistBody distBody = distBodies.get(i);
int id = distBody.getId();//相近的特征id
wrongFeature.addAll(featureMap.get(id));
}
//对每个分类的特征图进行权重调整
updatePower(wrongFeature, rightFeature, key);
}
}
//调整权重矩阵
private void updatePower(List<Matrix> wrongFeature, List<Matrix> rightFeature, int key) throws Exception {
int size = wrongFeature.size();
int y = rightFeature.get(0).getY();
//特征均值向量
Matrix featureMatrix = getFeatureAvg(rightFeature, y);
Matrix powerMatrix = new Matrix(1, y);
powerMatrixMap.put(key, powerMatrix);
Random random = new Random();
for (int j = 0; j < times; j++) {
for (int i = 0; i < size; i++) {
Matrix feature = rightFeature.get(random.nextInt(rightFeature.size()));
Matrix noFeature = wrongFeature.get(i);
powerDeflection(feature, featureMatrix, true, powerMatrix);
powerDeflection(noFeature, featureMatrix, false, powerMatrix);
}
}
end();
}
private void end() throws Exception {
for (Map.Entry<Integer, Matrix> entry : powerMatrixMap.entrySet()) {
Matrix powerMatrix = entry.getValue();
int y = powerMatrix.getY();
double min = 0;
for (int i = 0; i < y; i++) {
double nub = powerMatrix.getNumber(0, i);
if (nub < min) {
min = nub;
}
}
//获取最小值完毕
for (int i = 0; i < y; i++) {
double nub = powerMatrix.getNumber(0, i);
powerMatrix.setNub(0, i, nub - min);
}
}
}
private void powerDeflection(Matrix matrix1, Matrix matrix2, boolean polymerization,
Matrix powerMatrix) throws Exception {
int y = matrix1.getY();
for (int i = 0; i < y; i++) {
double sub = Math.abs(matrix1.getNumber(0, i) - matrix2.getNumber(0, i))
* studyPoint;
double power = powerMatrix.getNumber(0, i);//当前矩阵中的权值
if (polymerization) {//同类聚合 聚合是减
power = power - sub;
} else {//异类离散
power = power + sub;
}
powerMatrix.setNub(0, i, power);
}
}
private Matrix getFeatureAvg(List<Matrix> rightFeature, int size) throws Exception {//求特征均值
Matrix feature = new Matrix(1, size);
int nub = rightFeature.size();
for (int i = 0; i < nub; i++) {
Matrix matrix = rightFeature.get(i);
for (int j = 0; j < size; j++) {
double sigma = matrix.getNumber(0, j) + feature.getNumber(0, j);
feature.setNub(0, j, sigma);
}
}
MatrixOperation.mathDiv(feature, nub);
return feature;
}
}

View File

@ -1,7 +1,6 @@
package org.wlld.param; package org.wlld.param;
import org.wlld.imageRecognition.modelEntity.DeepMappingBody; import org.wlld.imageRecognition.modelEntity.DeepMappingBody;
import org.wlld.imageRecognition.segmentation.KNerveManger;
import org.wlld.imageRecognition.segmentation.RgbRegression; import org.wlld.imageRecognition.segmentation.RgbRegression;
import java.util.ArrayList; import java.util.ArrayList;
@ -24,7 +23,6 @@ public class Food {
private int regionSize = 5;//纹理区域大小 private int regionSize = 5;//纹理区域大小
private int step = 1;//特征取样步长 private int step = 1;//特征取样步长
private int speciesNub = 24;//种类数 private int speciesNub = 24;//种类数
private KNerveManger kNerveManger;
private DeepMappingBody deepMappingBody;//特征映射 private DeepMappingBody deepMappingBody;//特征映射
public DeepMappingBody getDeepMappingBody() { public DeepMappingBody getDeepMappingBody() {
@ -35,14 +33,6 @@ public class Food {
this.deepMappingBody = deepMappingBody; this.deepMappingBody = deepMappingBody;
} }
public KNerveManger getkNerveManger() {
return kNerveManger;
}
public void setkNerveManger(KNerveManger kNerveManger) {
this.kNerveManger = kNerveManger;
}
public int getSpeciesNub() { public int getSpeciesNub() {
return speciesNub; return speciesNub;
} }

View File

@ -71,8 +71,6 @@ public class FoodTest {
//菜品识别实体类 //菜品识别实体类
food.setShrink(0);//缩紧像素 food.setShrink(0);//缩紧像素
food.setRegionSize(2); food.setRegionSize(2);
KNerveManger kNerveManger = new KNerveManger(12, 24, 6000);
food.setkNerveManger(kNerveManger);
food.setRowMark(0.05);//0.12 food.setRowMark(0.05);//0.12
food.setColumnMark(0.05);//0.25 food.setColumnMark(0.05);//0.25
food.setRegressionNub(50000); food.setRegressionNub(50000);
@ -128,12 +126,9 @@ public class FoodTest {
ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix(c); ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix(c);
study(threeChannelMatrix1, templeConfig, specificationsList, cutFood, 2); study(threeChannelMatrix1, templeConfig, specificationsList, cutFood, 2);
study(threeChannelMatrix2, templeConfig, specificationsList, cutFood, 3); study(threeChannelMatrix2, templeConfig, specificationsList, cutFood, 3);
/// //
ThreeChannelMatrix threeChannelMatrix3 = picture.getThreeMatrix(g); ThreeChannelMatrix threeChannelMatrix3 = picture.getThreeMatrix(e);
long time1 = System.currentTimeMillis();
look(threeChannelMatrix3, templeConfig, specificationsList, cutFood); look(threeChannelMatrix3, templeConfig, specificationsList, cutFood);
long end = System.currentTimeMillis() - time1;
System.out.println("time:" + end);
} }
private static void look(ThreeChannelMatrix threeChannelMatrix, TempleConfig templeConfig, List<Specifications> specifications, private static void look(ThreeChannelMatrix threeChannelMatrix, TempleConfig templeConfig, List<Specifications> specifications,