修改二叉树回归误差比较

This commit is contained in:
lidapeng 2020-09-23 16:30:54 +08:00
parent 58db5df18c
commit 382c3ca598
5 changed files with 139 additions and 22 deletions

View File

@ -59,6 +59,20 @@ public class Forest extends Frequency {
this.resultVariance = resultVariance;
}
public double getMappingFeature(Matrix feature) throws Exception {//获取映射后的特征
double nub;
if (feature.isRowVector()) {
if (isOldG) {//使用原有基
nub = feature.getNumber(0, oldGId);
} else {
nub = transG(pc1, feature, gNorm);
}
} else {
throw new Exception("feature is not a rowVector");
}
return nub;
}
private double[] findG() throws Exception {//寻找新的切入维度
// 先尝试从原有维度切入
int xSize = conditionMatrix.getX();

View File

@ -52,22 +52,15 @@ public class RegressionForest extends Frequency {
}
}
public double getDist(double[] feature, double result) {//获取特征误差结果
Forest forestFinish;
if (result <= min) {//直接找下边界区域
forestFinish = getLimitRegion(forest, false);
} else if (result >= max) {//直接找到上边界区域
forestFinish = getLimitRegion(forest, true);
} else {
forestFinish = getRegion(forest, result);
}
public double getDist(Matrix featureMatrix, double result) throws Exception {//获取特征误差结果
Forest forestFinish = getRegion(forest, featureMatrix);
//计算误差
double[] w = forestFinish.getW();
double sigma = 0;
for (int i = 0; i < w.length; i++) {
double nub;
if (i < w.length - 1) {
nub = w[i] * feature[i];
nub = w[i] * featureMatrix.getNumber(0, i);
} else {
nub = w[i];
}
@ -76,8 +69,9 @@ public class RegressionForest extends Frequency {
return Math.abs(result - sigma);
}
private Forest getRegion(Forest forest, double result) {
private Forest getRegion(Forest forest, Matrix matrix) throws Exception {
double median = forest.getMedian();
double result = forest.getMappingFeature(matrix);
if (result > median && forest.getForestRight() != null) {//向右走
forest = forest.getForestRight();
} else if (result <= median && forest.getForestLeft() != null) {//向左走
@ -85,7 +79,7 @@ public class RegressionForest extends Frequency {
} else {
return forest;
}
return getRegion(forest, result);
return getRegion(forest, matrix);
}
private Forest getLimitRegion(Forest forest, boolean isMax) {
@ -202,7 +196,7 @@ public class RegressionForest extends Frequency {
}
}
private void pruning() throws Exception {//剪枝
private void pruning() {//剪枝
//先获取当前最大id
int max = forestMap.lastKey();
int layersNub = (int) (Math.log(max) / Math.log(2));//当前的层数

View File

@ -0,0 +1,35 @@
package org.wlld.regressionForest;
import org.wlld.MatrixTools.Matrix;
import org.wlld.imageRecognition.ThreeChannelMatrix;
/**
* @param
* @DATA
* @Author LiDaPeng
* @Description rgb 替换为权重
*/
public class RgbFilter {
public void filter(ThreeChannelMatrix threeChannelMatrix, RegressionForest regressionForest) throws Exception {
Matrix matrixR = threeChannelMatrix.getMatrixR();
Matrix matrixG = threeChannelMatrix.getMatrixG();
Matrix matrixB = threeChannelMatrix.getMatrixB();
int x = matrixR.getX();
int y = matrixR.getY();
for (int i = 0; i < x; i++) {
for (int j = 0; j < y; j++) {
double[] feature = new double[]{matrixR.getNumber(i, j) / 255, matrixG.getNumber(i, j) / 255};
double result = matrixB.getNumber(i, j);
regressionForest.insertFeature(feature, result);
}
}
regressionForest.startStudy();
for (int i = 0; i < x; i++) {
for (int j = 0; j < y; j++) {
double[] feature = new double[]{matrixR.getNumber(i, j) / 255, matrixG.getNumber(i, j) / 255};
double result = matrixB.getNumber(i, j);
}
}
}
}

View File

@ -1,18 +1,14 @@
package coverTest;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.wlld.MatrixTools.Matrix;
import org.wlld.ModelData;
import org.wlld.config.Classifier;
import org.wlld.config.RZ;
import org.wlld.config.StudyPattern;
import org.wlld.function.Sigmod;
import org.wlld.imageRecognition.*;
import org.wlld.nerveEntity.ModelParameter;
import org.wlld.tools.ArithUtil;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

View File

@ -1,6 +1,18 @@
package coverTest;
import org.wlld.randomForest.Tree;
import org.wlld.MatrixTools.Matrix;
import org.wlld.config.Classifier;
import org.wlld.config.StudyPattern;
import org.wlld.imageRecognition.Convolution;
import org.wlld.imageRecognition.Picture;
import org.wlld.imageRecognition.TempleConfig;
import org.wlld.imageRecognition.ThreeChannelMatrix;
import org.wlld.imageRecognition.segmentation.RegionBody;
import org.wlld.imageRecognition.segmentation.Specifications;
import org.wlld.imageRecognition.segmentation.Watershed;
import org.wlld.nerveEntity.ModelParameter;
import org.wlld.param.Cutting;
import org.wlld.param.Food;
import org.wlld.regressionForest.RegressionForest;
import java.util.*;
@ -12,6 +24,8 @@ import java.util.*;
* @Description
*/
public class ForestTest {
private static Convolution convolution = new Convolution();
public static void main(String[] args) throws Exception {
test();
//int a = (int) (Math.log(4) / Math.log(2));//id22是第几层
@ -20,6 +34,66 @@ public class ForestTest {
}
public static void testPic() throws Exception {
Picture picture = new Picture();
TempleConfig templeConfig = getTemple(null);
List<Specifications> specificationsList = new ArrayList<>();
Specifications specifications = new Specifications();
specifications.setMinWidth(100);
specifications.setMinHeight(100);
specifications.setMaxWidth(1000);
specifications.setMaxHeight(1000);
specificationsList.add(specifications);
ThreeChannelMatrix threeChannelMatrix = picture.getThreeMatrix("/Users/lidapeng/Desktop/myDocument/d.jpg");
Watershed watershed = new Watershed(threeChannelMatrix, specificationsList, templeConfig);
List<RegionBody> regionBodies = watershed.rainfall();
RegionBody regionBody = regionBodies.get(0);
int minX = regionBody.getMinX();
int minY = regionBody.getMinY();
int maxX = regionBody.getMaxX();
int maxY = regionBody.getMaxY();
int xSize = maxX - minX;
int ySize = maxY - minY;
ThreeChannelMatrix threeChannelMatrix1 = convolution.getRegionMatrix(threeChannelMatrix, minX, minY, xSize, ySize);
}
public static void picDo(ThreeChannelMatrix threeChannelMatrix) {//进行lbp特征提取
}
public static TempleConfig getTemple(ModelParameter modelParameter) throws Exception {
TempleConfig templeConfig = new TempleConfig();
//templeConfig.isShowLog(true);//是否打印日志
Cutting cutting = templeConfig.getCutting();
Food food = templeConfig.getFood();
//
cutting.setMaxRain(360);//切割阈值
cutting.setTh(0.6);
cutting.setRegionNub(200);
cutting.setMaxIou(0.5);
//knn参数
templeConfig.setKnnNub(1);
//池化比例
templeConfig.setPoolSize(2);//缩小比例
//聚类
templeConfig.setFeatureNub(3);//聚类特征数量
//菜品识别实体类
food.setShrink(10);//缩紧像素
food.setTimes(2);//聚类数据增强
food.setRowMark(0.1);//0.12
food.setColumnMark(0.1);//0.25
food.setRegressionNub(20000);
food.setTrayTh(0.08);
templeConfig.setClassifier(Classifier.KNN);
templeConfig.init(StudyPattern.Cover_Pattern, true, 400, 400, 3);
if (modelParameter != null) {
templeConfig.insertModel(modelParameter);
}
return templeConfig;
}
public static void test() throws Exception {//对分段回归进行测试
int size = 2000;
RegressionForest regressionForest = new RegressionForest(size, 3, 0.01, 200);
@ -41,18 +115,22 @@ public class ForestTest {
double sigma = 0;
for (int i = 0; i < 1000; i++) {
double[] feature = a1.get(i);
double[] test = new double[]{feature[0], feature[1]};
Matrix test = new Matrix(1, 3);
test.setNub(0, 0, feature[0]);
test.setNub(0, 1, feature[1]);
test.setNub(0, 2, feature[2]);
double dist = regressionForest.getDist(test, feature[2]);
sigma = sigma + dist;
}
double avs = sigma / size;
System.out.println("a误差" + avs);
// a误差0.0017585065712555645
// b误差0.00761733737464547
sigma = 0;
for (int i = 0; i < 1000; i++) {
double[] feature = b1.get(i);
double[] test = new double[]{feature[0], feature[1]};
Matrix test = new Matrix(1, 3);
test.setNub(0, 0, feature[0]);
test.setNub(0, 1, feature[1]);
test.setNub(0, 2, feature[2]);
double dist = regressionForest.getDist(test, feature[2]);
sigma = sigma + dist;
}