修改决策树精度损失问题

This commit is contained in:
李大鹏 2024-03-03 20:42:07 +08:00
parent 5aebb15cd9
commit feaa086a21

View File

@ -1,6 +1,5 @@
package org.wlld.randomForest;
import org.wlld.tools.ArithUtil;
import java.lang.reflect.Method;
import java.util.*;
@ -50,7 +49,7 @@ public class Tree {//决策树
}
private double log2(double p) {
return ArithUtil.div(Math.log(p), Math.log(2));
return Math.log(p) / Math.log(2);
}
private double getEnt(List<Integer> list) {
@ -67,14 +66,14 @@ public class Tree {//决策树
double ent = 0;
//求信息熵
for (Map.Entry<Integer, Integer> entry1 : myType.entrySet()) {
double g = ArithUtil.div(entry1.getValue(), list.size());//每个类别的概率
ent = ArithUtil.add(ent, ArithUtil.mul(g, log2(g)));
double g = (double) entry1.getValue() / (double) list.size();//每个类别的概率
ent = ent + g * log2(g);
}
return -ent;
}
private double getGain(double ent, double dNub, double gain) {
return ArithUtil.add(gain, ArithUtil.mul(ent, dNub));
return gain + ent * dNub;
}
private List<Node> createNode(Node node) {
@ -122,19 +121,19 @@ public class Tree {//决策树
sonNode.typeId = entry.getKey();//该属性值
int myNub = list.size();//该属性值下数据的数量
double ent = getEnt(list);//该属性值 的信息熵
double dNub = ArithUtil.div(myNub, fatherNub);//该属性值在 父级样本中出现的概率
IV = ArithUtil.add(ArithUtil.mul(dNub, log2(dNub)), IV);
double dNub = (double) myNub / (double) fatherNub;//该属性值在 父级样本中出现的概率
IV = dNub * log2(dNub) + IV;
gain = getGain(ent, dNub, gain);
}
Gain gain1 = new Gain();
gainMap.put(name, gain1);
gain1.gain = ArithUtil.sub(fatherEnt, gain);//信息增益
gain1.gain = fatherEnt - gain;//信息增益
if (IV != 0) {
gain1.gainRatio = ArithUtil.div(gain1.gain, -IV);//增益率
gain1.gainRatio = gain1.gain / -IV;//增益率
} else {
gain1.gainRatio = 1000000;
}
sigmaG = ArithUtil.add(gain1.gain, sigmaG);
sigmaG = gain1.gain + sigmaG;
i++;
}
double avgGain = sigmaG / i;
@ -143,9 +142,9 @@ public class Tree {//决策树
System.out.println("平均信息增益==============================" + avgGain);
for (Map.Entry<String, Gain> entry : gainMap.entrySet()) {
Gain gain = entry.getValue();
System.out.println("主键:" + entry.getKey() + ",平均信息增益:" + avgGain + ",可用属性数量:" + gainMap.size()
+ "该属性信息增益:" + gain.gain + ",该属性增益率:" + gain.gainRatio + ",当前最高增益率:" + gainRatio);
if (gainMap.size() == 1 || (gain.gain >= avgGain && (gain.gainRatio >= gainRatio || gainRatio == -2))) {
// System.out.println("主键:" + entry.getKey() + ",平均信息增益:" + avgGain + ",可用属性数量:" + gainMap.size()
// + "该属性信息增益:" + gain.gain + ",该属性增益率:" + gain.gainRatio + ",当前最高增益率:" + gainRatio);
if (gainMap.size() == 1 || ((gain.gain >= avgGain || Math.abs(gain.gain - avgGain) < 0.000001) && (gain.gainRatio >= gainRatio || gainRatio == -2))) {
gainRatio = gain.gainRatio;
key = entry.getKey();
}
@ -222,7 +221,7 @@ public class Tree {//决策树
private void punishment(TreeWithTrust treeWithTrust) {//信任惩罚
//System.out.println("惩罚");
double trust = treeWithTrust.getTrust();//获取当前信任值
trust = ArithUtil.mul(trust, trustPunishment);
trust = trust * trustPunishment;
treeWithTrust.setTrust(trust);
}
@ -312,7 +311,7 @@ public class Tree {//决策树
int fatherType = getType(father.fatherList);
int nub = getRightPoint(father.fatherList, fatherType);
//父级该样本正确率
double rightFather = ArithUtil.div(nub, father.fatherList.size());
double rightFather = (double) nub / (double) father.fatherList.size();
int rightNub = 0;
int rightAllNub = 0;
for (int i = 0; i < sonNodes.size(); i++) {
@ -322,7 +321,7 @@ public class Tree {//决策树
rightNub = rightNub + right;
rightAllNub = rightAllNub + list.size();
}
double rightPoint = ArithUtil.div(rightNub, rightAllNub);//子节点正确率
double rightPoint = (double) rightNub / (double) rightAllNub;//子节点正确率
if (rightPoint <= rightFather) {
isRemove = true;
}