diff --git a/src/main/java/org/wlld/naturalLanguage/Talk.java b/src/main/java/org/wlld/naturalLanguage/Talk.java index 6e90289..1dfa92f 100644 --- a/src/main/java/org/wlld/naturalLanguage/Talk.java +++ b/src/main/java/org/wlld/naturalLanguage/Talk.java @@ -13,9 +13,17 @@ import java.util.List; * @date 4:14 下午 2020/2/23 */ public class Talk { - private List allWorld = WordTemple.get().getAllWorld();//所有词集合 - private RandomForest randomForest = WordTemple.get().getRandomForest();//获取随机森林模型 - private List> wordTimes = WordTemple.get().getWordTimes(); + private List allWorld;//所有词集合 + private RandomForest randomForest;//获取随机森林模型 + private List> wordTimes; + private WordTemple wordTemple; + + public Talk(WordTemple wordTemple) { + this.wordTemple = wordTemple; + allWorld = wordTemple.getAllWorld();//所有词集合 + randomForest = wordTemple.getRandomForest();//获取随机森林模型 + wordTimes = wordTemple.getWordTimes(); + } public List talk(String sentence) throws Exception { List typeList = new ArrayList<>(); @@ -69,7 +77,7 @@ public class Talk { features.add(nub); } int type = 0; - if (ArithUtil.div(wrong, wordNumber) < WordTemple.get().getGarbageTh()) { + if (ArithUtil.div(wrong, wordNumber) < wordTemple.getGarbageTh()) { LangBody langBody = new LangBody(); langBody.setA1(features.get(0)); langBody.setA2(features.get(1)); @@ -135,7 +143,7 @@ public class Talk { listWord = body.getWorldBodies();//这个body报了一次空指针 word.setWordFrequency(body.getWordFrequency()); } - Tokenizer tokenizer = new Tokenizer(); + Tokenizer tokenizer = new Tokenizer(wordTemple); tokenizer.radiation(words); } diff --git a/src/main/java/org/wlld/naturalLanguage/TemplateReader.java b/src/main/java/org/wlld/naturalLanguage/TemplateReader.java index cbf2c9e..cb8555a 100644 --- a/src/main/java/org/wlld/naturalLanguage/TemplateReader.java +++ b/src/main/java/org/wlld/naturalLanguage/TemplateReader.java @@ -20,7 +20,7 @@ public class TemplateReader {//模板读取类 * @param charsetName 文字编码(一般使用UTF-8) * @throws Exception 找不到文字抛出异常 */ - public void read(String url, String charsetName) throws Exception { + public void read(String url, String charsetName, WordTemple wordTemple) throws Exception { this.charsetName = charsetName; File file = new File(url); InputStream is = new FileInputStream(file); @@ -68,11 +68,12 @@ public class TemplateReader {//模板读取类 } } } - word(); + word(wordTemple); } - public void word() throws Exception { - Tokenizer tokenizer = new Tokenizer(); + public void word(WordTemple wordTemple) throws Exception { + //将模版注入分词器进行分词 + Tokenizer tokenizer = new Tokenizer(wordTemple); tokenizer.start(model); } diff --git a/src/main/java/org/wlld/naturalLanguage/Tokenizer.java b/src/main/java/org/wlld/naturalLanguage/Tokenizer.java index d73fb22..2f4a8d5 100644 --- a/src/main/java/org/wlld/naturalLanguage/Tokenizer.java +++ b/src/main/java/org/wlld/naturalLanguage/Tokenizer.java @@ -13,10 +13,18 @@ import java.util.*; * @date 7:42 上午 2020/2/23 */ public class Tokenizer extends Frequency { - private List sentences = WordTemple.get().getSentences();//所有断句 - private List allWorld = WordTemple.get().getAllWorld();//所有词集合 - private List> wordTimes = WordTemple.get().getWordTimes();//所有词编号 + private List sentences;//所有断句 + private List allWorld;//所有词集合 + private List> wordTimes;//所有词编号 private Word nowWord;//上一次出现的关键字 + private WordTemple wordTemple; + + public Tokenizer(WordTemple wordTemple) { + this.wordTemple = wordTemple; + sentences = wordTemple.getSentences();//所有断句 + allWorld = wordTemple.getAllWorld();//所有词集合 + wordTimes = wordTemple.getWordTimes();//所有词编号 + } public void start(Map> model) throws Exception { //model的主键是类别,值是该类别语句的集合 @@ -59,7 +67,7 @@ public class Tokenizer extends Frequency { } private void number() {//分词编号 - System.out.println("开始编码:" + sentences.size()); + System.out.println("开始编码:" + (sentences.size() + 1)); for (Sentence sentence : sentences) { List features = sentence.getFeatures(); List sentenceList = sentence.getKeyWords(); @@ -85,10 +93,12 @@ public class Tokenizer extends Frequency { } column.add("key"); DataTable dataTable = new DataTable(column); - dataTable.setKey("key"); + dataTable.setKey("key");//确认结果集主键 //初始化随机森林 - RandomForest randomForest = new RandomForest(11); - WordTemple.get().setRandomForest(randomForest);//保存随机森林到模版 + RandomForest randomForest = new RandomForest(wordTemple.getTreeNub()); + randomForest.setTrustTh(wordTemple.getTrustTh()); + randomForest.setTrustPunishment(wordTemple.getTrustPunishment()); + wordTemple.setRandomForest(randomForest);//保存随机森林到模版 randomForest.init(dataTable); for (Sentence sentence : sentences) { LangBody langBody = new LangBody(); diff --git a/src/main/java/org/wlld/naturalLanguage/WordTemple.java b/src/main/java/org/wlld/naturalLanguage/WordTemple.java index d254e64..402249e 100644 --- a/src/main/java/org/wlld/naturalLanguage/WordTemple.java +++ b/src/main/java/org/wlld/naturalLanguage/WordTemple.java @@ -11,33 +11,30 @@ import java.util.List; * @date 4:15 下午 2020/2/23 */ public class WordTemple { - private static WordTemple Word_Temple = new WordTemple(); private List sentences = new ArrayList<>();//所有断句 private List allWorld = new ArrayList<>();//所有词集合 private List> wordTimes = new ArrayList<>();//词编号 private RandomForest randomForest;//保存的随机森林模型 + //四大参数 private double garbageTh = 0.5;//垃圾分类的阈值默认0.5 private double trustPunishment = 0.1;//信任惩罚 + private double trustTh = 0.1;//信任阈值,相当于一次信任惩罚的数值 + private int treeNub = 11;//丛林里面树的数量 - public WordModel getModel() {//获取模型 - WordModel wordModel = new WordModel(); - wordModel.setAllWorld(allWorld); - wordModel.setWordTimes(wordTimes); - wordModel.setGarbageTh(garbageTh); - wordModel.setTrustPunishment(trustPunishment); - wordModel.setTrustTh(randomForest.getTrustTh()); - wordModel.setRfModel(randomForest.getModel()); - return wordModel; + public int getTreeNub() { + return treeNub; } - public void insertModel(WordModel wordModel) throws Exception {//注入模型 - allWorld = wordModel.getAllWorld(); - wordTimes = wordModel.getWordTimes(); - garbageTh = wordModel.getGarbageTh(); - trustPunishment = wordModel.getTrustPunishment(); - randomForest = new RandomForest(); - randomForest.setTrustTh(wordModel.getTrustTh()); - randomForest.insertModel(wordModel.getRfModel()); + public void setTreeNub(int treeNub) { + this.treeNub = treeNub; + } + + public double getTrustTh() { + return trustTh; + } + + public void setTrustTh(double trustTh) { + this.trustTh = trustTh; } public double getTrustPunishment() { @@ -64,9 +61,6 @@ public class WordTemple { this.randomForest = randomForest; } - private WordTemple() { - } - public List> getWordTimes() { return wordTimes; } @@ -75,10 +69,6 @@ public class WordTemple { this.wordTimes = wordTimes; } - public static WordTemple get() { - return Word_Temple; - } - public List getSentences() { return sentences; } diff --git a/src/main/java/org/wlld/randomForest/RandomForest.java b/src/main/java/org/wlld/randomForest/RandomForest.java index c961400..5e2b89d 100644 --- a/src/main/java/org/wlld/randomForest/RandomForest.java +++ b/src/main/java/org/wlld/randomForest/RandomForest.java @@ -13,6 +13,15 @@ public class RandomForest { private Random random = new Random(); private Tree[] forest; private double trustTh = 0.1;//信任阈值 + private double trustPunishment = 0.1;//信任惩罚 + + public double getTrustPunishment() { + return trustPunishment; + } + + public void setTrustPunishment(double trustPunishment) { + this.trustPunishment = trustPunishment; + } public double getTrustTh() { return trustTh; @@ -33,21 +42,6 @@ public class RandomForest { } } - public void insertModel(RfModel rfModel) throws Exception {//注入模型 - if (rfModel != null) { - Map nodeMap = rfModel.getNodeMap(); - forest = new Tree[nodeMap.size()]; - for (Map.Entry entry : nodeMap.entrySet()) { - int key = entry.getKey(); - Tree tree = new Tree(); - forest[key] = tree; - tree.setRootNode(entry.getValue()); - } - } else { - throw new Exception("model is null"); - } - } - public RfModel getModel() {//获取模型 RfModel rfModel = new RfModel(); Map nodeMap = new HashMap<>(); @@ -89,6 +83,7 @@ public class RandomForest { return type; } + //rf初始化 public void init(DataTable dataTable) throws Exception { //一棵树属性的数量 if (dataTable.getSize() > 4) { @@ -96,7 +91,7 @@ public class RandomForest { //int kNub = dataTable.getSize() - 1; // System.out.println("knNub==" + kNub); for (int i = 0; i < forest.length; i++) { - Tree tree = new Tree(getRandomData(dataTable, kNub)); + Tree tree = new Tree(getRandomData(dataTable, kNub), trustPunishment); forest[i] = tree; } } else { @@ -119,6 +114,7 @@ public class RandomForest { } } + //从总属性列表中随机挑选属性kNub个属性数量 private DataTable getRandomData(DataTable dataTable, int kNub) throws Exception { Set attr = dataTable.getKeyType(); Set myName = new HashSet<>(); diff --git a/src/main/java/org/wlld/randomForest/Tree.java b/src/main/java/org/wlld/randomForest/Tree.java index 7ab565a..039df2b 100644 --- a/src/main/java/org/wlld/randomForest/Tree.java +++ b/src/main/java/org/wlld/randomForest/Tree.java @@ -18,6 +18,7 @@ public class Tree {//决策树 private List endList;//最终结果分类 private List lastNodes = new ArrayList<>();//最后一层节点集合 private Random random = new Random(); + private double trustPunishment;//信任惩罚 public Node getRootNode() { return rootNode; @@ -36,11 +37,13 @@ public class Tree {//决策树 private double gainRatio;//信息增益率 } - public Tree() { + public Tree(double trustPunishment) { + this.trustPunishment = trustPunishment; } - public Tree(DataTable dataTable) throws Exception { + public Tree(DataTable dataTable, double trustPunishment) throws Exception { if (dataTable != null && dataTable.getKey() != null) { + this.trustPunishment = trustPunishment; this.dataTable = dataTable; } else { throw new Exception("dataTable is empty"); @@ -221,7 +224,7 @@ public class Tree {//决策树 private void punishment(TreeWithTrust treeWithTrust) {//信任惩罚 //System.out.println("惩罚"); double trust = treeWithTrust.getTrust();//获取当前信任值 - trust = ArithUtil.mul(trust, WordTemple.get().getTrustPunishment()); + trust = ArithUtil.mul(trust, trustPunishment); treeWithTrust.setTrust(trust); } diff --git a/src/test/java/org/wlld/LangTest.java b/src/test/java/org/wlld/LangTest.java index b504325..8f34caa 100644 --- a/src/test/java/org/wlld/LangTest.java +++ b/src/test/java/org/wlld/LangTest.java @@ -70,17 +70,15 @@ public class LangTest { public static void test() throws Exception { //创建模板读取累 TemplateReader templateReader = new TemplateReader(); + WordTemple wordTemple = new WordTemple();//初始化语言模版 + wordTemple.setTreeNub(9); //读取语言模版,第一个参数是模版地址,第二个参数是编码方式 (教程里的第三个参数已经省略) //同时也是学习过程 - templateReader.read("D:\\b/a.txt", "UTF-8"); - //学习结束获取模型参数 - // WordModel wordModel = WordTemple.get().getModel(); - //不用学习注入模型参数 - // WordTemple.get().insertModel(wordModel); - Talk talk = new Talk(); + templateReader.read("/Users/lidapeng/Desktop/myDocument/model.txt", "UTF-8", wordTemple); + Talk talk = new Talk(wordTemple); //输入语句进行识别,若有标点符号会形成LIST中的每个元素 //返回的集合中每个值代表了输入语句,每个标点符号前语句的分类 - List list = talk.talk("有个快递尽快代我邮寄出去"); + List list = talk.talk("空调坏了"); System.out.println(list); } }