mirror of
https://gitee.com/dromara/easyAi.git
synced 2024-12-02 03:38:08 +08:00
修改二叉树回归误差比较
This commit is contained in:
parent
58db5df18c
commit
382c3ca598
@ -59,6 +59,20 @@ public class Forest extends Frequency {
|
||||
this.resultVariance = resultVariance;
|
||||
}
|
||||
|
||||
public double getMappingFeature(Matrix feature) throws Exception {//获取映射后的特征
|
||||
double nub;
|
||||
if (feature.isRowVector()) {
|
||||
if (isOldG) {//使用原有基
|
||||
nub = feature.getNumber(0, oldGId);
|
||||
} else {
|
||||
nub = transG(pc1, feature, gNorm);
|
||||
}
|
||||
} else {
|
||||
throw new Exception("feature is not a rowVector");
|
||||
}
|
||||
return nub;
|
||||
}
|
||||
|
||||
private double[] findG() throws Exception {//寻找新的切入维度
|
||||
// 先尝试从原有维度切入
|
||||
int xSize = conditionMatrix.getX();
|
||||
|
@ -52,22 +52,15 @@ public class RegressionForest extends Frequency {
|
||||
}
|
||||
}
|
||||
|
||||
public double getDist(double[] feature, double result) {//获取特征误差结果
|
||||
Forest forestFinish;
|
||||
if (result <= min) {//直接找下边界区域
|
||||
forestFinish = getLimitRegion(forest, false);
|
||||
} else if (result >= max) {//直接找到上边界区域
|
||||
forestFinish = getLimitRegion(forest, true);
|
||||
} else {
|
||||
forestFinish = getRegion(forest, result);
|
||||
}
|
||||
public double getDist(Matrix featureMatrix, double result) throws Exception {//获取特征误差结果
|
||||
Forest forestFinish = getRegion(forest, featureMatrix);
|
||||
//计算误差
|
||||
double[] w = forestFinish.getW();
|
||||
double sigma = 0;
|
||||
for (int i = 0; i < w.length; i++) {
|
||||
double nub;
|
||||
if (i < w.length - 1) {
|
||||
nub = w[i] * feature[i];
|
||||
nub = w[i] * featureMatrix.getNumber(0, i);
|
||||
} else {
|
||||
nub = w[i];
|
||||
}
|
||||
@ -76,8 +69,9 @@ public class RegressionForest extends Frequency {
|
||||
return Math.abs(result - sigma);
|
||||
}
|
||||
|
||||
private Forest getRegion(Forest forest, double result) {
|
||||
private Forest getRegion(Forest forest, Matrix matrix) throws Exception {
|
||||
double median = forest.getMedian();
|
||||
double result = forest.getMappingFeature(matrix);
|
||||
if (result > median && forest.getForestRight() != null) {//向右走
|
||||
forest = forest.getForestRight();
|
||||
} else if (result <= median && forest.getForestLeft() != null) {//向左走
|
||||
@ -85,7 +79,7 @@ public class RegressionForest extends Frequency {
|
||||
} else {
|
||||
return forest;
|
||||
}
|
||||
return getRegion(forest, result);
|
||||
return getRegion(forest, matrix);
|
||||
}
|
||||
|
||||
private Forest getLimitRegion(Forest forest, boolean isMax) {
|
||||
@ -202,7 +196,7 @@ public class RegressionForest extends Frequency {
|
||||
}
|
||||
}
|
||||
|
||||
private void pruning() throws Exception {//剪枝
|
||||
private void pruning() {//剪枝
|
||||
//先获取当前最大id
|
||||
int max = forestMap.lastKey();
|
||||
int layersNub = (int) (Math.log(max) / Math.log(2));//当前的层数
|
||||
|
35
src/main/java/org/wlld/regressionForest/RgbFilter.java
Normal file
35
src/main/java/org/wlld/regressionForest/RgbFilter.java
Normal file
@ -0,0 +1,35 @@
|
||||
package org.wlld.regressionForest;
|
||||
|
||||
import org.wlld.MatrixTools.Matrix;
|
||||
import org.wlld.imageRecognition.ThreeChannelMatrix;
|
||||
|
||||
/**
|
||||
* @param
|
||||
* @DATA
|
||||
* @Author LiDaPeng
|
||||
* @Description rgb 替换为权重
|
||||
*/
|
||||
public class RgbFilter {
|
||||
|
||||
public void filter(ThreeChannelMatrix threeChannelMatrix, RegressionForest regressionForest) throws Exception {
|
||||
Matrix matrixR = threeChannelMatrix.getMatrixR();
|
||||
Matrix matrixG = threeChannelMatrix.getMatrixG();
|
||||
Matrix matrixB = threeChannelMatrix.getMatrixB();
|
||||
int x = matrixR.getX();
|
||||
int y = matrixR.getY();
|
||||
for (int i = 0; i < x; i++) {
|
||||
for (int j = 0; j < y; j++) {
|
||||
double[] feature = new double[]{matrixR.getNumber(i, j) / 255, matrixG.getNumber(i, j) / 255};
|
||||
double result = matrixB.getNumber(i, j);
|
||||
regressionForest.insertFeature(feature, result);
|
||||
}
|
||||
}
|
||||
regressionForest.startStudy();
|
||||
for (int i = 0; i < x; i++) {
|
||||
for (int j = 0; j < y; j++) {
|
||||
double[] feature = new double[]{matrixR.getNumber(i, j) / 255, matrixG.getNumber(i, j) / 255};
|
||||
double result = matrixB.getNumber(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,18 +1,14 @@
|
||||
package coverTest;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import org.wlld.MatrixTools.Matrix;
|
||||
import org.wlld.ModelData;
|
||||
import org.wlld.config.Classifier;
|
||||
import org.wlld.config.RZ;
|
||||
import org.wlld.config.StudyPattern;
|
||||
import org.wlld.function.Sigmod;
|
||||
import org.wlld.imageRecognition.*;
|
||||
import org.wlld.nerveEntity.ModelParameter;
|
||||
import org.wlld.tools.ArithUtil;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -1,6 +1,18 @@
|
||||
package coverTest;
|
||||
|
||||
import org.wlld.randomForest.Tree;
|
||||
import org.wlld.MatrixTools.Matrix;
|
||||
import org.wlld.config.Classifier;
|
||||
import org.wlld.config.StudyPattern;
|
||||
import org.wlld.imageRecognition.Convolution;
|
||||
import org.wlld.imageRecognition.Picture;
|
||||
import org.wlld.imageRecognition.TempleConfig;
|
||||
import org.wlld.imageRecognition.ThreeChannelMatrix;
|
||||
import org.wlld.imageRecognition.segmentation.RegionBody;
|
||||
import org.wlld.imageRecognition.segmentation.Specifications;
|
||||
import org.wlld.imageRecognition.segmentation.Watershed;
|
||||
import org.wlld.nerveEntity.ModelParameter;
|
||||
import org.wlld.param.Cutting;
|
||||
import org.wlld.param.Food;
|
||||
import org.wlld.regressionForest.RegressionForest;
|
||||
|
||||
import java.util.*;
|
||||
@ -12,6 +24,8 @@ import java.util.*;
|
||||
* @Description
|
||||
*/
|
||||
public class ForestTest {
|
||||
private static Convolution convolution = new Convolution();
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
test();
|
||||
//int a = (int) (Math.log(4) / Math.log(2));//id22是第几层
|
||||
@ -20,6 +34,66 @@ public class ForestTest {
|
||||
|
||||
}
|
||||
|
||||
public static void testPic() throws Exception {
|
||||
Picture picture = new Picture();
|
||||
TempleConfig templeConfig = getTemple(null);
|
||||
List<Specifications> specificationsList = new ArrayList<>();
|
||||
Specifications specifications = new Specifications();
|
||||
specifications.setMinWidth(100);
|
||||
specifications.setMinHeight(100);
|
||||
specifications.setMaxWidth(1000);
|
||||
specifications.setMaxHeight(1000);
|
||||
specificationsList.add(specifications);
|
||||
ThreeChannelMatrix threeChannelMatrix = picture.getThreeMatrix("/Users/lidapeng/Desktop/myDocument/d.jpg");
|
||||
Watershed watershed = new Watershed(threeChannelMatrix, specificationsList, templeConfig);
|
||||
List<RegionBody> regionBodies = watershed.rainfall();
|
||||
RegionBody regionBody = regionBodies.get(0);
|
||||
int minX = regionBody.getMinX();
|
||||
int minY = regionBody.getMinY();
|
||||
int maxX = regionBody.getMaxX();
|
||||
int maxY = regionBody.getMaxY();
|
||||
int xSize = maxX - minX;
|
||||
int ySize = maxY - minY;
|
||||
ThreeChannelMatrix threeChannelMatrix1 = convolution.getRegionMatrix(threeChannelMatrix, minX, minY, xSize, ySize);
|
||||
|
||||
}
|
||||
|
||||
public static void picDo(ThreeChannelMatrix threeChannelMatrix) {//进行lbp特征提取
|
||||
|
||||
|
||||
}
|
||||
|
||||
public static TempleConfig getTemple(ModelParameter modelParameter) throws Exception {
|
||||
TempleConfig templeConfig = new TempleConfig();
|
||||
//templeConfig.isShowLog(true);//是否打印日志
|
||||
Cutting cutting = templeConfig.getCutting();
|
||||
Food food = templeConfig.getFood();
|
||||
//
|
||||
cutting.setMaxRain(360);//切割阈值
|
||||
cutting.setTh(0.6);
|
||||
cutting.setRegionNub(200);
|
||||
cutting.setMaxIou(0.5);
|
||||
//knn参数
|
||||
templeConfig.setKnnNub(1);
|
||||
//池化比例
|
||||
templeConfig.setPoolSize(2);//缩小比例
|
||||
//聚类
|
||||
templeConfig.setFeatureNub(3);//聚类特征数量
|
||||
//菜品识别实体类
|
||||
food.setShrink(10);//缩紧像素
|
||||
food.setTimes(2);//聚类数据增强
|
||||
food.setRowMark(0.1);//0.12
|
||||
food.setColumnMark(0.1);//0.25
|
||||
food.setRegressionNub(20000);
|
||||
food.setTrayTh(0.08);
|
||||
templeConfig.setClassifier(Classifier.KNN);
|
||||
templeConfig.init(StudyPattern.Cover_Pattern, true, 400, 400, 3);
|
||||
if (modelParameter != null) {
|
||||
templeConfig.insertModel(modelParameter);
|
||||
}
|
||||
return templeConfig;
|
||||
}
|
||||
|
||||
public static void test() throws Exception {//对分段回归进行测试
|
||||
int size = 2000;
|
||||
RegressionForest regressionForest = new RegressionForest(size, 3, 0.01, 200);
|
||||
@ -41,18 +115,22 @@ public class ForestTest {
|
||||
double sigma = 0;
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
double[] feature = a1.get(i);
|
||||
double[] test = new double[]{feature[0], feature[1]};
|
||||
Matrix test = new Matrix(1, 3);
|
||||
test.setNub(0, 0, feature[0]);
|
||||
test.setNub(0, 1, feature[1]);
|
||||
test.setNub(0, 2, feature[2]);
|
||||
double dist = regressionForest.getDist(test, feature[2]);
|
||||
sigma = sigma + dist;
|
||||
}
|
||||
double avs = sigma / size;
|
||||
System.out.println("a误差:" + avs);
|
||||
// a误差:0.0017585065712555645
|
||||
// b误差:0.00761733737464547
|
||||
sigma = 0;
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
double[] feature = b1.get(i);
|
||||
double[] test = new double[]{feature[0], feature[1]};
|
||||
Matrix test = new Matrix(1, 3);
|
||||
test.setNub(0, 0, feature[0]);
|
||||
test.setNub(0, 1, feature[1]);
|
||||
test.setNub(0, 2, feature[2]);
|
||||
double dist = regressionForest.getDist(test, feature[2]);
|
||||
sigma = sigma + dist;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user