mirror of
https://gitee.com/dromara/easyAi.git
synced 2024-12-02 11:48:08 +08:00
增加KNN的模型注入与获取
This commit is contained in:
parent
2dac3dadf9
commit
59500f4388
@ -188,7 +188,7 @@ public class Operation {//进行计算
|
||||
for (Map.Entry<Integer, Matrix> entry : matrixK.entrySet()) {
|
||||
Matrix matrix = entry.getValue();
|
||||
double dist = MatrixOperation.getEDist(matrix, myVector);
|
||||
System.out.println("距离===" + dist + ",类别==" + entry.getKey()+",核心:"+matrix.getString());
|
||||
//System.out.println("距离===" + dist + ",类别==" + entry.getKey()+",核心:"+matrix.getString());
|
||||
if (minDist == 0 || dist < minDist) {
|
||||
minDist = dist;
|
||||
id = entry.getKey();
|
||||
|
@ -542,6 +542,22 @@ public class TempleConfig {
|
||||
modelParameter.setMatrixK(map);
|
||||
}
|
||||
break;
|
||||
case Classifier.KNN:
|
||||
if (knn != null) {
|
||||
Map<Integer, List<Matrix>> listMap = knn.getFeatureMap();
|
||||
Map<Integer, List<List<Double>>> knnVector = new HashMap<>();
|
||||
for (Map.Entry<Integer, List<Matrix>> entry : listMap.entrySet()) {
|
||||
List<Matrix> list = entry.getValue();
|
||||
List<List<Double>> listFeature = new ArrayList<>();
|
||||
for (Matrix matrix : list) {
|
||||
List<Double> list1 = MatrixOperation.rowVectorToList(matrix);
|
||||
listFeature.add(list1);
|
||||
}
|
||||
knnVector.put(entry.getKey(), listFeature);
|
||||
}
|
||||
modelParameter.setKnnVector(knnVector);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
@ -651,6 +667,19 @@ public class TempleConfig {
|
||||
case Classifier.DNN:
|
||||
nerveManager.insertModelParameter(modelParameter);
|
||||
break;
|
||||
case Classifier.KNN:
|
||||
Map<Integer, List<List<Double>>> knnVector = modelParameter.getKnnVector();
|
||||
if (knn != null && knnVector != null) {
|
||||
for (Map.Entry<Integer, List<List<Double>>> entry : knnVector.entrySet()) {
|
||||
List<List<Double>> featureList = entry.getValue();
|
||||
int type = entry.getKey();
|
||||
for (List<Double> list : featureList) {
|
||||
Matrix matrix = MatrixOperation.listToRowVector(list);
|
||||
knn.insertMatrix(matrix, type);
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (modelParameter.getFrame() != null) {
|
||||
|
@ -17,6 +17,14 @@ public class Knn {//KNN分类器
|
||||
this.nub = nub;
|
||||
}
|
||||
|
||||
public Map<Integer, List<Matrix>> getFeatureMap() {
|
||||
return featureMap;
|
||||
}
|
||||
|
||||
public void removeType(int type) {
|
||||
featureMap.remove(type);
|
||||
}
|
||||
|
||||
public void insertMatrix(Matrix vector, int tag) throws Exception {
|
||||
if (vector.isVector() && vector.isRowVector()) {
|
||||
if (featureMap.size() == 0) {
|
||||
|
@ -23,9 +23,18 @@ public class ModelParameter {
|
||||
private Map<Integer, KBorder> borderMap = new HashMap<>();//边框距离模型
|
||||
private LvqModel lvqModel;//LVQ模型
|
||||
private Map<Integer, List<Double>> matrixK = new HashMap<>();//均值特征向量
|
||||
private Map<Integer, List<List<Double>>> knnVector;//Knn模型
|
||||
private Frame frame;//先验边框
|
||||
private double dnnAvg;//
|
||||
|
||||
public Map<Integer, List<List<Double>>> getKnnVector() {
|
||||
return knnVector;
|
||||
}
|
||||
|
||||
public void setKnnVector(Map<Integer, List<List<Double>>> knnVector) {
|
||||
this.knnVector = knnVector;
|
||||
}
|
||||
|
||||
public double getDnnAvg() {
|
||||
return dnnAvg;
|
||||
}
|
||||
|
@ -55,6 +55,7 @@ public class FoodTest {
|
||||
templeConfig.setMaxRain(320);//切割阈值
|
||||
templeConfig.setFeatureNub(4);
|
||||
templeConfig.sethTh(0.88);
|
||||
templeConfig.setKnnNub(7);
|
||||
templeConfig.setPoolSize(2);
|
||||
templeConfig.setRegionNub(200);
|
||||
templeConfig.setClassifier(Classifier.KNN);
|
||||
|
Loading…
Reference in New Issue
Block a user