增加Knn分类器

This commit is contained in:
thenk008 2020-05-20 16:55:22 +08:00
parent ce889e439b
commit 7ab9886674
9 changed files with 267 additions and 93 deletions

View File

@ -4,4 +4,5 @@ public class Classifier {//分类器
public static final int LVQ = 1;//LVQ分类 你的训练模版量非常少 比如 一种只有几十一百张照片/分类少
public static final int DNN = 2; //使用DNN分类 训练量足够大一个种类1500+训练图片
public static final int VAvg = 3;//使用特征向量均值分类 一种只有几十一百张照片
public static final int KNN = 4;//KNN分类器
}

View File

@ -106,7 +106,7 @@ public class Convolution extends Frequency {
return features;
}
public List<Double> getCenterColor(ThreeChannelMatrix threeChannelMatrix, int poolSize, int sqNub) throws Exception {
public List<double[]> getCenterColor(ThreeChannelMatrix threeChannelMatrix, int poolSize, int sqNub) throws Exception {
Matrix matrixR = threeChannelMatrix.getMatrixR();
Matrix matrixG = threeChannelMatrix.getMatrixG();
Matrix matrixB = threeChannelMatrix.getMatrixB();
@ -126,9 +126,9 @@ public class Convolution extends Frequency {
meanClustering.start();
List<RGBNorm> rgbNorms = meanClustering.getMatrices();
Collections.sort(rgbNorms, rgbSort);
List<Double> feature = new ArrayList<>();
List<double[]> feature = new ArrayList<>();
for (int i = 0; i < sqNub; i++) {
feature.add(rgbNorms.get(i).getNorm());
feature.add(rgbNorms.get(i).getRgb());
}
//System.out.println("feature==" + feature);
return feature;

View File

@ -18,10 +18,7 @@ import org.wlld.nerveEntity.SensoryNerve;
import org.wlld.tools.ArithUtil;
import org.wlld.tools.IdCreator;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
public class Operation {//进行计算
private TempleConfig templeConfig;//配置初始化参数模板
@ -70,31 +67,43 @@ public class Operation {//进行计算
int xSize = maxX - minX;
int ySize = maxY - minY;
ThreeChannelMatrix threeChannelMatrix1 = convolution.getRegionMatrix(threeChannelMatrix, minX, minY, xSize, ySize);
List<Double> feature = convolution.getCenterColor(threeChannelMatrix1, templeConfig.getPoolSize(),
List<double[]> feature = convolution.getCenterColor(threeChannelMatrix1, templeConfig.getPoolSize(),
templeConfig.getFeatureNub());
if (templeConfig.isShowLog()) {
System.out.println(feature);
for (double[] f : feature) {
System.out.println(Arrays.toString(f));
}
}
System.out.println("=====================================");
int classifier = templeConfig.getClassifier();
switch (classifier) {
case Classifier.DNN:
Map<Integer, Double> map = new HashMap<>();
map.put(tag, 1.0);
if (templeConfig.getSensoryNerves().size() == templeConfig.getFeatureNub()) {
intoDnnNetwork(1, feature, templeConfig.getSensoryNerves(), true, map, null);
} else {
throw new Exception("nerves number is not equal featureNub");
}
break;
case Classifier.LVQ:
Matrix vector = MatrixOperation.listToRowVector(feature);
lvqStudy(tag, vector);
break;
case Classifier.VAvg:
Matrix vec = MatrixOperation.listToRowVector(feature);
avgStudy(tag, vec);
break;
}
// switch (classifier) {
// case Classifier.DNN:
// Map<Integer, Double> map = new HashMap<>();
// map.put(tag, 1.0);
// if (templeConfig.getSensoryNerves().size() == templeConfig.getFeatureNub()) {
// intoDnnNetwork(1, feature, templeConfig.getSensoryNerves(), true, map, null);
// } else {
// throw new Exception("nerves number is not equal featureNub");
// }
// break;
// case Classifier.LVQ:
// Matrix vector = MatrixOperation.listToRowVector(feature);
// lvqStudy(tag, vector);
// break;
// case Classifier.VAvg:
// Matrix vec = MatrixOperation.listToRowVector(feature);
// avgStudy(tag, vec);
// break;
// case Classifier.KNN:
// Matrix veck = MatrixOperation.listToRowVector(feature);
// knnStudy(tag, veck);
// break;
// }
}
private void knnStudy(int tagging, Matrix vector) throws Exception {
Knn knn = templeConfig.getKnn();
knn.insertMatrix(vector, tagging);
}
private void avgStudy(int tagging, Matrix myMatrix) throws Exception {//特征矩阵均值学习
@ -122,32 +131,38 @@ public class Operation {//进行计算
int xSize = maxX - minX;
int ySize = maxY - minY;
ThreeChannelMatrix threeChannelMatrix1 = convolution.getRegionMatrix(threeChannelMatrix, minX, minY, xSize, ySize);
List<Double> feature = convolution.getCenterColor(threeChannelMatrix1, templeConfig.getPoolSize(),
List<double[]> feature = convolution.getCenterColor(threeChannelMatrix1, templeConfig.getPoolSize(),
templeConfig.getFeatureNub());
if (templeConfig.isShowLog()) {
System.out.println(feature);
}
int classifier = templeConfig.getClassifier();
int id = 0;
switch (classifier) {
case Classifier.LVQ:
Matrix myMatrix = MatrixOperation.listToRowVector(feature);
id = getIdByLVQ(myMatrix);
break;
case Classifier.DNN:
if (templeConfig.getSensoryNerves().size() == templeConfig.getFeatureNub()) {
intoDnnNetwork(IdCreator.get().nextId(), feature, templeConfig.getSensoryNerves(), false, null, maxPoint);
id = maxPoint.getId();
} else {
throw new Exception("nerves number is not equal featureNub");
}
break;
case Classifier.VAvg:
Matrix myMatrix1 = MatrixOperation.listToRowVector(feature);
id = getIdByVag(myMatrix1);
break;
}
regionBody.setType(id);
// int classifier = templeConfig.getClassifier();
// int id = 0;
// switch (classifier) {
// case Classifier.LVQ:
// Matrix myMatrix = MatrixOperation.listToRowVector(feature);
// id = getIdByLVQ(myMatrix);
// break;
// case Classifier.DNN:
// if (templeConfig.getSensoryNerves().size() == templeConfig.getFeatureNub()) {
// intoDnnNetwork(IdCreator.get().nextId(), feature, templeConfig.getSensoryNerves(), false, null, maxPoint);
// id = maxPoint.getId();
// } else {
// throw new Exception("nerves number is not equal featureNub");
// }
// break;
// case Classifier.VAvg:
// Matrix myMatrix1 = MatrixOperation.listToRowVector(feature);
// id = getIdByVag(myMatrix1);
// break;
// case Classifier.KNN:
// Matrix myMatrix2 = MatrixOperation.listToRowVector(feature);
// Knn knn = templeConfig.getKnn();
// id = knn.getType(myMatrix2);
// break;
// }
// regionBody.setType(id);
//System.out.println("类别" + id);
}
return regionList;
}
@ -159,12 +174,13 @@ public class Operation {//进行计算
for (Map.Entry<Integer, Matrix> entry : matrixK.entrySet()) {
Matrix matrix = entry.getValue();
double dist = MatrixOperation.getEDist(matrix, myVector);
//System.out.println("距离===" + dist + ",类别==" + entry.getKey());
//System.out.println("距离===" + dist + ",类别==" + entry.getKey()+",核心:"+matrix.getString());
if (minDist == 0 || dist < minDist) {
minDist = dist;
id = entry.getKey();
}
}
//System.out.println("=======================");
return id;
}

View File

@ -45,7 +45,7 @@ public class TempleConfig {
private double th = -1;//标准阈值
private boolean boxReady = false;//边框已经学习完毕
private double iouTh = 0.5;//IOU阈值
private int lvqNub = 10;//lvq循环次数默认30
private int lvqNub = 100;//lvq循环次数默认30
private VectorK vectorK;//特征向量均值类
private boolean isThreeChannel = false;//是否启用三通道
private int classifier = Classifier.VAvg;//默认分类类别使用的是向量均值分类
@ -53,7 +53,7 @@ public class TempleConfig {
private int sensoryNerveNub = 9;//输入神经元个数
private boolean isShowLog = false;
private ActiveFunction activeFunction = new Tanh();
private double studyPoint = 0;
private double studyPoint = 0.1;
private double matrixWidth = 5;//期望矩阵间隔
private int rzType = RZ.NOT_RZ;//正则化类型默认不进行正则化
private double lParam = 0;//正则参数
@ -64,6 +64,20 @@ public class TempleConfig {
private double hTh = 0.88;//灰度阈值
private double maxRain = 340;//不降雨RGB阈值
private int featureNub = 4;//聚类特征数量
private Knn knn;//KNN分类器
private int knnNub = 5;//KNN投票人数
public Knn getKnn() {
return knn;
}
public int getKnnNub() {
return knnNub;
}
public void setKnnNub(int knnNub) {
this.knnNub = knnNub;
}
public int getFeatureNub() {
return featureNub;
@ -304,7 +318,10 @@ public class TempleConfig {
lvq = new LVQ(classificationNub, lvqNub, studyPoint);
break;
case Classifier.VAvg:
vectorK = new VectorK(sensoryNerveNub);
vectorK = new VectorK(featureNub);
break;
case Classifier.KNN:
knn = new Knn(knnNub);
break;
}
break;

View File

@ -0,0 +1,101 @@
package org.wlld.imageRecognition.border;
import org.wlld.MatrixTools.Matrix;
import org.wlld.MatrixTools.MatrixOperation;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class Knn {//KNN分类器
private Map<Integer, List<Matrix>> featureMap = new HashMap<>();
private int length;//向量长度(需要返回)
private int nub;//选择几个人投票
public Knn(int nub) {
this.nub = nub;
}
public void insertMatrix(Matrix vector, int tag) throws Exception {
if (vector.isVector() && vector.isRowVector()) {
if (featureMap.size() == 0) {
List<Matrix> list = new ArrayList<>();
list.add(vector);
featureMap.put(tag, list);
length = vector.getY();
} else {
if (length == vector.getY()) {
if (featureMap.containsKey(tag)) {
featureMap.get(tag).add(vector);
} else {
List<Matrix> list = new ArrayList<>();
list.add(vector);
featureMap.put(tag, list);
}
} else {
throw new Exception("vector length is different");
}
}
} else {
throw new Exception("this matrix is not vector or rowVector");
}
}
private void compare(double[] values, int[] types, double value, int type) {
for (int i = 0; i < values.length; i++) {
double val = values[i];
if (val < 0) {
values[i] = value;
types[i] = type;
break;
} else {
if (value < val) {
for (int j = values.length - 2; j >= i; j--) {
values[j + 1] = values[j];
types[j + 1] = types[j];
}
values[i] = value;
types[i] = type;
break;
}
}
}
}
public int getType(Matrix vector) throws Exception {//识别分类
int ty = 0;
double[] dists = new double[nub];
int[] types = new int[nub];
for (int i = 0; i < nub; i++) {
dists[i] = -1;
}
for (Map.Entry<Integer, List<Matrix>> entry : featureMap.entrySet()) {
int type = entry.getKey();
List<Matrix> matrices = entry.getValue();
for (Matrix matrix : matrices) {
double dist = MatrixOperation.getEDist(matrix, vector);
compare(dists, types, dist, type);
}
}
Map<Integer, Integer> map = new HashMap<>();
for (int i = 0; i < nub; i++) {
int type = types[i];
if (map.containsKey(type)) {
map.put(type, map.get(type) + 1);
} else {
map.put(type, 1);
}
}
int max = 0;
for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
int value = entry.getValue();
int type = entry.getKey();
if (value > max) {
ty = type;
}
}
return ty;
}
}

View File

@ -87,16 +87,35 @@ public class LVQ {
long type = matrixBody.getId();//类别
double distEnd = 0;
int id = 0;
double dis0 = 0;
double dis1 = 0;
double dis2 = 0;
double dis3 = 0;
for (int i = 0; i < typeNub; i++) {
MatrixBody modelBody = model[i];
Matrix modelMatrix = modelBody.getMatrix();
//修正矩阵与原矩阵的范数差
double dist = vectorEqual(modelMatrix, matrix);
switch (i) {
case 0:
dis0 = dist;
break;
case 1:
dis1 = dist;
break;
case 2:
dis2 = dist;
break;
case 3:
dis3 = dist;
break;
}
if (distEnd == 0 || dist < distEnd) {
id = modelBody.getId();
distEnd = dist;
}
}
System.out.println("type==" + type + ",dist0==" + dis0 + ",dist1==" + dis1 + ",dist2==" + dis2 + ",dist3==" + dis3);
MatrixBody modelBody = model[id];
Matrix modelMatrix = modelBody.getMatrix();
boolean isRight = id == type;
@ -137,6 +156,7 @@ public class LVQ {
}
//初始化完成
for (int i = 0; i < lvqNub; i++) {
System.out.println("================================");
study();
}
isReady = true;

View File

@ -20,7 +20,7 @@ public class Watershed {
private double th = Kernel.th;//灰度阈值
private int regionNub = Kernel.Region_Nub;//一张图分多少份
private Map<Integer, RegionBody> regionBodyMap = new HashMap<>();
private double rainTh = 2;
private double rainTh = 0;
private int xMax;
private int yMax;
private double maxRain;

View File

@ -23,12 +23,12 @@ import java.util.Map;
*/
public class CoverTest {
public static void main(String[] args) throws Exception {
double a = 220;
double b = 220;
double c = 220;
double d = Math.sqrt(Math.pow(a, 2) + Math.pow(b, 2) + Math.pow(c, 2));
System.out.println(d);
//cover();
// double a = 220;
// double b = 220;
// double c = 220;
// double d = Math.sqrt(Math.pow(a, 2) + Math.pow(b, 2) + Math.pow(c, 2));
// System.out.println(d);
cover();
}
public static void insertModel(String model) throws Exception {//注入模型
@ -64,7 +64,7 @@ public class CoverTest {
operation = getModel();
}
for (int i = 1; i < 100; i++) {
String na = "D:\\pic\\test/" + name + i + ".jpg";
String na = "D:\\share\\cai/" + name + i + ".jpg";
//System.out.println("name======================" + na);
ThreeChannelMatrix threeChannelMatrix = picture.getThreeMatrix(na);
Map<Integer, Double> map1 = operation.coverPoint(threeChannelMatrix, poolSize, sqlNub, regionSize);
@ -108,19 +108,28 @@ public class CoverTest {
templeConfig.setSensoryNerveNub(3);//多出来的
templeConfig.setRzType(RZ.L1);//不动//3 18
templeConfig.setlParam(0.015);//不动
templeConfig.init(StudyPattern.Cover_Pattern, true, 400, 400, 2);
templeConfig.setClassifier(Classifier.DNN);
templeConfig.init(StudyPattern.Cover_Pattern, true, 400, 400, 4);
Operation operation = new Operation(templeConfig);
for (int i = 1; i < 100; i++) {
Map<Integer, ThreeChannelMatrix> matrixMap = new HashMap<>();
ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("D:\\pic\\test/b" + i + ".jpg");
ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix("D:\\pic\\test/d" + i + ".jpg");
ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("D:\\share\\cai/a" + i + ".jpg");
ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix("D:\\share\\cai/b" + i + ".jpg");
ThreeChannelMatrix threeChannelMatrix3 = picture.getThreeMatrix("D:\\share\\cai/c" + i + ".jpg");
ThreeChannelMatrix threeChannelMatrix4 = picture.getThreeMatrix("D:\\share\\cai/d" + i + ".jpg");
matrixMap.put(1, threeChannelMatrix1);//桔梗覆盖
matrixMap.put(2, threeChannelMatrix2);//土地
matrixMap.put(3, threeChannelMatrix3);//桔梗覆盖
matrixMap.put(4, threeChannelMatrix4);//土地
operation.coverStudy(matrixMap, 2, 3, 18, 2);
}
ModelParameter modelParameter = templeConfig.getModel();
String model = JSON.toJSONString(modelParameter);
System.out.println(model);
test(operation, 2, 3, 18, "d", 2);
test(operation, 2, 3, 18, "a", 1);
test(operation, 2, 3, 18, "b", 2);
test(operation, 2, 3, 18, "c", 3);
test(operation, 2, 3, 18, "d", 4);
}
}

View File

@ -8,8 +8,6 @@ import org.wlld.config.StudyPattern;
import org.wlld.imageRecognition.*;
import org.wlld.imageRecognition.segmentation.RegionBody;
import org.wlld.imageRecognition.segmentation.Specifications;
import org.wlld.imageRecognition.segmentation.Watershed;
import org.wlld.nerveCenter.NerveManager;
import org.wlld.nerveEntity.ModelParameter;
import org.wlld.tools.ArithUtil;
@ -19,42 +17,51 @@ import java.util.List;
public class FoodTest {
public static void main(String[] args) throws Exception {
//food();
//rain();
test2();
//test2();
test();
}
public static void test2() throws Exception {
public static void test2(TempleConfig templeConfig) throws Exception {
//test();
TempleConfig templeConfig = new TempleConfig();
Picture picture = new Picture();
templeConfig.setSensoryNerveNub(4);
templeConfig.setStudyPoint(0.01);
templeConfig.sethTh(0.86);
templeConfig.setRegionNub(200);
templeConfig.setMaxRain(360);
templeConfig.setSoftMax(true);
List<Specifications> specificationsList = new ArrayList<>();
Specifications specifications = new Specifications();
specifications.setWidth(400);
specifications.setHeight(400);
specificationsList.add(specifications);
templeConfig.setClassifier(Classifier.LVQ);
templeConfig.init(StudyPattern.Cover_Pattern, true, 400, 400, 3);
Operation operation = new Operation(templeConfig);
ThreeChannelMatrix threeChannelMatrix = picture.getThreeMatrix("D:\\cai\\e/e1.jpg");
operation.colorLook(threeChannelMatrix, specificationsList);
for (int i = 1; i <= 10; i++) {
ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("D:\\pic/a" + i + ".jpg");
ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix("D:\\pic/b" + i + ".jpg");
ThreeChannelMatrix threeChannelMatrix3 = picture.getThreeMatrix("D:\\pic/c" + i + ".jpg");
ThreeChannelMatrix threeChannelMatrix4 = picture.getThreeMatrix("D:\\pic/d" + i + ".jpg");
RegionBody regionBody1 = operation.colorLook(threeChannelMatrix1, specificationsList).get(0);
RegionBody regionBody2 = operation.colorLook(threeChannelMatrix2, specificationsList).get(0);
RegionBody regionBody3 = operation.colorLook(threeChannelMatrix3, specificationsList).get(0);
RegionBody regionBody4 = operation.colorLook(threeChannelMatrix4, specificationsList).get(0);
System.out.println("type1==" + regionBody1.getType());
System.out.println("type2==" + regionBody2.getType());
System.out.println("type3==" + regionBody3.getType());
System.out.println("type4==" + regionBody4.getType());
System.out.println("==========================================" + i);
}
}
public static void test() throws Exception {
TempleConfig templeConfig = new TempleConfig();
Picture picture = new Picture();
templeConfig.setSensoryNerveNub(4);
templeConfig.setStudyPoint(0.01);
templeConfig.isShowLog(true);
templeConfig.setMaxRain(320);
templeConfig.setSensoryNerveNub(3);
templeConfig.setSoftMax(true);
templeConfig.setClassifier(Classifier.LVQ);
templeConfig.init(StudyPattern.Cover_Pattern, true, 400, 400, 3);
templeConfig.setFeatureNub(3);
templeConfig.sethTh(0.88);
templeConfig.setPoolSize(2);
templeConfig.setRzType(RZ.L1);
templeConfig.setlParam(0.015);
templeConfig.setClassifier(Classifier.VAvg);
templeConfig.init(StudyPattern.Cover_Pattern, true, 400, 400, 4);
Operation operation = new Operation(templeConfig);
List<Specifications> specificationsList = new ArrayList<>();
Specifications specifications = new Specifications();
@ -63,15 +70,19 @@ public class FoodTest {
specificationsList.add(specifications);
for (int j = 0; j < 1; j++) {
for (int i = 1; i <= 10; i++) {
ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("D:\\cai/a/a" + i + ".jpg");
ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix("D:\\cai/b/b" + i + ".jpg");
ThreeChannelMatrix threeChannelMatrix3 = picture.getThreeMatrix("D:\\cai/c/c" + i + ".jpg");
ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("D:\\pic/a" + i + ".jpg");
ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix("D:\\pic/b" + i + ".jpg");
ThreeChannelMatrix threeChannelMatrix3 = picture.getThreeMatrix("D:\\pic/c" + i + ".jpg");
ThreeChannelMatrix threeChannelMatrix4 = picture.getThreeMatrix("D:\\pic/d" + i + ".jpg");
operation.colorStudy(threeChannelMatrix1, 1, specificationsList);
operation.colorStudy(threeChannelMatrix2, 2, specificationsList);
operation.colorStudy(threeChannelMatrix3, 3, specificationsList);
System.out.println("=======================================");
operation.colorStudy(threeChannelMatrix4, 4, specificationsList);
System.out.println("=======================================" + i);
}
}
//templeConfig.finishStudy();
//test2(templeConfig);
}
public static void study() throws Exception {
@ -89,8 +100,7 @@ public class FoodTest {
Picture picture = new Picture();
Convolution convolution = new Convolution();
ThreeChannelMatrix threeChannelMatrix = picture.getThreeMatrix("");
List<Double> feature = convolution.getCenterColor(threeChannelMatrix, 2, 4);
// List<Double> feature = convolution.getCenterColor(threeChannelMatrix, 2, 4);
}
public static void food() throws Exception {