mirror of
https://gitee.com/dromara/easyAi.git
synced 2024-12-02 03:38:08 +08:00
修改决策树精度损失问题
This commit is contained in:
parent
5aebb15cd9
commit
feaa086a21
@ -1,6 +1,5 @@
|
||||
package org.wlld.randomForest;
|
||||
|
||||
import org.wlld.tools.ArithUtil;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.*;
|
||||
@ -50,7 +49,7 @@ public class Tree {//决策树
|
||||
}
|
||||
|
||||
private double log2(double p) {
|
||||
return ArithUtil.div(Math.log(p), Math.log(2));
|
||||
return Math.log(p) / Math.log(2);
|
||||
}
|
||||
|
||||
private double getEnt(List<Integer> list) {
|
||||
@ -67,14 +66,14 @@ public class Tree {//决策树
|
||||
double ent = 0;
|
||||
//求信息熵
|
||||
for (Map.Entry<Integer, Integer> entry1 : myType.entrySet()) {
|
||||
double g = ArithUtil.div(entry1.getValue(), list.size());//每个类别的概率
|
||||
ent = ArithUtil.add(ent, ArithUtil.mul(g, log2(g)));
|
||||
double g = (double) entry1.getValue() / (double) list.size();//每个类别的概率
|
||||
ent = ent + g * log2(g);
|
||||
}
|
||||
return -ent;
|
||||
}
|
||||
|
||||
private double getGain(double ent, double dNub, double gain) {
|
||||
return ArithUtil.add(gain, ArithUtil.mul(ent, dNub));
|
||||
return gain + ent * dNub;
|
||||
}
|
||||
|
||||
private List<Node> createNode(Node node) {
|
||||
@ -122,19 +121,19 @@ public class Tree {//决策树
|
||||
sonNode.typeId = entry.getKey();//该属性值
|
||||
int myNub = list.size();//该属性值下数据的数量
|
||||
double ent = getEnt(list);//该属性值 的信息熵
|
||||
double dNub = ArithUtil.div(myNub, fatherNub);//该属性值在 父级样本中出现的概率
|
||||
IV = ArithUtil.add(ArithUtil.mul(dNub, log2(dNub)), IV);
|
||||
double dNub = (double) myNub / (double) fatherNub;//该属性值在 父级样本中出现的概率
|
||||
IV = dNub * log2(dNub) + IV;
|
||||
gain = getGain(ent, dNub, gain);
|
||||
}
|
||||
Gain gain1 = new Gain();
|
||||
gainMap.put(name, gain1);
|
||||
gain1.gain = ArithUtil.sub(fatherEnt, gain);//信息增益
|
||||
gain1.gain = fatherEnt - gain;//信息增益
|
||||
if (IV != 0) {
|
||||
gain1.gainRatio = ArithUtil.div(gain1.gain, -IV);//增益率
|
||||
gain1.gainRatio = gain1.gain / -IV;//增益率
|
||||
} else {
|
||||
gain1.gainRatio = 1000000;
|
||||
}
|
||||
sigmaG = ArithUtil.add(gain1.gain, sigmaG);
|
||||
sigmaG = gain1.gain + sigmaG;
|
||||
i++;
|
||||
}
|
||||
double avgGain = sigmaG / i;
|
||||
@ -143,9 +142,9 @@ public class Tree {//决策树
|
||||
System.out.println("平均信息增益==============================" + avgGain);
|
||||
for (Map.Entry<String, Gain> entry : gainMap.entrySet()) {
|
||||
Gain gain = entry.getValue();
|
||||
System.out.println("主键:" + entry.getKey() + ",平均信息增益:" + avgGain + ",可用属性数量:" + gainMap.size()
|
||||
+ "该属性信息增益:" + gain.gain + ",该属性增益率:" + gain.gainRatio + ",当前最高增益率:" + gainRatio);
|
||||
if (gainMap.size() == 1 || (gain.gain >= avgGain && (gain.gainRatio >= gainRatio || gainRatio == -2))) {
|
||||
// System.out.println("主键:" + entry.getKey() + ",平均信息增益:" + avgGain + ",可用属性数量:" + gainMap.size()
|
||||
// + "该属性信息增益:" + gain.gain + ",该属性增益率:" + gain.gainRatio + ",当前最高增益率:" + gainRatio);
|
||||
if (gainMap.size() == 1 || ((gain.gain >= avgGain || Math.abs(gain.gain - avgGain) < 0.000001) && (gain.gainRatio >= gainRatio || gainRatio == -2))) {
|
||||
gainRatio = gain.gainRatio;
|
||||
key = entry.getKey();
|
||||
}
|
||||
@ -222,7 +221,7 @@ public class Tree {//决策树
|
||||
private void punishment(TreeWithTrust treeWithTrust) {//信任惩罚
|
||||
//System.out.println("惩罚");
|
||||
double trust = treeWithTrust.getTrust();//获取当前信任值
|
||||
trust = ArithUtil.mul(trust, trustPunishment);
|
||||
trust = trust * trustPunishment;
|
||||
treeWithTrust.setTrust(trust);
|
||||
}
|
||||
|
||||
@ -312,7 +311,7 @@ public class Tree {//决策树
|
||||
int fatherType = getType(father.fatherList);
|
||||
int nub = getRightPoint(father.fatherList, fatherType);
|
||||
//父级该样本正确率
|
||||
double rightFather = ArithUtil.div(nub, father.fatherList.size());
|
||||
double rightFather = (double) nub / (double) father.fatherList.size();
|
||||
int rightNub = 0;
|
||||
int rightAllNub = 0;
|
||||
for (int i = 0; i < sonNodes.size(); i++) {
|
||||
@ -322,7 +321,7 @@ public class Tree {//决策树
|
||||
rightNub = rightNub + right;
|
||||
rightAllNub = rightAllNub + list.size();
|
||||
}
|
||||
double rightPoint = ArithUtil.div(rightNub, rightAllNub);//子节点正确率
|
||||
double rightPoint = (double) rightNub / (double) rightAllNub;//子节点正确率
|
||||
if (rightPoint <= rightFather) {
|
||||
isRemove = true;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user