增加KNN的模型注入与获取

This commit is contained in:
thenk008 2020-05-23 13:49:18 +08:00
parent 2dac3dadf9
commit 59500f4388
5 changed files with 48 additions and 1 deletions

View File

@ -188,7 +188,7 @@ public class Operation {//进行计算
for (Map.Entry<Integer, Matrix> entry : matrixK.entrySet()) {
Matrix matrix = entry.getValue();
double dist = MatrixOperation.getEDist(matrix, myVector);
System.out.println("距离===" + dist + ",类别==" + entry.getKey()+",核心:"+matrix.getString());
//System.out.println("距离===" + dist + ",类别==" + entry.getKey()+",核心:"+matrix.getString());
if (minDist == 0 || dist < minDist) {
minDist = dist;
id = entry.getKey();

View File

@ -542,6 +542,22 @@ public class TempleConfig {
modelParameter.setMatrixK(map);
}
break;
case Classifier.KNN:
if (knn != null) {
Map<Integer, List<Matrix>> listMap = knn.getFeatureMap();
Map<Integer, List<List<Double>>> knnVector = new HashMap<>();
for (Map.Entry<Integer, List<Matrix>> entry : listMap.entrySet()) {
List<Matrix> list = entry.getValue();
List<List<Double>> listFeature = new ArrayList<>();
for (Matrix matrix : list) {
List<Double> list1 = MatrixOperation.rowVectorToList(matrix);
listFeature.add(list1);
}
knnVector.put(entry.getKey(), listFeature);
}
modelParameter.setKnnVector(knnVector);
}
break;
}
}
@ -651,6 +667,19 @@ public class TempleConfig {
case Classifier.DNN:
nerveManager.insertModelParameter(modelParameter);
break;
case Classifier.KNN:
Map<Integer, List<List<Double>>> knnVector = modelParameter.getKnnVector();
if (knn != null && knnVector != null) {
for (Map.Entry<Integer, List<List<Double>>> entry : knnVector.entrySet()) {
List<List<Double>> featureList = entry.getValue();
int type = entry.getKey();
for (List<Double> list : featureList) {
Matrix matrix = MatrixOperation.listToRowVector(list);
knn.insertMatrix(matrix, type);
}
}
}
break;
}
}
if (modelParameter.getFrame() != null) {

View File

@ -17,6 +17,14 @@ public class Knn {//KNN分类器
this.nub = nub;
}
public Map<Integer, List<Matrix>> getFeatureMap() {
return featureMap;
}
public void removeType(int type) {
featureMap.remove(type);
}
public void insertMatrix(Matrix vector, int tag) throws Exception {
if (vector.isVector() && vector.isRowVector()) {
if (featureMap.size() == 0) {

View File

@ -23,9 +23,18 @@ public class ModelParameter {
private Map<Integer, KBorder> borderMap = new HashMap<>();//边框距离模型
private LvqModel lvqModel;//LVQ模型
private Map<Integer, List<Double>> matrixK = new HashMap<>();//均值特征向量
private Map<Integer, List<List<Double>>> knnVector;//Knn模型
private Frame frame;//先验边框
private double dnnAvg;//
public Map<Integer, List<List<Double>>> getKnnVector() {
return knnVector;
}
public void setKnnVector(Map<Integer, List<List<Double>>> knnVector) {
this.knnVector = knnVector;
}
public double getDnnAvg() {
return dnnAvg;
}

View File

@ -55,6 +55,7 @@ public class FoodTest {
templeConfig.setMaxRain(320);//切割阈值
templeConfig.setFeatureNub(4);
templeConfig.sethTh(0.88);
templeConfig.setKnnNub(7);
templeConfig.setPoolSize(2);
templeConfig.setRegionNub(200);
templeConfig.setClassifier(Classifier.KNN);