mirror of
https://gitee.com/dromara/easyAi.git
synced 2024-12-02 03:38:08 +08:00
添加多种类干食切图
This commit is contained in:
parent
2254885c57
commit
121cfda3f6
@ -70,7 +70,7 @@ public class Convolution extends Frequency {
|
||||
List<ThreeChannelMatrix> threeChannelMatrixList = regionThreeChannelMatrix(threeMatrix, regionSize);
|
||||
for (ThreeChannelMatrix threeChannelMatrix : threeChannelMatrixList) {
|
||||
List<Double> feature = new ArrayList<>();
|
||||
MeanClustering meanClustering = new MeanClustering(sqNub, templeConfig);
|
||||
MeanClustering meanClustering = new MeanClustering(sqNub);
|
||||
Matrix matrixR = threeChannelMatrix.getMatrixR();
|
||||
Matrix matrixG = threeChannelMatrix.getMatrixG();
|
||||
Matrix matrixB = threeChannelMatrix.getMatrixB();
|
||||
@ -82,7 +82,7 @@ public class Convolution extends Frequency {
|
||||
meanClustering.setColor(color);
|
||||
}
|
||||
}
|
||||
meanClustering.start(false);
|
||||
meanClustering.start();
|
||||
List<RGBNorm> rgbNorms = meanClustering.getMatrices();
|
||||
Collections.sort(rgbNorms, rgbSort);
|
||||
for (RGBNorm rgbNorm : rgbNorms) {
|
||||
@ -96,48 +96,10 @@ public class Convolution extends Frequency {
|
||||
return features;
|
||||
}
|
||||
|
||||
|
||||
public List<Double> getCenterColor(ThreeChannelMatrix threeChannelMatrix, TempleConfig templeConfig,
|
||||
int sqNub) throws Exception {
|
||||
Matrix matrixR = threeChannelMatrix.getMatrixR();
|
||||
Matrix matrixG = threeChannelMatrix.getMatrixG();
|
||||
Matrix matrixB = threeChannelMatrix.getMatrixB();
|
||||
MeanClustering meanClustering = new MeanClustering(sqNub, templeConfig);
|
||||
int maxX = matrixR.getX();
|
||||
int maxY = matrixR.getY();
|
||||
ColorFunction colorFunction = new ColorFunction(threeChannelMatrix);
|
||||
int[] minBorder = new int[]{0, 0};
|
||||
int[] maxBorder = new int[]{maxX - 1, maxY - 1};
|
||||
//创建粒子群
|
||||
PSO pso = new PSO(2, minBorder, maxBorder, 200, 200,
|
||||
colorFunction, 0.2, 1, 0.5, true, 10, 1);
|
||||
List<double[]> positions = pso.start();
|
||||
for (int i = 0; i < positions.size(); i++) {
|
||||
double[] parameter = positions.get(i);
|
||||
//获取取样坐标
|
||||
int x = (int) parameter[0];
|
||||
int y = (int) parameter[1];
|
||||
double[] rgb = new double[]{matrixR.getNumber(x, y), matrixG.getNumber(x, y),
|
||||
matrixB.getNumber(x, y)};
|
||||
meanClustering.setColor(rgb);
|
||||
}
|
||||
meanClustering.start(true);
|
||||
List<RGBNorm> rgbNorms = meanClustering.getMatrices();
|
||||
List<Double> features = new ArrayList<>();
|
||||
for (int i = 0; i < sqNub; i++) {
|
||||
double[] rgb = rgbNorms.get(i).getRgb();
|
||||
for (int j = 0; j < rgb.length; j++) {
|
||||
features.add(rgb[j]);
|
||||
}
|
||||
}
|
||||
|
||||
return features;
|
||||
}
|
||||
|
||||
public List<Double> getCenterTexture(ThreeChannelMatrix threeChannelMatrix, int size, TempleConfig templeConfig
|
||||
, int sqNub) throws Exception {
|
||||
//MeanClustering meanClustering = new MeanClustering(sqNub, templeConfig);
|
||||
GMClustering meanClustering = new GMClustering(sqNub, templeConfig);
|
||||
//MeanClustering meanClustering = new MeanClustering(sqNub);
|
||||
GMClustering meanClustering = new GMClustering(sqNub);
|
||||
Matrix matrixR = threeChannelMatrix.getMatrixR();
|
||||
Matrix matrixG = threeChannelMatrix.getMatrixG();
|
||||
Matrix matrixB = threeChannelMatrix.getMatrixB();
|
||||
@ -158,40 +120,7 @@ public class Convolution extends Frequency {
|
||||
}
|
||||
}
|
||||
}
|
||||
// for (int i = 0; i <= xn - size; i += 2) {
|
||||
// for (int j = 0; j <= yn - size; j += 2) {
|
||||
// Matrix sonR = matrixR.getSonOfMatrix(i, j, size, size);
|
||||
// Matrix sonG = matrixG.getSonOfMatrix(i, j, size, size);
|
||||
// Matrix sonB = matrixB.getSonOfMatrix(i, j, size, size);
|
||||
// Matrix sonRGB = matrixRGB.getSonOfMatrix(i, j, size, size);
|
||||
// double[] h = new double[nub];
|
||||
// double[] rgb = new double[nub * 3];
|
||||
// for (int t = 0; t < size; t++) {
|
||||
// for (int k = 0; k < size; k++) {
|
||||
// int index = t * size + k;
|
||||
// h[index] = sonRGB.getNumber(t, k);
|
||||
// rgb[index] = sonR.getNumber(t, k) / 255;
|
||||
// rgb[nub + index] = sonG.getNumber(t, k) / 255;
|
||||
// rgb[twoNub + index] = sonB.getNumber(t, k) / 255;
|
||||
// }
|
||||
// }
|
||||
// //900 200
|
||||
// double dispersed = variance(h);
|
||||
// if (dispersed < 900 && dispersed > 200) {
|
||||
// for (int m = 0; m < nub; m++) {
|
||||
// double[] color = new double[]{rgb[m], rgb[m + nub], rgb[m + twoNub]};
|
||||
// meanClustering.setColor(color);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//List<double[]> list = meanClustering.start(true);//开始聚类
|
||||
meanClustering.start(true);
|
||||
// if (tag == 0) {//识别
|
||||
// templeConfig.getFood().getkNerveManger().look(list);
|
||||
// } else {//训练
|
||||
// templeConfig.getFood().getkNerveManger().setFeature(tag, list);
|
||||
// }
|
||||
meanClustering.start();
|
||||
List<RGBNorm> rgbNorms = meanClustering.getMatrices();
|
||||
List<Double> features = new ArrayList<>();
|
||||
for (int i = 0; i < sqNub; i++) {
|
||||
|
@ -38,7 +38,7 @@ public class CutFood {
|
||||
mean.setColor(rgb);
|
||||
}
|
||||
}
|
||||
mean.start(true);
|
||||
mean.start();
|
||||
}
|
||||
|
||||
private double getAvg(Matrix matrix) throws Exception {
|
||||
@ -89,7 +89,7 @@ public class CutFood {
|
||||
double regionSize = gmBody.getPixelNub() * s;
|
||||
int type = gmBody.getType();
|
||||
if (type != 1) {//背景直接过滤
|
||||
int oneSize = meanMap.get(type).getRegionSize();
|
||||
double oneSize = meanMap.get(type).getRegionSize();
|
||||
if (regionSize > oneSize * 0.8) {
|
||||
gmBodies2.add(gmBody);
|
||||
}
|
||||
@ -98,7 +98,7 @@ public class CutFood {
|
||||
for (GMBody gmBody : gmBodies2) {
|
||||
int type = gmBody.getType();
|
||||
double regionSize = gmBody.getPixelNub() * s;
|
||||
int oneSize = meanMap.get(type).getRegionSize();
|
||||
double oneSize = meanMap.get(type).getRegionSize();
|
||||
double nub = regionSize / (double) oneSize;
|
||||
System.out.println("type==" + type + ",nub==" + nub + ",onSize==" + oneSize + ",gmNub=="
|
||||
+ gmBody.getPixelNub());
|
||||
@ -138,8 +138,8 @@ public class CutFood {
|
||||
Matrix matrixR = threeChannelMatrix.getMatrixR();
|
||||
int x = matrixR.getX();
|
||||
int y = matrixR.getY();
|
||||
GMClustering mean = new GMClustering(templeConfig.getFeatureNub(), templeConfig);
|
||||
mean.setRegionSize(x * y);
|
||||
GMClustering mean = new GMClustering(templeConfig.getFeatureNub());
|
||||
mean.setRegionSize(x * y * 0.8);
|
||||
meanMap.put(type, mean);
|
||||
mean(threeChannelMatrix, mean);
|
||||
//记录非背景的单物体面积
|
||||
|
@ -1,24 +0,0 @@
|
||||
package org.wlld.imageRecognition;
|
||||
|
||||
import org.wlld.imageRecognition.border.DistBody;
|
||||
|
||||
import java.util.Comparator;
|
||||
|
||||
/**
|
||||
* @param
|
||||
* @DATA
|
||||
* @Author LiDaPeng
|
||||
* @Description
|
||||
*/
|
||||
public class DistSort implements Comparator<DistBody> {
|
||||
@Override
|
||||
public int compare(DistBody o1, DistBody o2) {
|
||||
if (o1.getDist() < o2.getDist()) {
|
||||
return -1;
|
||||
} else if (o1.getDist() > o2.getDist()) {
|
||||
return 1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
@ -9,24 +9,19 @@ public class MeanClustering {
|
||||
private int length;//向量长度(模型需要返回)
|
||||
protected int speciesQuantity;//种类数量(模型需要返回)
|
||||
protected List<RGBNorm> matrices = new ArrayList<>();//均值K模型(模型需要返回)
|
||||
private TempleConfig templeConfig;
|
||||
private int sensoryNerveNub;//神经元个数
|
||||
private List<MeanClustering> kList = new ArrayList<>();
|
||||
|
||||
public List<RGBNorm> getMatrices() {
|
||||
return matrices;
|
||||
}
|
||||
|
||||
public MeanClustering(int speciesQuantity, TempleConfig templeConfig) throws Exception {
|
||||
public MeanClustering(int speciesQuantity) throws Exception {
|
||||
this.speciesQuantity = speciesQuantity;//聚类的数量
|
||||
this.templeConfig = templeConfig;
|
||||
}
|
||||
|
||||
public void setColor(double[] color) throws Exception {
|
||||
if (matrixList.size() == 0) {
|
||||
matrixList.add(color);
|
||||
length = color.length;
|
||||
sensoryNerveNub = templeConfig.getFeatureNub() * length;
|
||||
} else {
|
||||
if (length == color.length) {
|
||||
matrixList.add(color);
|
||||
@ -90,76 +85,7 @@ public class MeanClustering {
|
||||
return sigma;
|
||||
}
|
||||
|
||||
private List<double[]> listK(List<double[]> listOne, int nub) {
|
||||
int size = listOne.size();
|
||||
int oneSize = size / nub;//几份取一份平均值
|
||||
//System.out.println("oneSize==" + oneSize);
|
||||
List<double[]> allList = new ArrayList<>();
|
||||
for (int i = 0; i <= size - oneSize; i += oneSize) {
|
||||
double[] avg = getListAvg(listOne.subList(i, i + oneSize));
|
||||
allList.add(avg);
|
||||
}
|
||||
return allList;
|
||||
}
|
||||
|
||||
private List<double[]> startBp() {
|
||||
int times = 2000;
|
||||
int index = 0;
|
||||
List<double[]> features = new ArrayList<>();
|
||||
List<List<double[]>> lists = new ArrayList<>();
|
||||
for (int j = 0; j < matrices.size(); j++) {
|
||||
List<double[]> listOne = matrices.get(j).getRgbs();
|
||||
// List<double[]> list = listK(listOne, times);
|
||||
//System.out.println(listOne.size());
|
||||
List<double[]> list = listOne.subList(index, times + index);
|
||||
lists.add(list);
|
||||
}
|
||||
for (int j = 0; j < times; j++) {
|
||||
double[] feature = new double[sensoryNerveNub];
|
||||
for (int i = 0; i < lists.size(); i++) {
|
||||
double[] data = lists.get(i).get(j);
|
||||
int len = data.length;
|
||||
for (int k = 0; k < len; k++) {
|
||||
feature[i * len + k] = data[k];
|
||||
}
|
||||
}
|
||||
features.add(feature);
|
||||
}
|
||||
return features;
|
||||
}
|
||||
|
||||
private List<double[]> startRegression() throws Exception {//开始聚类回归
|
||||
for (int i = 0; i < matrices.size(); i++) {
|
||||
List<double[]> list = matrices.get(i).getRgbs();
|
||||
MeanClustering k = kList.get(i);
|
||||
for (double[] rgb : list) {
|
||||
k.setColor(rgb);
|
||||
}
|
||||
k.start(false);
|
||||
}
|
||||
//遍历子聚类
|
||||
int times = 2000;
|
||||
Random random = new Random();
|
||||
List<double[]> features = new ArrayList<>();
|
||||
for (int i = 0; i < times; i++) {
|
||||
double[] feature = new double[sensoryNerveNub];
|
||||
for (int k = 0; k < kList.size(); k++) {
|
||||
MeanClustering mean = kList.get(k);
|
||||
List<RGBNorm> rgbNorms = mean.getMatrices();
|
||||
double[] rgb = rgbNorms.get(random.nextInt(rgbNorms.size())).getRgb();
|
||||
int rgbLen = rgb.length;
|
||||
for (int t = 0; t < rgbLen; t++) {
|
||||
int index = k * rgbLen + t;
|
||||
feature[index] = rgb[t];
|
||||
}
|
||||
}
|
||||
//System.out.println(Arrays.toString(feature));
|
||||
features.add(feature);
|
||||
}
|
||||
return features;
|
||||
}
|
||||
|
||||
public void start(boolean isRegression) throws Exception {//开始聚类
|
||||
public void start() throws Exception {//开始聚类
|
||||
if (matrixList.size() > 1) {
|
||||
Random random = new Random();
|
||||
for (int i = 0; i < speciesQuantity; i++) {//初始化均值向量
|
||||
@ -182,10 +108,6 @@ public class MeanClustering {
|
||||
}
|
||||
RGBSort rgbSort = new RGBSort();
|
||||
Collections.sort(matrices, rgbSort);
|
||||
// for (RGBNorm rgbNorm : matrices) {
|
||||
// rgbNorm.finish();
|
||||
// }
|
||||
// return startBp();
|
||||
} else {
|
||||
throw new Exception("matrixList number less than 2");
|
||||
}
|
||||
|
@ -1,33 +0,0 @@
|
||||
package org.wlld.imageRecognition;
|
||||
|
||||
import org.wlld.imageRecognition.modelEntity.RegressionBody;
|
||||
|
||||
public class XYBody {
|
||||
private double[] X;
|
||||
private double[] Y;
|
||||
private RegressionBody regressionBody;
|
||||
|
||||
public RegressionBody getRegressionBody() {
|
||||
return regressionBody;
|
||||
}
|
||||
|
||||
public void setRegressionBody(RegressionBody regressionBody) {
|
||||
this.regressionBody = regressionBody;
|
||||
}
|
||||
|
||||
public double[] getX() {
|
||||
return X;
|
||||
}
|
||||
|
||||
public void setX(double[] x) {
|
||||
X = x;
|
||||
}
|
||||
|
||||
public double[] getY() {
|
||||
return Y;
|
||||
}
|
||||
|
||||
public void setY(double[] y) {
|
||||
Y = y;
|
||||
}
|
||||
}
|
@ -12,18 +12,18 @@ import org.wlld.imageRecognition.TempleConfig;
|
||||
* @Description
|
||||
*/
|
||||
public class GMClustering extends MeanClustering {
|
||||
private int regionSize;//单区域面积
|
||||
private double regionSize;//单区域面积
|
||||
|
||||
public int getRegionSize() {
|
||||
public double getRegionSize() {
|
||||
return regionSize;
|
||||
}
|
||||
|
||||
public void setRegionSize(int regionSize) {
|
||||
public void setRegionSize(double regionSize) {
|
||||
this.regionSize = regionSize;
|
||||
}
|
||||
|
||||
public GMClustering(int speciesQuantity, TempleConfig templeConfig) throws Exception {
|
||||
super(speciesQuantity, templeConfig);
|
||||
public GMClustering(int speciesQuantity) throws Exception {
|
||||
super(speciesQuantity);
|
||||
}
|
||||
|
||||
public double getProbabilityDensity(double[] feature) throws Exception {//获取总概率密度
|
||||
@ -35,8 +35,8 @@ public class GMClustering extends MeanClustering {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start(boolean isRegression) throws Exception {
|
||||
super.start(isRegression);
|
||||
public void start() throws Exception {
|
||||
super.start();
|
||||
for (RGBNorm rgbNorm : matrices) {//高斯系数初始化
|
||||
rgbNorm.gm();
|
||||
}
|
||||
|
@ -2,13 +2,9 @@ package org.wlld.imageRecognition.modelEntity;
|
||||
|
||||
import org.wlld.MatrixTools.Matrix;
|
||||
import org.wlld.imageRecognition.TempleConfig;
|
||||
import org.wlld.imageRecognition.border.Knn;
|
||||
import org.wlld.imageRecognition.segmentation.DimensionMappingStudy;
|
||||
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* @param
|
||||
* @DATA
|
||||
@ -17,25 +13,14 @@ import java.util.Set;
|
||||
*/
|
||||
public class DeepMappingBody {
|
||||
private DimensionMappingStudy dimensionAll;
|
||||
private TempleConfig templeConfig;
|
||||
private List<KeyMapping> mappingList;
|
||||
|
||||
public DeepMappingBody(TempleConfig templeConfig) throws Exception {
|
||||
this.templeConfig = templeConfig;
|
||||
dimensionAll = new DimensionMappingStudy(templeConfig, true);
|
||||
mappingList = dimensionAll.start();
|
||||
dimensionAll.start();
|
||||
}
|
||||
|
||||
public int getType(Matrix feature) throws Exception {
|
||||
int type = dimensionAll.getType(feature);
|
||||
for (KeyMapping keyMapping : mappingList) {
|
||||
Set<Integer> region = keyMapping.getKeys();
|
||||
if (region.contains(type)) {
|
||||
DimensionMappingStudy mapping = keyMapping.getDimensionMapping();
|
||||
type = mapping.getType(feature);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return type;
|
||||
}
|
||||
}
|
||||
|
@ -274,9 +274,8 @@ public class DimensionMappingStudy {
|
||||
myKnn.setFeatureMap(featureMapping(featureMap, mappingSigma));
|
||||
}
|
||||
|
||||
public List<KeyMapping> start() throws Exception {
|
||||
public void start() throws Exception {
|
||||
mappingStart();
|
||||
return selfTest(1);
|
||||
}
|
||||
|
||||
private Map<Integer, List<Matrix>> featureMapping(Map<Integer, List<Matrix>> featureMap, double[] mapping) throws Exception {
|
||||
|
@ -1,112 +0,0 @@
|
||||
package org.wlld.imageRecognition.segmentation;
|
||||
|
||||
import org.wlld.config.RZ;
|
||||
import org.wlld.function.Sigmod;
|
||||
import org.wlld.function.Tanh;
|
||||
import org.wlld.imageRecognition.modelEntity.RgbBack;
|
||||
import org.wlld.nerveCenter.NerveManager;
|
||||
import org.wlld.nerveEntity.SensoryNerve;
|
||||
|
||||
import java.awt.image.Kernel;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* @param
|
||||
* @DATA
|
||||
* @Author LiDaPeng
|
||||
* @Description
|
||||
*/
|
||||
public class KNerveManger {
|
||||
private Map<Integer, List<double[]>> featureMap = new HashMap<>();
|
||||
private int sensoryNerveNub;//输出神经元个数
|
||||
private int speciesNub;//种类数
|
||||
private NerveManager nerveManager;
|
||||
private int times;
|
||||
private RgbBack rgbBack = new RgbBack();
|
||||
|
||||
public KNerveManger(int sensoryNerveNub, int speciesNub, int times) throws Exception {
|
||||
this.sensoryNerveNub = sensoryNerveNub;
|
||||
this.speciesNub = speciesNub;
|
||||
this.times = times;
|
||||
nerveManager = new NerveManager(sensoryNerveNub, 24, speciesNub,
|
||||
1, new Tanh(),//0.008 l1 0.02
|
||||
false, false, 0.008, RZ.L1, 0.02);
|
||||
nerveManager.init(true, false, true, true);
|
||||
}
|
||||
|
||||
private Map<Integer, Double> createTag(int tag) {//创建一个标注
|
||||
Map<Integer, Double> tagging = new HashMap<>();
|
||||
tagging.put(tag, 1.0);
|
||||
return tagging;
|
||||
}
|
||||
|
||||
public void look(List<double[]> data) throws Exception {
|
||||
int size = data.size();
|
||||
Map<Integer, Integer> map = new HashMap<>();
|
||||
for (int i = 0; i < size; i++) {
|
||||
rgbBack.clear();
|
||||
post(data.get(i), null, false);
|
||||
int type = rgbBack.getId();
|
||||
if (map.containsKey(type)) {
|
||||
map.put(type, map.get(type) + 1);
|
||||
} else {
|
||||
map.put(type, 1);
|
||||
}
|
||||
}
|
||||
double max = 0;
|
||||
int type = 0;
|
||||
for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
|
||||
int nub = entry.getValue();
|
||||
if (nub > max) {
|
||||
max = nub;
|
||||
type = entry.getKey();
|
||||
}
|
||||
}
|
||||
double point = max / size;
|
||||
System.out.println("类型是:" + type + ",总票数:" + size + ",得票率:" + point);
|
||||
System.out.println("=================================完成");
|
||||
}
|
||||
|
||||
public void startStudy() throws Exception {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
for (int i = 0; i < times; i++) {
|
||||
for (Map.Entry<Integer, List<double[]>> entry : featureMap.entrySet()) {
|
||||
int type = entry.getKey();
|
||||
System.out.println("=============================" + type);
|
||||
Map<Integer, Double> tag = createTag(type);//标注
|
||||
double[] feature = entry.getValue().get(i);//数据
|
||||
post(feature, tag, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (Map.Entry<Integer, List<double[]>> entry : featureMap.entrySet()) {
|
||||
int type = entry.getKey();
|
||||
System.out.println("=============================" + type);
|
||||
List<double[]> list = entry.getValue();
|
||||
look(list);
|
||||
}
|
||||
}
|
||||
|
||||
private void post(double[] data, Map<Integer, Double> tagging, boolean isStudy) throws Exception {
|
||||
List<SensoryNerve> sensoryNerveList = nerveManager.getSensoryNerves();
|
||||
int size = sensoryNerveList.size();
|
||||
for (int i = 0; i < size; i++) {
|
||||
sensoryNerveList.get(i).postMessage(1, data[i], isStudy, tagging, rgbBack);
|
||||
}
|
||||
}
|
||||
|
||||
public void setFeature(int type, List<double[]> feature) {
|
||||
if (type > 0) {
|
||||
if (featureMap.containsKey(type)) {
|
||||
featureMap.get(type).addAll(feature);
|
||||
} else {
|
||||
featureMap.put(type, feature);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -1,206 +0,0 @@
|
||||
package org.wlld.imageRecognition.segmentation;
|
||||
|
||||
import org.wlld.MatrixTools.Matrix;
|
||||
import org.wlld.MatrixTools.MatrixOperation;
|
||||
import org.wlld.imageRecognition.DistSort;
|
||||
import org.wlld.imageRecognition.TempleConfig;
|
||||
import org.wlld.imageRecognition.border.DistBody;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* @param
|
||||
* @DATA
|
||||
* @Author LiDaPeng
|
||||
* @Description 修正映射
|
||||
*/
|
||||
public class RegionMapping {
|
||||
private Map<Integer, List<Matrix>> featureMap;
|
||||
private double studyPoint = 0.01;
|
||||
private Map<Integer, Matrix> powerMatrixMap = new HashMap<>();//映射层的权重矩阵
|
||||
private int times = 200;
|
||||
|
||||
public RegionMapping(TempleConfig templeConfig) {
|
||||
this.featureMap = templeConfig.getKnn().getFeatureMap();
|
||||
}
|
||||
|
||||
public void detection() throws Exception {//对模型进行检查,先对当前模型最相近的三个模型进行检出
|
||||
Map<Integer, List<DistBody>> map = new HashMap<>();
|
||||
DistSort distSort = new DistSort();
|
||||
for (Map.Entry<Integer, List<Matrix>> entry : featureMap.entrySet()) {
|
||||
int key = entry.getKey();
|
||||
List<Matrix> myFeature = entry.getValue();
|
||||
//记录 哪个类别 每个类别距离这个类别的最小距离
|
||||
List<DistBody> distBodies = new ArrayList<>();
|
||||
map.put(key, distBodies);
|
||||
for (Map.Entry<Integer, List<Matrix>> entrySon : featureMap.entrySet()) {
|
||||
int type = entrySon.getKey();
|
||||
if (key != type) {//当前其余类别
|
||||
DistBody distBody = new DistBody();
|
||||
List<Matrix> features = entrySon.getValue();//其余类别的所有特征
|
||||
double minError = -1;
|
||||
for (Matrix myMatrix : myFeature) {//待验证类别所有特征
|
||||
for (Matrix otherFeature : features) {
|
||||
double dist = MatrixOperation.getEDistByMatrix(myMatrix, otherFeature);
|
||||
if (minError < 0 || dist < minError) {
|
||||
minError = dist;
|
||||
}
|
||||
}
|
||||
}
|
||||
//当前类别遍历结束
|
||||
distBody.setId(type);
|
||||
distBody.setDist(minError);
|
||||
distBodies.add(distBody);
|
||||
}
|
||||
}
|
||||
}
|
||||
//进行排序
|
||||
for (Map.Entry<Integer, List<DistBody>> entry : map.entrySet()) {
|
||||
List<DistBody> list = entry.getValue();
|
||||
Collections.sort(list, distSort);
|
||||
}
|
||||
adjustWeight(map, 3);
|
||||
test(12, map, 3);
|
||||
// int testType = 5;
|
||||
// List<DistBody> testBody = map.get(testType);
|
||||
// Matrix matrix = powerMatrixMap.get(testType);
|
||||
// System.out.println(matrix.getString());
|
||||
}
|
||||
|
||||
private Matrix mapping(Matrix feature, Matrix mapping) throws Exception {
|
||||
int y = feature.getY();
|
||||
if (y == mapping.getY()) {
|
||||
Matrix matrix = new Matrix(1, y);
|
||||
for (int i = 0; i < y; i++) {
|
||||
double nub = feature.getNumber(0, i) * mapping.getNumber(0, i);
|
||||
matrix.setNub(0, i, nub);
|
||||
}
|
||||
return matrix;
|
||||
} else {
|
||||
throw new Exception("matrix is not equals");
|
||||
}
|
||||
}
|
||||
|
||||
private void test(int key, Map<Integer, List<DistBody>> map, int nub) throws Exception {
|
||||
Matrix matrixMapping = powerMatrixMap.get(key);//映射
|
||||
List<Matrix> rightFeature = featureMap.get(key);//正确的特征
|
||||
System.out.println("正确特征====================");
|
||||
for (Matrix matrix : rightFeature) {
|
||||
System.out.println(matrix.getString());
|
||||
}
|
||||
int y = rightFeature.get(0).getY();
|
||||
Matrix featureMatrix = getFeatureAvg(rightFeature, y);//均值
|
||||
List<DistBody> distBodies = map.get(key);
|
||||
List<Matrix> wrongFeature = new ArrayList<>();
|
||||
for (int i = 0; i < nub; i++) {
|
||||
DistBody distBody = distBodies.get(i);
|
||||
int id = distBody.getId();//相近的特征id
|
||||
wrongFeature.addAll(featureMap.get(id));
|
||||
}
|
||||
System.out.println("错误特征=============");
|
||||
for (Matrix matrix : wrongFeature) {
|
||||
System.out.println(matrix.getString());
|
||||
}
|
||||
//做验证
|
||||
double minOtherDist = -1;//其余特征最小距离
|
||||
for (Matrix wrongMatrix : wrongFeature) {
|
||||
Matrix wrongMapping = mapping(wrongMatrix, matrixMapping);
|
||||
double dist = MatrixOperation.getEDistByMatrix(featureMatrix, wrongMapping);
|
||||
System.out.println("异类距离:" + dist);
|
||||
if (minOtherDist < 0 || dist < minOtherDist) {
|
||||
minOtherDist = dist;
|
||||
}
|
||||
}
|
||||
System.out.println("异类最小距离:" + minOtherDist);
|
||||
for (Matrix rightMatrix : rightFeature) {
|
||||
Matrix rightMapping = mapping(rightMatrix, matrixMapping);
|
||||
double dist = MatrixOperation.getEDistByMatrix(featureMatrix, rightMapping);
|
||||
System.out.println("同类距离:" + dist);
|
||||
}
|
||||
}
|
||||
|
||||
private void adjustWeight(Map<Integer, List<DistBody>> map, int nub) throws Exception {//进行权重调整
|
||||
for (Map.Entry<Integer, List<DistBody>> entry : map.entrySet()) {
|
||||
int key = entry.getKey();//当前类别的id
|
||||
List<Matrix> rightFeature = featureMap.get(key);//正确的特征
|
||||
List<DistBody> distBodies = entry.getValue();
|
||||
List<Matrix> wrongFeature = new ArrayList<>();
|
||||
for (int i = 0; i < nub; i++) {
|
||||
DistBody distBody = distBodies.get(i);
|
||||
int id = distBody.getId();//相近的特征id
|
||||
wrongFeature.addAll(featureMap.get(id));
|
||||
}
|
||||
//对每个分类的特征图进行权重调整
|
||||
updatePower(wrongFeature, rightFeature, key);
|
||||
}
|
||||
}
|
||||
|
||||
//调整权重矩阵
|
||||
private void updatePower(List<Matrix> wrongFeature, List<Matrix> rightFeature, int key) throws Exception {
|
||||
int size = wrongFeature.size();
|
||||
int y = rightFeature.get(0).getY();
|
||||
//特征均值向量
|
||||
Matrix featureMatrix = getFeatureAvg(rightFeature, y);
|
||||
Matrix powerMatrix = new Matrix(1, y);
|
||||
powerMatrixMap.put(key, powerMatrix);
|
||||
Random random = new Random();
|
||||
for (int j = 0; j < times; j++) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
Matrix feature = rightFeature.get(random.nextInt(rightFeature.size()));
|
||||
Matrix noFeature = wrongFeature.get(i);
|
||||
powerDeflection(feature, featureMatrix, true, powerMatrix);
|
||||
powerDeflection(noFeature, featureMatrix, false, powerMatrix);
|
||||
}
|
||||
}
|
||||
end();
|
||||
}
|
||||
|
||||
private void end() throws Exception {
|
||||
for (Map.Entry<Integer, Matrix> entry : powerMatrixMap.entrySet()) {
|
||||
Matrix powerMatrix = entry.getValue();
|
||||
int y = powerMatrix.getY();
|
||||
double min = 0;
|
||||
for (int i = 0; i < y; i++) {
|
||||
double nub = powerMatrix.getNumber(0, i);
|
||||
if (nub < min) {
|
||||
min = nub;
|
||||
}
|
||||
}
|
||||
//获取最小值完毕
|
||||
for (int i = 0; i < y; i++) {
|
||||
double nub = powerMatrix.getNumber(0, i);
|
||||
powerMatrix.setNub(0, i, nub - min);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void powerDeflection(Matrix matrix1, Matrix matrix2, boolean polymerization,
|
||||
Matrix powerMatrix) throws Exception {
|
||||
int y = matrix1.getY();
|
||||
for (int i = 0; i < y; i++) {
|
||||
double sub = Math.abs(matrix1.getNumber(0, i) - matrix2.getNumber(0, i))
|
||||
* studyPoint;
|
||||
double power = powerMatrix.getNumber(0, i);//当前矩阵中的权值
|
||||
if (polymerization) {//同类聚合 聚合是减
|
||||
power = power - sub;
|
||||
} else {//异类离散
|
||||
power = power + sub;
|
||||
}
|
||||
powerMatrix.setNub(0, i, power);
|
||||
}
|
||||
}
|
||||
|
||||
private Matrix getFeatureAvg(List<Matrix> rightFeature, int size) throws Exception {//求特征均值
|
||||
Matrix feature = new Matrix(1, size);
|
||||
int nub = rightFeature.size();
|
||||
for (int i = 0; i < nub; i++) {
|
||||
Matrix matrix = rightFeature.get(i);
|
||||
for (int j = 0; j < size; j++) {
|
||||
double sigma = matrix.getNumber(0, j) + feature.getNumber(0, j);
|
||||
feature.setNub(0, j, sigma);
|
||||
}
|
||||
}
|
||||
MatrixOperation.mathDiv(feature, nub);
|
||||
return feature;
|
||||
}
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
package org.wlld.param;
|
||||
|
||||
import org.wlld.imageRecognition.modelEntity.DeepMappingBody;
|
||||
import org.wlld.imageRecognition.segmentation.KNerveManger;
|
||||
import org.wlld.imageRecognition.segmentation.RgbRegression;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@ -24,7 +23,6 @@ public class Food {
|
||||
private int regionSize = 5;//纹理区域大小
|
||||
private int step = 1;//特征取样步长
|
||||
private int speciesNub = 24;//种类数
|
||||
private KNerveManger kNerveManger;
|
||||
private DeepMappingBody deepMappingBody;//特征映射
|
||||
|
||||
public DeepMappingBody getDeepMappingBody() {
|
||||
@ -35,14 +33,6 @@ public class Food {
|
||||
this.deepMappingBody = deepMappingBody;
|
||||
}
|
||||
|
||||
public KNerveManger getkNerveManger() {
|
||||
return kNerveManger;
|
||||
}
|
||||
|
||||
public void setkNerveManger(KNerveManger kNerveManger) {
|
||||
this.kNerveManger = kNerveManger;
|
||||
}
|
||||
|
||||
public int getSpeciesNub() {
|
||||
return speciesNub;
|
||||
}
|
||||
|
@ -71,8 +71,6 @@ public class FoodTest {
|
||||
//菜品识别实体类
|
||||
food.setShrink(0);//缩紧像素
|
||||
food.setRegionSize(2);
|
||||
KNerveManger kNerveManger = new KNerveManger(12, 24, 6000);
|
||||
food.setkNerveManger(kNerveManger);
|
||||
food.setRowMark(0.05);//0.12
|
||||
food.setColumnMark(0.05);//0.25
|
||||
food.setRegressionNub(50000);
|
||||
@ -128,12 +126,9 @@ public class FoodTest {
|
||||
ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix(c);
|
||||
study(threeChannelMatrix1, templeConfig, specificationsList, cutFood, 2);
|
||||
study(threeChannelMatrix2, templeConfig, specificationsList, cutFood, 3);
|
||||
///
|
||||
ThreeChannelMatrix threeChannelMatrix3 = picture.getThreeMatrix(g);
|
||||
long time1 = System.currentTimeMillis();
|
||||
//
|
||||
ThreeChannelMatrix threeChannelMatrix3 = picture.getThreeMatrix(e);
|
||||
look(threeChannelMatrix3, templeConfig, specificationsList, cutFood);
|
||||
long end = System.currentTimeMillis() - time1;
|
||||
System.out.println("time:" + end);
|
||||
}
|
||||
|
||||
private static void look(ThreeChannelMatrix threeChannelMatrix, TempleConfig templeConfig, List<Specifications> specifications,
|
||||
|
Loading…
Reference in New Issue
Block a user