diff --git a/src/main/java/org/wlld/config/Classifier.java b/src/main/java/org/wlld/config/Classifier.java index 9b218c4..0d9c6ae 100644 --- a/src/main/java/org/wlld/config/Classifier.java +++ b/src/main/java/org/wlld/config/Classifier.java @@ -4,4 +4,5 @@ public class Classifier {//分类器 public static final int LVQ = 1;//LVQ分类 你的训练模版量非常少 比如 一种只有几十一百张照片/分类少 public static final int DNN = 2; //使用DNN分类 训练量足够大,一个种类1500+训练图片 public static final int VAvg = 3;//使用特征向量均值分类 一种只有几十一百张照片 + public static final int KNN = 4;//KNN分类器 } diff --git a/src/main/java/org/wlld/imageRecognition/Convolution.java b/src/main/java/org/wlld/imageRecognition/Convolution.java index 441516d..f378cb2 100644 --- a/src/main/java/org/wlld/imageRecognition/Convolution.java +++ b/src/main/java/org/wlld/imageRecognition/Convolution.java @@ -106,7 +106,7 @@ public class Convolution extends Frequency { return features; } - public List getCenterColor(ThreeChannelMatrix threeChannelMatrix, int poolSize, int sqNub) throws Exception { + public List getCenterColor(ThreeChannelMatrix threeChannelMatrix, int poolSize, int sqNub) throws Exception { Matrix matrixR = threeChannelMatrix.getMatrixR(); Matrix matrixG = threeChannelMatrix.getMatrixG(); Matrix matrixB = threeChannelMatrix.getMatrixB(); @@ -126,9 +126,9 @@ public class Convolution extends Frequency { meanClustering.start(); List rgbNorms = meanClustering.getMatrices(); Collections.sort(rgbNorms, rgbSort); - List feature = new ArrayList<>(); + List feature = new ArrayList<>(); for (int i = 0; i < sqNub; i++) { - feature.add(rgbNorms.get(i).getNorm()); + feature.add(rgbNorms.get(i).getRgb()); } //System.out.println("feature==" + feature); return feature; diff --git a/src/main/java/org/wlld/imageRecognition/Operation.java b/src/main/java/org/wlld/imageRecognition/Operation.java index 02e933c..3a1ef9f 100644 --- a/src/main/java/org/wlld/imageRecognition/Operation.java +++ b/src/main/java/org/wlld/imageRecognition/Operation.java @@ -18,10 +18,7 @@ import org.wlld.nerveEntity.SensoryNerve; import org.wlld.tools.ArithUtil; import org.wlld.tools.IdCreator; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; public class Operation {//进行计算 private TempleConfig templeConfig;//配置初始化参数模板 @@ -70,31 +67,43 @@ public class Operation {//进行计算 int xSize = maxX - minX; int ySize = maxY - minY; ThreeChannelMatrix threeChannelMatrix1 = convolution.getRegionMatrix(threeChannelMatrix, minX, minY, xSize, ySize); - List feature = convolution.getCenterColor(threeChannelMatrix1, templeConfig.getPoolSize(), + List feature = convolution.getCenterColor(threeChannelMatrix1, templeConfig.getPoolSize(), templeConfig.getFeatureNub()); if (templeConfig.isShowLog()) { - System.out.println(feature); + for (double[] f : feature) { + System.out.println(Arrays.toString(f)); + } } + System.out.println("====================================="); int classifier = templeConfig.getClassifier(); - switch (classifier) { - case Classifier.DNN: - Map map = new HashMap<>(); - map.put(tag, 1.0); - if (templeConfig.getSensoryNerves().size() == templeConfig.getFeatureNub()) { - intoDnnNetwork(1, feature, templeConfig.getSensoryNerves(), true, map, null); - } else { - throw new Exception("nerves number is not equal featureNub"); - } - break; - case Classifier.LVQ: - Matrix vector = MatrixOperation.listToRowVector(feature); - lvqStudy(tag, vector); - break; - case Classifier.VAvg: - Matrix vec = MatrixOperation.listToRowVector(feature); - avgStudy(tag, vec); - break; - } +// switch (classifier) { +// case Classifier.DNN: +// Map map = new HashMap<>(); +// map.put(tag, 1.0); +// if (templeConfig.getSensoryNerves().size() == templeConfig.getFeatureNub()) { +// intoDnnNetwork(1, feature, templeConfig.getSensoryNerves(), true, map, null); +// } else { +// throw new Exception("nerves number is not equal featureNub"); +// } +// break; +// case Classifier.LVQ: +// Matrix vector = MatrixOperation.listToRowVector(feature); +// lvqStudy(tag, vector); +// break; +// case Classifier.VAvg: +// Matrix vec = MatrixOperation.listToRowVector(feature); +// avgStudy(tag, vec); +// break; +// case Classifier.KNN: +// Matrix veck = MatrixOperation.listToRowVector(feature); +// knnStudy(tag, veck); +// break; +// } + } + + private void knnStudy(int tagging, Matrix vector) throws Exception { + Knn knn = templeConfig.getKnn(); + knn.insertMatrix(vector, tagging); } private void avgStudy(int tagging, Matrix myMatrix) throws Exception {//特征矩阵均值学习 @@ -122,32 +131,38 @@ public class Operation {//进行计算 int xSize = maxX - minX; int ySize = maxY - minY; ThreeChannelMatrix threeChannelMatrix1 = convolution.getRegionMatrix(threeChannelMatrix, minX, minY, xSize, ySize); - List feature = convolution.getCenterColor(threeChannelMatrix1, templeConfig.getPoolSize(), + List feature = convolution.getCenterColor(threeChannelMatrix1, templeConfig.getPoolSize(), templeConfig.getFeatureNub()); if (templeConfig.isShowLog()) { System.out.println(feature); } - int classifier = templeConfig.getClassifier(); - int id = 0; - switch (classifier) { - case Classifier.LVQ: - Matrix myMatrix = MatrixOperation.listToRowVector(feature); - id = getIdByLVQ(myMatrix); - break; - case Classifier.DNN: - if (templeConfig.getSensoryNerves().size() == templeConfig.getFeatureNub()) { - intoDnnNetwork(IdCreator.get().nextId(), feature, templeConfig.getSensoryNerves(), false, null, maxPoint); - id = maxPoint.getId(); - } else { - throw new Exception("nerves number is not equal featureNub"); - } - break; - case Classifier.VAvg: - Matrix myMatrix1 = MatrixOperation.listToRowVector(feature); - id = getIdByVag(myMatrix1); - break; - } - regionBody.setType(id); +// int classifier = templeConfig.getClassifier(); +// int id = 0; +// switch (classifier) { +// case Classifier.LVQ: +// Matrix myMatrix = MatrixOperation.listToRowVector(feature); +// id = getIdByLVQ(myMatrix); +// break; +// case Classifier.DNN: +// if (templeConfig.getSensoryNerves().size() == templeConfig.getFeatureNub()) { +// intoDnnNetwork(IdCreator.get().nextId(), feature, templeConfig.getSensoryNerves(), false, null, maxPoint); +// id = maxPoint.getId(); +// } else { +// throw new Exception("nerves number is not equal featureNub"); +// } +// break; +// case Classifier.VAvg: +// Matrix myMatrix1 = MatrixOperation.listToRowVector(feature); +// id = getIdByVag(myMatrix1); +// break; +// case Classifier.KNN: +// Matrix myMatrix2 = MatrixOperation.listToRowVector(feature); +// Knn knn = templeConfig.getKnn(); +// id = knn.getType(myMatrix2); +// break; +// } +// regionBody.setType(id); + //System.out.println("类别" + id); } return regionList; } @@ -159,12 +174,13 @@ public class Operation {//进行计算 for (Map.Entry entry : matrixK.entrySet()) { Matrix matrix = entry.getValue(); double dist = MatrixOperation.getEDist(matrix, myVector); - //System.out.println("距离===" + dist + ",类别==" + entry.getKey()); + //System.out.println("距离===" + dist + ",类别==" + entry.getKey()+",核心:"+matrix.getString()); if (minDist == 0 || dist < minDist) { minDist = dist; id = entry.getKey(); } } + //System.out.println("======================="); return id; } diff --git a/src/main/java/org/wlld/imageRecognition/TempleConfig.java b/src/main/java/org/wlld/imageRecognition/TempleConfig.java index 84d3416..7b2dbcd 100644 --- a/src/main/java/org/wlld/imageRecognition/TempleConfig.java +++ b/src/main/java/org/wlld/imageRecognition/TempleConfig.java @@ -45,7 +45,7 @@ public class TempleConfig { private double th = -1;//标准阈值 private boolean boxReady = false;//边框已经学习完毕 private double iouTh = 0.5;//IOU阈值 - private int lvqNub = 10;//lvq循环次数,默认30 + private int lvqNub = 100;//lvq循环次数,默认30 private VectorK vectorK;//特征向量均值类 private boolean isThreeChannel = false;//是否启用三通道 private int classifier = Classifier.VAvg;//默认分类类别使用的是向量均值分类 @@ -53,7 +53,7 @@ public class TempleConfig { private int sensoryNerveNub = 9;//输入神经元个数 private boolean isShowLog = false; private ActiveFunction activeFunction = new Tanh(); - private double studyPoint = 0; + private double studyPoint = 0.1; private double matrixWidth = 5;//期望矩阵间隔 private int rzType = RZ.NOT_RZ;//正则化类型,默认不进行正则化 private double lParam = 0;//正则参数 @@ -64,6 +64,20 @@ public class TempleConfig { private double hTh = 0.88;//灰度阈值 private double maxRain = 340;//不降雨RGB阈值 private int featureNub = 4;//聚类特征数量 + private Knn knn;//KNN分类器 + private int knnNub = 5;//KNN投票人数 + + public Knn getKnn() { + return knn; + } + + public int getKnnNub() { + return knnNub; + } + + public void setKnnNub(int knnNub) { + this.knnNub = knnNub; + } public int getFeatureNub() { return featureNub; @@ -304,7 +318,10 @@ public class TempleConfig { lvq = new LVQ(classificationNub, lvqNub, studyPoint); break; case Classifier.VAvg: - vectorK = new VectorK(sensoryNerveNub); + vectorK = new VectorK(featureNub); + break; + case Classifier.KNN: + knn = new Knn(knnNub); break; } break; diff --git a/src/main/java/org/wlld/imageRecognition/border/Knn.java b/src/main/java/org/wlld/imageRecognition/border/Knn.java new file mode 100644 index 0000000..f9c2d49 --- /dev/null +++ b/src/main/java/org/wlld/imageRecognition/border/Knn.java @@ -0,0 +1,101 @@ +package org.wlld.imageRecognition.border; + +import org.wlld.MatrixTools.Matrix; +import org.wlld.MatrixTools.MatrixOperation; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class Knn {//KNN分类器 + private Map> featureMap = new HashMap<>(); + private int length;//向量长度(需要返回) + private int nub;//选择几个人投票 + + public Knn(int nub) { + this.nub = nub; + } + + public void insertMatrix(Matrix vector, int tag) throws Exception { + if (vector.isVector() && vector.isRowVector()) { + if (featureMap.size() == 0) { + List list = new ArrayList<>(); + list.add(vector); + featureMap.put(tag, list); + length = vector.getY(); + } else { + if (length == vector.getY()) { + if (featureMap.containsKey(tag)) { + featureMap.get(tag).add(vector); + } else { + List list = new ArrayList<>(); + list.add(vector); + featureMap.put(tag, list); + } + } else { + throw new Exception("vector length is different"); + } + } + } else { + throw new Exception("this matrix is not vector or rowVector"); + } + } + + private void compare(double[] values, int[] types, double value, int type) { + for (int i = 0; i < values.length; i++) { + double val = values[i]; + if (val < 0) { + values[i] = value; + types[i] = type; + break; + } else { + if (value < val) { + for (int j = values.length - 2; j >= i; j--) { + values[j + 1] = values[j]; + types[j + 1] = types[j]; + } + values[i] = value; + types[i] = type; + break; + } + } + } + } + + public int getType(Matrix vector) throws Exception {//识别分类 + int ty = 0; + double[] dists = new double[nub]; + int[] types = new int[nub]; + for (int i = 0; i < nub; i++) { + dists[i] = -1; + } + for (Map.Entry> entry : featureMap.entrySet()) { + int type = entry.getKey(); + List matrices = entry.getValue(); + for (Matrix matrix : matrices) { + double dist = MatrixOperation.getEDist(matrix, vector); + compare(dists, types, dist, type); + } + } + Map map = new HashMap<>(); + for (int i = 0; i < nub; i++) { + int type = types[i]; + if (map.containsKey(type)) { + map.put(type, map.get(type) + 1); + } else { + map.put(type, 1); + } + } + int max = 0; + for (Map.Entry entry : map.entrySet()) { + int value = entry.getValue(); + int type = entry.getKey(); + if (value > max) { + ty = type; + } + } + return ty; + } + +} diff --git a/src/main/java/org/wlld/imageRecognition/border/LVQ.java b/src/main/java/org/wlld/imageRecognition/border/LVQ.java index 988a72f..5bf6106 100644 --- a/src/main/java/org/wlld/imageRecognition/border/LVQ.java +++ b/src/main/java/org/wlld/imageRecognition/border/LVQ.java @@ -87,16 +87,35 @@ public class LVQ { long type = matrixBody.getId();//类别 double distEnd = 0; int id = 0; + double dis0 = 0; + double dis1 = 0; + double dis2 = 0; + double dis3 = 0; for (int i = 0; i < typeNub; i++) { MatrixBody modelBody = model[i]; Matrix modelMatrix = modelBody.getMatrix(); //修正矩阵与原矩阵的范数差 double dist = vectorEqual(modelMatrix, matrix); + switch (i) { + case 0: + dis0 = dist; + break; + case 1: + dis1 = dist; + break; + case 2: + dis2 = dist; + break; + case 3: + dis3 = dist; + break; + } if (distEnd == 0 || dist < distEnd) { id = modelBody.getId(); distEnd = dist; } } + System.out.println("type==" + type + ",dist0==" + dis0 + ",dist1==" + dis1 + ",dist2==" + dis2 + ",dist3==" + dis3); MatrixBody modelBody = model[id]; Matrix modelMatrix = modelBody.getMatrix(); boolean isRight = id == type; @@ -137,6 +156,7 @@ public class LVQ { } //初始化完成 for (int i = 0; i < lvqNub; i++) { + System.out.println("================================"); study(); } isReady = true; diff --git a/src/main/java/org/wlld/imageRecognition/segmentation/Watershed.java b/src/main/java/org/wlld/imageRecognition/segmentation/Watershed.java index 507f250..e473024 100644 --- a/src/main/java/org/wlld/imageRecognition/segmentation/Watershed.java +++ b/src/main/java/org/wlld/imageRecognition/segmentation/Watershed.java @@ -20,7 +20,7 @@ public class Watershed { private double th = Kernel.th;//灰度阈值 private int regionNub = Kernel.Region_Nub;//一张图分多少份 private Map regionBodyMap = new HashMap<>(); - private double rainTh = 2; + private double rainTh = 0; private int xMax; private int yMax; private double maxRain; diff --git a/src/test/java/coverTest/CoverTest.java b/src/test/java/coverTest/CoverTest.java index 1a6461a..59be1be 100644 --- a/src/test/java/coverTest/CoverTest.java +++ b/src/test/java/coverTest/CoverTest.java @@ -23,12 +23,12 @@ import java.util.Map; */ public class CoverTest { public static void main(String[] args) throws Exception { - double a = 220; - double b = 220; - double c = 220; - double d = Math.sqrt(Math.pow(a, 2) + Math.pow(b, 2) + Math.pow(c, 2)); - System.out.println(d); - //cover(); +// double a = 220; +// double b = 220; +// double c = 220; +// double d = Math.sqrt(Math.pow(a, 2) + Math.pow(b, 2) + Math.pow(c, 2)); +// System.out.println(d); + cover(); } public static void insertModel(String model) throws Exception {//注入模型 @@ -64,7 +64,7 @@ public class CoverTest { operation = getModel(); } for (int i = 1; i < 100; i++) { - String na = "D:\\pic\\test/" + name + i + ".jpg"; + String na = "D:\\share\\cai/" + name + i + ".jpg"; //System.out.println("name======================" + na); ThreeChannelMatrix threeChannelMatrix = picture.getThreeMatrix(na); Map map1 = operation.coverPoint(threeChannelMatrix, poolSize, sqlNub, regionSize); @@ -108,19 +108,28 @@ public class CoverTest { templeConfig.setSensoryNerveNub(3);//多出来的 templeConfig.setRzType(RZ.L1);//不动//3 18 templeConfig.setlParam(0.015);//不动 - templeConfig.init(StudyPattern.Cover_Pattern, true, 400, 400, 2); + templeConfig.setClassifier(Classifier.DNN); + templeConfig.init(StudyPattern.Cover_Pattern, true, 400, 400, 4); Operation operation = new Operation(templeConfig); for (int i = 1; i < 100; i++) { Map matrixMap = new HashMap<>(); - ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("D:\\pic\\test/b" + i + ".jpg"); - ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix("D:\\pic\\test/d" + i + ".jpg"); + ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("D:\\share\\cai/a" + i + ".jpg"); + ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix("D:\\share\\cai/b" + i + ".jpg"); + ThreeChannelMatrix threeChannelMatrix3 = picture.getThreeMatrix("D:\\share\\cai/c" + i + ".jpg"); + ThreeChannelMatrix threeChannelMatrix4 = picture.getThreeMatrix("D:\\share\\cai/d" + i + ".jpg"); matrixMap.put(1, threeChannelMatrix1);//桔梗覆盖 matrixMap.put(2, threeChannelMatrix2);//土地 + matrixMap.put(3, threeChannelMatrix3);//桔梗覆盖 + matrixMap.put(4, threeChannelMatrix4);//土地 operation.coverStudy(matrixMap, 2, 3, 18, 2); } ModelParameter modelParameter = templeConfig.getModel(); String model = JSON.toJSONString(modelParameter); System.out.println(model); - test(operation, 2, 3, 18, "d", 2); + test(operation, 2, 3, 18, "a", 1); + test(operation, 2, 3, 18, "b", 2); + test(operation, 2, 3, 18, "c", 3); + test(operation, 2, 3, 18, "d", 4); + } } diff --git a/src/test/java/coverTest/FoodTest.java b/src/test/java/coverTest/FoodTest.java index 4ec37f8..abe74d3 100644 --- a/src/test/java/coverTest/FoodTest.java +++ b/src/test/java/coverTest/FoodTest.java @@ -8,8 +8,6 @@ import org.wlld.config.StudyPattern; import org.wlld.imageRecognition.*; import org.wlld.imageRecognition.segmentation.RegionBody; import org.wlld.imageRecognition.segmentation.Specifications; -import org.wlld.imageRecognition.segmentation.Watershed; -import org.wlld.nerveCenter.NerveManager; import org.wlld.nerveEntity.ModelParameter; import org.wlld.tools.ArithUtil; @@ -19,42 +17,51 @@ import java.util.List; public class FoodTest { public static void main(String[] args) throws Exception { - //food(); - //rain(); - test2(); + //test2(); + test(); } - public static void test2() throws Exception { + public static void test2(TempleConfig templeConfig) throws Exception { //test(); - TempleConfig templeConfig = new TempleConfig(); Picture picture = new Picture(); - templeConfig.setSensoryNerveNub(4); - templeConfig.setStudyPoint(0.01); - templeConfig.sethTh(0.86); - templeConfig.setRegionNub(200); - templeConfig.setMaxRain(360); - templeConfig.setSoftMax(true); List specificationsList = new ArrayList<>(); Specifications specifications = new Specifications(); specifications.setWidth(400); specifications.setHeight(400); specificationsList.add(specifications); - templeConfig.setClassifier(Classifier.LVQ); - templeConfig.init(StudyPattern.Cover_Pattern, true, 400, 400, 3); Operation operation = new Operation(templeConfig); - ThreeChannelMatrix threeChannelMatrix = picture.getThreeMatrix("D:\\cai\\e/e1.jpg"); - operation.colorLook(threeChannelMatrix, specificationsList); + for (int i = 1; i <= 10; i++) { + ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("D:\\pic/a" + i + ".jpg"); + ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix("D:\\pic/b" + i + ".jpg"); + ThreeChannelMatrix threeChannelMatrix3 = picture.getThreeMatrix("D:\\pic/c" + i + ".jpg"); + ThreeChannelMatrix threeChannelMatrix4 = picture.getThreeMatrix("D:\\pic/d" + i + ".jpg"); + RegionBody regionBody1 = operation.colorLook(threeChannelMatrix1, specificationsList).get(0); + RegionBody regionBody2 = operation.colorLook(threeChannelMatrix2, specificationsList).get(0); + RegionBody regionBody3 = operation.colorLook(threeChannelMatrix3, specificationsList).get(0); + RegionBody regionBody4 = operation.colorLook(threeChannelMatrix4, specificationsList).get(0); + System.out.println("type1==" + regionBody1.getType()); + System.out.println("type2==" + regionBody2.getType()); + System.out.println("type3==" + regionBody3.getType()); + System.out.println("type4==" + regionBody4.getType()); + System.out.println("==========================================" + i); + } } public static void test() throws Exception { TempleConfig templeConfig = new TempleConfig(); Picture picture = new Picture(); - templeConfig.setSensoryNerveNub(4); templeConfig.setStudyPoint(0.01); + templeConfig.isShowLog(true); + templeConfig.setMaxRain(320); + templeConfig.setSensoryNerveNub(3); templeConfig.setSoftMax(true); - - templeConfig.setClassifier(Classifier.LVQ); - templeConfig.init(StudyPattern.Cover_Pattern, true, 400, 400, 3); + templeConfig.setFeatureNub(3); + templeConfig.sethTh(0.88); + templeConfig.setPoolSize(2); + templeConfig.setRzType(RZ.L1); + templeConfig.setlParam(0.015); + templeConfig.setClassifier(Classifier.VAvg); + templeConfig.init(StudyPattern.Cover_Pattern, true, 400, 400, 4); Operation operation = new Operation(templeConfig); List specificationsList = new ArrayList<>(); Specifications specifications = new Specifications(); @@ -63,15 +70,19 @@ public class FoodTest { specificationsList.add(specifications); for (int j = 0; j < 1; j++) { for (int i = 1; i <= 10; i++) { - ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("D:\\cai/a/a" + i + ".jpg"); - ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix("D:\\cai/b/b" + i + ".jpg"); - ThreeChannelMatrix threeChannelMatrix3 = picture.getThreeMatrix("D:\\cai/c/c" + i + ".jpg"); + ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("D:\\pic/a" + i + ".jpg"); + ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix("D:\\pic/b" + i + ".jpg"); + ThreeChannelMatrix threeChannelMatrix3 = picture.getThreeMatrix("D:\\pic/c" + i + ".jpg"); + ThreeChannelMatrix threeChannelMatrix4 = picture.getThreeMatrix("D:\\pic/d" + i + ".jpg"); operation.colorStudy(threeChannelMatrix1, 1, specificationsList); operation.colorStudy(threeChannelMatrix2, 2, specificationsList); operation.colorStudy(threeChannelMatrix3, 3, specificationsList); - System.out.println("======================================="); + operation.colorStudy(threeChannelMatrix4, 4, specificationsList); + System.out.println("=======================================" + i); } } + //templeConfig.finishStudy(); + //test2(templeConfig); } public static void study() throws Exception { @@ -89,8 +100,7 @@ public class FoodTest { Picture picture = new Picture(); Convolution convolution = new Convolution(); ThreeChannelMatrix threeChannelMatrix = picture.getThreeMatrix(""); - List feature = convolution.getCenterColor(threeChannelMatrix, 2, 4); - + // List feature = convolution.getCenterColor(threeChannelMatrix, 2, 4); } public static void food() throws Exception {