使原型聚类特征依范数进行排序

This commit is contained in:
李大鹏 2024-04-11 14:45:16 +08:00
parent 72a6fb0eb7
commit de0e2f9c6b
4 changed files with 47 additions and 14 deletions

View File

@ -1,7 +1,6 @@
package org.wlld.entity; package org.wlld.entity;
import org.wlld.MatrixTools.Matrix; import org.wlld.MatrixTools.Matrix;
import org.wlld.tools.ArithUtil;
import org.wlld.tools.RgbRegression; import org.wlld.tools.RgbRegression;
import java.util.*; import java.util.*;
@ -262,7 +261,7 @@ public class RGBNorm {
double sigma = 0; double sigma = 0;
if (nub > 0) { if (nub > 0) {
for (int i = 0; i < rgb.length; i++) { for (int i = 0; i < rgb.length; i++) {
double rgbc = ArithUtil.div(rgbAll[i], nub); double rgbc = rgbAll[i] / nub;
rgb[i] = rgbc; rgb[i] = rgbc;
sigma = sigma + Math.pow(rgbc, 2); sigma = sigma + Math.pow(rgbc, 2);
} }

View File

@ -21,8 +21,8 @@ public class GMClustering extends MeanClustering {
this.regionSize = regionSize; this.regionSize = regionSize;
} }
public GMClustering(int speciesQuantity) throws Exception { public GMClustering(int speciesQuantity, int maxTimes) throws Exception {
super(speciesQuantity); super(speciesQuantity, maxTimes);
} }
public double getProbabilityDensity(double[] feature) throws Exception {//获取总概率密度 public double getProbabilityDensity(double[] feature) throws Exception {//获取总概率密度

View File

@ -4,6 +4,7 @@ package org.wlld.tools;
import org.wlld.entity.RGBNorm; import org.wlld.entity.RGBNorm;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
@ -12,23 +13,39 @@ public class MeanClustering {
protected List<double[]> matrixList = new ArrayList<>();//聚类集合 protected List<double[]> matrixList = new ArrayList<>();//聚类集合
private int length;//向量长度(模型需要返回) private int length;//向量长度(模型需要返回)
protected int speciesQuantity;//种类数量(模型需要返回) protected int speciesQuantity;//种类数量(模型需要返回)
private final int maxTimes;//最大迭代次数
protected List<RGBNorm> matrices = new ArrayList<>();//均值K模型(模型需要返回) protected List<RGBNorm> matrices = new ArrayList<>();//均值K模型(模型需要返回)
public List<RGBNorm> getMatrices() { public List<RGBNorm> getMatrices() {
return matrices; return matrices;
} }
public MeanClustering(int speciesQuantity) throws Exception { public double[] getResultByNorm() {
this.speciesQuantity = speciesQuantity;//聚类的数量 MeanSort meanSort = new MeanSort();
double[] dm = new double[matrices.size() * length];
matrices.sort(meanSort);
for (int i = 0; i < matrices.size(); i++) {
RGBNorm rgbNorm = matrices.get(i);
double[] rgb = rgbNorm.getRgb();
for (int j = 0; j < rgb.length; j++) {
dm[i * rgb.length + j] = rgb[j];
}
}
return dm;
} }
public void setColor(double[] color) throws Exception { public MeanClustering(int speciesQuantity, int maxTimes) throws Exception {
if (matrixList.size() == 0) { this.speciesQuantity = speciesQuantity;//聚类的数量
matrixList.add(color); this.maxTimes = maxTimes;
length = color.length; }
public void setFeature(double[] feature) throws Exception {
if (matrixList.isEmpty()) {
matrixList.add(feature);
length = feature.length;
} else { } else {
if (length == color.length) { if (length == feature.length) {
matrixList.add(color); matrixList.add(feature);
} else { } else {
throw new Exception("vector length is different"); throw new Exception("vector length is different");
} }
@ -86,11 +103,11 @@ public class MeanClustering {
} }
//进行两者的比较 //进行两者的比较
boolean isNext; boolean isNext;
for (int i = 0; i < 60; i++) { for (int i = 0; i < maxTimes; i++) {
//System.out.println("聚类:" + i); //System.out.println("聚类:" + i);
averageMatrix(); averageMatrix();
isNext = isNext(); isNext = isNext();
if (isNext && i < 59) { if (isNext && i < maxTimes - 1) {
clear(); clear();
} else { } else {
break; break;

View File

@ -0,0 +1,17 @@
package org.wlld.tools;
import org.wlld.entity.RGBNorm;
import java.util.Comparator;
public class MeanSort implements Comparator<RGBNorm> {
@Override
public int compare(RGBNorm o1, RGBNorm o2) {
if (o1.getNorm() > o2.getNorm()) {
return 1;
} else if (o1.getNorm() < o2.getNorm()) {
return -1;
}
return 0;
}
}