From de0e2f9c6b7d6004f72458654d12ea0eff1977ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=A4=A7=E9=B9=8F?= <794757862@qq.com> Date: Thu, 11 Apr 2024 14:45:16 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BD=BF=E5=8E=9F=E5=9E=8B=E8=81=9A=E7=B1=BB?= =?UTF-8?q?=E7=89=B9=E5=BE=81=E4=BE=9D=E8=8C=83=E6=95=B0=E8=BF=9B=E8=A1=8C?= =?UTF-8?q?=E6=8E=92=E5=BA=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/java/org/wlld/entity/RGBNorm.java | 3 +- .../java/org/wlld/tools/GMClustering.java | 4 +- .../java/org/wlld/tools/MeanClustering.java | 37 ++++++++++++++----- src/main/java/org/wlld/tools/MeanSort.java | 17 +++++++++ 4 files changed, 47 insertions(+), 14 deletions(-) create mode 100644 src/main/java/org/wlld/tools/MeanSort.java diff --git a/src/main/java/org/wlld/entity/RGBNorm.java b/src/main/java/org/wlld/entity/RGBNorm.java index 246c8dc..00b4639 100644 --- a/src/main/java/org/wlld/entity/RGBNorm.java +++ b/src/main/java/org/wlld/entity/RGBNorm.java @@ -1,7 +1,6 @@ package org.wlld.entity; import org.wlld.MatrixTools.Matrix; -import org.wlld.tools.ArithUtil; import org.wlld.tools.RgbRegression; import java.util.*; @@ -262,7 +261,7 @@ public class RGBNorm { double sigma = 0; if (nub > 0) { for (int i = 0; i < rgb.length; i++) { - double rgbc = ArithUtil.div(rgbAll[i], nub); + double rgbc = rgbAll[i] / nub; rgb[i] = rgbc; sigma = sigma + Math.pow(rgbc, 2); } diff --git a/src/main/java/org/wlld/tools/GMClustering.java b/src/main/java/org/wlld/tools/GMClustering.java index 34c7e20..4a14e67 100644 --- a/src/main/java/org/wlld/tools/GMClustering.java +++ b/src/main/java/org/wlld/tools/GMClustering.java @@ -21,8 +21,8 @@ public class GMClustering extends MeanClustering { this.regionSize = regionSize; } - public GMClustering(int speciesQuantity) throws Exception { - super(speciesQuantity); + public GMClustering(int speciesQuantity, int maxTimes) throws Exception { + super(speciesQuantity, maxTimes); } public double getProbabilityDensity(double[] feature) throws Exception {//获取总概率密度 diff --git a/src/main/java/org/wlld/tools/MeanClustering.java b/src/main/java/org/wlld/tools/MeanClustering.java index 74705f1..2c80c63 100644 --- a/src/main/java/org/wlld/tools/MeanClustering.java +++ b/src/main/java/org/wlld/tools/MeanClustering.java @@ -4,6 +4,7 @@ package org.wlld.tools; import org.wlld.entity.RGBNorm; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Random; @@ -12,23 +13,39 @@ public class MeanClustering { protected List matrixList = new ArrayList<>();//聚类集合 private int length;//向量长度(模型需要返回) protected int speciesQuantity;//种类数量(模型需要返回) + private final int maxTimes;//最大迭代次数 protected List matrices = new ArrayList<>();//均值K模型(模型需要返回) public List getMatrices() { return matrices; } - public MeanClustering(int speciesQuantity) throws Exception { - this.speciesQuantity = speciesQuantity;//聚类的数量 + public double[] getResultByNorm() { + MeanSort meanSort = new MeanSort(); + double[] dm = new double[matrices.size() * length]; + matrices.sort(meanSort); + for (int i = 0; i < matrices.size(); i++) { + RGBNorm rgbNorm = matrices.get(i); + double[] rgb = rgbNorm.getRgb(); + for (int j = 0; j < rgb.length; j++) { + dm[i * rgb.length + j] = rgb[j]; + } + } + return dm; } - public void setColor(double[] color) throws Exception { - if (matrixList.size() == 0) { - matrixList.add(color); - length = color.length; + public MeanClustering(int speciesQuantity, int maxTimes) throws Exception { + this.speciesQuantity = speciesQuantity;//聚类的数量 + this.maxTimes = maxTimes; + } + + public void setFeature(double[] feature) throws Exception { + if (matrixList.isEmpty()) { + matrixList.add(feature); + length = feature.length; } else { - if (length == color.length) { - matrixList.add(color); + if (length == feature.length) { + matrixList.add(feature); } else { throw new Exception("vector length is different"); } @@ -86,11 +103,11 @@ public class MeanClustering { } //进行两者的比较 boolean isNext; - for (int i = 0; i < 60; i++) { + for (int i = 0; i < maxTimes; i++) { //System.out.println("聚类:" + i); averageMatrix(); isNext = isNext(); - if (isNext && i < 59) { + if (isNext && i < maxTimes - 1) { clear(); } else { break; diff --git a/src/main/java/org/wlld/tools/MeanSort.java b/src/main/java/org/wlld/tools/MeanSort.java new file mode 100644 index 0000000..cff9ecf --- /dev/null +++ b/src/main/java/org/wlld/tools/MeanSort.java @@ -0,0 +1,17 @@ +package org.wlld.tools; + +import org.wlld.entity.RGBNorm; + +import java.util.Comparator; + +public class MeanSort implements Comparator { + @Override + public int compare(RGBNorm o1, RGBNorm o2) { + if (o1.getNorm() > o2.getNorm()) { + return 1; + } else if (o1.getNorm() < o2.getNorm()) { + return -1; + } + return 0; + } +}