mirror of
https://gitee.com/dromara/easyAi.git
synced 2024-12-02 03:38:08 +08:00
使原型聚类特征依范数进行排序
This commit is contained in:
parent
72a6fb0eb7
commit
de0e2f9c6b
@ -1,7 +1,6 @@
|
|||||||
package org.wlld.entity;
|
package org.wlld.entity;
|
||||||
|
|
||||||
import org.wlld.MatrixTools.Matrix;
|
import org.wlld.MatrixTools.Matrix;
|
||||||
import org.wlld.tools.ArithUtil;
|
|
||||||
import org.wlld.tools.RgbRegression;
|
import org.wlld.tools.RgbRegression;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
@ -262,7 +261,7 @@ public class RGBNorm {
|
|||||||
double sigma = 0;
|
double sigma = 0;
|
||||||
if (nub > 0) {
|
if (nub > 0) {
|
||||||
for (int i = 0; i < rgb.length; i++) {
|
for (int i = 0; i < rgb.length; i++) {
|
||||||
double rgbc = ArithUtil.div(rgbAll[i], nub);
|
double rgbc = rgbAll[i] / nub;
|
||||||
rgb[i] = rgbc;
|
rgb[i] = rgbc;
|
||||||
sigma = sigma + Math.pow(rgbc, 2);
|
sigma = sigma + Math.pow(rgbc, 2);
|
||||||
}
|
}
|
||||||
|
@ -21,8 +21,8 @@ public class GMClustering extends MeanClustering {
|
|||||||
this.regionSize = regionSize;
|
this.regionSize = regionSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
public GMClustering(int speciesQuantity) throws Exception {
|
public GMClustering(int speciesQuantity, int maxTimes) throws Exception {
|
||||||
super(speciesQuantity);
|
super(speciesQuantity, maxTimes);
|
||||||
}
|
}
|
||||||
|
|
||||||
public double getProbabilityDensity(double[] feature) throws Exception {//获取总概率密度
|
public double getProbabilityDensity(double[] feature) throws Exception {//获取总概率密度
|
||||||
|
@ -4,6 +4,7 @@ package org.wlld.tools;
|
|||||||
import org.wlld.entity.RGBNorm;
|
import org.wlld.entity.RGBNorm;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
@ -12,23 +13,39 @@ public class MeanClustering {
|
|||||||
protected List<double[]> matrixList = new ArrayList<>();//聚类集合
|
protected List<double[]> matrixList = new ArrayList<>();//聚类集合
|
||||||
private int length;//向量长度(模型需要返回)
|
private int length;//向量长度(模型需要返回)
|
||||||
protected int speciesQuantity;//种类数量(模型需要返回)
|
protected int speciesQuantity;//种类数量(模型需要返回)
|
||||||
|
private final int maxTimes;//最大迭代次数
|
||||||
protected List<RGBNorm> matrices = new ArrayList<>();//均值K模型(模型需要返回)
|
protected List<RGBNorm> matrices = new ArrayList<>();//均值K模型(模型需要返回)
|
||||||
|
|
||||||
public List<RGBNorm> getMatrices() {
|
public List<RGBNorm> getMatrices() {
|
||||||
return matrices;
|
return matrices;
|
||||||
}
|
}
|
||||||
|
|
||||||
public MeanClustering(int speciesQuantity) throws Exception {
|
public double[] getResultByNorm() {
|
||||||
this.speciesQuantity = speciesQuantity;//聚类的数量
|
MeanSort meanSort = new MeanSort();
|
||||||
|
double[] dm = new double[matrices.size() * length];
|
||||||
|
matrices.sort(meanSort);
|
||||||
|
for (int i = 0; i < matrices.size(); i++) {
|
||||||
|
RGBNorm rgbNorm = matrices.get(i);
|
||||||
|
double[] rgb = rgbNorm.getRgb();
|
||||||
|
for (int j = 0; j < rgb.length; j++) {
|
||||||
|
dm[i * rgb.length + j] = rgb[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dm;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setColor(double[] color) throws Exception {
|
public MeanClustering(int speciesQuantity, int maxTimes) throws Exception {
|
||||||
if (matrixList.size() == 0) {
|
this.speciesQuantity = speciesQuantity;//聚类的数量
|
||||||
matrixList.add(color);
|
this.maxTimes = maxTimes;
|
||||||
length = color.length;
|
}
|
||||||
|
|
||||||
|
public void setFeature(double[] feature) throws Exception {
|
||||||
|
if (matrixList.isEmpty()) {
|
||||||
|
matrixList.add(feature);
|
||||||
|
length = feature.length;
|
||||||
} else {
|
} else {
|
||||||
if (length == color.length) {
|
if (length == feature.length) {
|
||||||
matrixList.add(color);
|
matrixList.add(feature);
|
||||||
} else {
|
} else {
|
||||||
throw new Exception("vector length is different");
|
throw new Exception("vector length is different");
|
||||||
}
|
}
|
||||||
@ -86,11 +103,11 @@ public class MeanClustering {
|
|||||||
}
|
}
|
||||||
//进行两者的比较
|
//进行两者的比较
|
||||||
boolean isNext;
|
boolean isNext;
|
||||||
for (int i = 0; i < 60; i++) {
|
for (int i = 0; i < maxTimes; i++) {
|
||||||
//System.out.println("聚类:" + i);
|
//System.out.println("聚类:" + i);
|
||||||
averageMatrix();
|
averageMatrix();
|
||||||
isNext = isNext();
|
isNext = isNext();
|
||||||
if (isNext && i < 59) {
|
if (isNext && i < maxTimes - 1) {
|
||||||
clear();
|
clear();
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
|
17
src/main/java/org/wlld/tools/MeanSort.java
Normal file
17
src/main/java/org/wlld/tools/MeanSort.java
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
package org.wlld.tools;
|
||||||
|
|
||||||
|
import org.wlld.entity.RGBNorm;
|
||||||
|
|
||||||
|
import java.util.Comparator;
|
||||||
|
|
||||||
|
public class MeanSort implements Comparator<RGBNorm> {
|
||||||
|
@Override
|
||||||
|
public int compare(RGBNorm o1, RGBNorm o2) {
|
||||||
|
if (o1.getNorm() > o2.getNorm()) {
|
||||||
|
return 1;
|
||||||
|
} else if (o1.getNorm() < o2.getNorm()) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user