diff --git a/src/main/java/org/wlld/randomForest/Tree.java b/src/main/java/org/wlld/randomForest/Tree.java index c49d27f..a55e704 100644 --- a/src/main/java/org/wlld/randomForest/Tree.java +++ b/src/main/java/org/wlld/randomForest/Tree.java @@ -1,6 +1,5 @@ package org.wlld.randomForest; -import org.wlld.tools.ArithUtil; import java.lang.reflect.Method; import java.util.*; @@ -50,7 +49,7 @@ public class Tree {//决策树 } private double log2(double p) { - return ArithUtil.div(Math.log(p), Math.log(2)); + return Math.log(p) / Math.log(2); } private double getEnt(List list) { @@ -67,14 +66,14 @@ public class Tree {//决策树 double ent = 0; //求信息熵 for (Map.Entry entry1 : myType.entrySet()) { - double g = ArithUtil.div(entry1.getValue(), list.size());//每个类别的概率 - ent = ArithUtil.add(ent, ArithUtil.mul(g, log2(g))); + double g = (double) entry1.getValue() / (double) list.size();//每个类别的概率 + ent = ent + g * log2(g); } return -ent; } private double getGain(double ent, double dNub, double gain) { - return ArithUtil.add(gain, ArithUtil.mul(ent, dNub)); + return gain + ent * dNub; } private List createNode(Node node) { @@ -122,19 +121,19 @@ public class Tree {//决策树 sonNode.typeId = entry.getKey();//该属性值 int myNub = list.size();//该属性值下数据的数量 double ent = getEnt(list);//该属性值 的信息熵 - double dNub = ArithUtil.div(myNub, fatherNub);//该属性值在 父级样本中出现的概率 - IV = ArithUtil.add(ArithUtil.mul(dNub, log2(dNub)), IV); + double dNub = (double) myNub / (double) fatherNub;//该属性值在 父级样本中出现的概率 + IV = dNub * log2(dNub) + IV; gain = getGain(ent, dNub, gain); } Gain gain1 = new Gain(); gainMap.put(name, gain1); - gain1.gain = ArithUtil.sub(fatherEnt, gain);//信息增益 + gain1.gain = fatherEnt - gain;//信息增益 if (IV != 0) { - gain1.gainRatio = ArithUtil.div(gain1.gain, -IV);//增益率 + gain1.gainRatio = gain1.gain / -IV;//增益率 } else { gain1.gainRatio = 1000000; } - sigmaG = ArithUtil.add(gain1.gain, sigmaG); + sigmaG = gain1.gain + sigmaG; i++; } double avgGain = sigmaG / i; @@ -143,9 +142,9 @@ public class Tree {//决策树 System.out.println("平均信息增益==============================" + avgGain); for (Map.Entry entry : gainMap.entrySet()) { Gain gain = entry.getValue(); - System.out.println("主键:" + entry.getKey() + ",平均信息增益:" + avgGain + ",可用属性数量:" + gainMap.size() - + "该属性信息增益:" + gain.gain + ",该属性增益率:" + gain.gainRatio + ",当前最高增益率:" + gainRatio); - if (gainMap.size() == 1 || (gain.gain >= avgGain && (gain.gainRatio >= gainRatio || gainRatio == -2))) { +// System.out.println("主键:" + entry.getKey() + ",平均信息增益:" + avgGain + ",可用属性数量:" + gainMap.size() +// + "该属性信息增益:" + gain.gain + ",该属性增益率:" + gain.gainRatio + ",当前最高增益率:" + gainRatio); + if (gainMap.size() == 1 || ((gain.gain >= avgGain || Math.abs(gain.gain - avgGain) < 0.000001) && (gain.gainRatio >= gainRatio || gainRatio == -2))) { gainRatio = gain.gainRatio; key = entry.getKey(); } @@ -222,7 +221,7 @@ public class Tree {//决策树 private void punishment(TreeWithTrust treeWithTrust) {//信任惩罚 //System.out.println("惩罚"); double trust = treeWithTrust.getTrust();//获取当前信任值 - trust = ArithUtil.mul(trust, trustPunishment); + trust = trust * trustPunishment; treeWithTrust.setTrust(trust); } @@ -312,7 +311,7 @@ public class Tree {//决策树 int fatherType = getType(father.fatherList); int nub = getRightPoint(father.fatherList, fatherType); //父级该样本正确率 - double rightFather = ArithUtil.div(nub, father.fatherList.size()); + double rightFather = (double) nub / (double) father.fatherList.size(); int rightNub = 0; int rightAllNub = 0; for (int i = 0; i < sonNodes.size(); i++) { @@ -322,7 +321,7 @@ public class Tree {//决策树 rightNub = rightNub + right; rightAllNub = rightAllNub + list.size(); } - double rightPoint = ArithUtil.div(rightNub, rightAllNub);//子节点正确率 + double rightPoint = (double) rightNub / (double) rightAllNub;//子节点正确率 if (rightPoint <= rightFather) { isRemove = true; }