mirror of
https://gitee.com/dromara/easyAi.git
synced 2024-12-02 03:38:08 +08:00
自然语言处理将WordTemple独立出来,并取消单例
This commit is contained in:
parent
09b68cad98
commit
be9d063fb4
@ -13,9 +13,17 @@ import java.util.List;
|
||||
* @date 4:14 下午 2020/2/23
|
||||
*/
|
||||
public class Talk {
|
||||
private List<WorldBody> allWorld = WordTemple.get().getAllWorld();//所有词集合
|
||||
private RandomForest randomForest = WordTemple.get().getRandomForest();//获取随机森林模型
|
||||
private List<List<String>> wordTimes = WordTemple.get().getWordTimes();
|
||||
private List<WorldBody> allWorld;//所有词集合
|
||||
private RandomForest randomForest;//获取随机森林模型
|
||||
private List<List<String>> wordTimes;
|
||||
private WordTemple wordTemple;
|
||||
|
||||
public Talk(WordTemple wordTemple) {
|
||||
this.wordTemple = wordTemple;
|
||||
allWorld = wordTemple.getAllWorld();//所有词集合
|
||||
randomForest = wordTemple.getRandomForest();//获取随机森林模型
|
||||
wordTimes = wordTemple.getWordTimes();
|
||||
}
|
||||
|
||||
public List<Integer> talk(String sentence) throws Exception {
|
||||
List<Integer> typeList = new ArrayList<>();
|
||||
@ -69,7 +77,7 @@ public class Talk {
|
||||
features.add(nub);
|
||||
}
|
||||
int type = 0;
|
||||
if (ArithUtil.div(wrong, wordNumber) < WordTemple.get().getGarbageTh()) {
|
||||
if (ArithUtil.div(wrong, wordNumber) < wordTemple.getGarbageTh()) {
|
||||
LangBody langBody = new LangBody();
|
||||
langBody.setA1(features.get(0));
|
||||
langBody.setA2(features.get(1));
|
||||
@ -135,7 +143,7 @@ public class Talk {
|
||||
listWord = body.getWorldBodies();//这个body报了一次空指针
|
||||
word.setWordFrequency(body.getWordFrequency());
|
||||
}
|
||||
Tokenizer tokenizer = new Tokenizer();
|
||||
Tokenizer tokenizer = new Tokenizer(wordTemple);
|
||||
tokenizer.radiation(words);
|
||||
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ public class TemplateReader {//模板读取类
|
||||
* @param charsetName 文字编码(一般使用UTF-8)
|
||||
* @throws Exception 找不到文字抛出异常
|
||||
*/
|
||||
public void read(String url, String charsetName) throws Exception {
|
||||
public void read(String url, String charsetName, WordTemple wordTemple) throws Exception {
|
||||
this.charsetName = charsetName;
|
||||
File file = new File(url);
|
||||
InputStream is = new FileInputStream(file);
|
||||
@ -68,11 +68,12 @@ public class TemplateReader {//模板读取类
|
||||
}
|
||||
}
|
||||
}
|
||||
word();
|
||||
word(wordTemple);
|
||||
}
|
||||
|
||||
public void word() throws Exception {
|
||||
Tokenizer tokenizer = new Tokenizer();
|
||||
public void word(WordTemple wordTemple) throws Exception {
|
||||
//将模版注入分词器进行分词
|
||||
Tokenizer tokenizer = new Tokenizer(wordTemple);
|
||||
tokenizer.start(model);
|
||||
}
|
||||
|
||||
|
@ -13,10 +13,18 @@ import java.util.*;
|
||||
* @date 7:42 上午 2020/2/23
|
||||
*/
|
||||
public class Tokenizer extends Frequency {
|
||||
private List<Sentence> sentences = WordTemple.get().getSentences();//所有断句
|
||||
private List<WorldBody> allWorld = WordTemple.get().getAllWorld();//所有词集合
|
||||
private List<List<String>> wordTimes = WordTemple.get().getWordTimes();//所有词编号
|
||||
private List<Sentence> sentences;//所有断句
|
||||
private List<WorldBody> allWorld;//所有词集合
|
||||
private List<List<String>> wordTimes;//所有词编号
|
||||
private Word nowWord;//上一次出现的关键字
|
||||
private WordTemple wordTemple;
|
||||
|
||||
public Tokenizer(WordTemple wordTemple) {
|
||||
this.wordTemple = wordTemple;
|
||||
sentences = wordTemple.getSentences();//所有断句
|
||||
allWorld = wordTemple.getAllWorld();//所有词集合
|
||||
wordTimes = wordTemple.getWordTimes();//所有词编号
|
||||
}
|
||||
|
||||
public void start(Map<Integer, List<String>> model) throws Exception {
|
||||
//model的主键是类别,值是该类别语句的集合
|
||||
@ -59,7 +67,7 @@ public class Tokenizer extends Frequency {
|
||||
}
|
||||
|
||||
private void number() {//分词编号
|
||||
System.out.println("开始编码:" + sentences.size());
|
||||
System.out.println("开始编码:" + (sentences.size() + 1));
|
||||
for (Sentence sentence : sentences) {
|
||||
List<Integer> features = sentence.getFeatures();
|
||||
List<String> sentenceList = sentence.getKeyWords();
|
||||
@ -85,10 +93,12 @@ public class Tokenizer extends Frequency {
|
||||
}
|
||||
column.add("key");
|
||||
DataTable dataTable = new DataTable(column);
|
||||
dataTable.setKey("key");
|
||||
dataTable.setKey("key");//确认结果集主键
|
||||
//初始化随机森林
|
||||
RandomForest randomForest = new RandomForest(11);
|
||||
WordTemple.get().setRandomForest(randomForest);//保存随机森林到模版
|
||||
RandomForest randomForest = new RandomForest(wordTemple.getTreeNub());
|
||||
randomForest.setTrustTh(wordTemple.getTrustTh());
|
||||
randomForest.setTrustPunishment(wordTemple.getTrustPunishment());
|
||||
wordTemple.setRandomForest(randomForest);//保存随机森林到模版
|
||||
randomForest.init(dataTable);
|
||||
for (Sentence sentence : sentences) {
|
||||
LangBody langBody = new LangBody();
|
||||
|
@ -11,33 +11,30 @@ import java.util.List;
|
||||
* @date 4:15 下午 2020/2/23
|
||||
*/
|
||||
public class WordTemple {
|
||||
private static WordTemple Word_Temple = new WordTemple();
|
||||
private List<Sentence> sentences = new ArrayList<>();//所有断句
|
||||
private List<WorldBody> allWorld = new ArrayList<>();//所有词集合
|
||||
private List<List<String>> wordTimes = new ArrayList<>();//词编号
|
||||
private RandomForest randomForest;//保存的随机森林模型
|
||||
//四大参数
|
||||
private double garbageTh = 0.5;//垃圾分类的阈值默认0.5
|
||||
private double trustPunishment = 0.1;//信任惩罚
|
||||
private double trustTh = 0.1;//信任阈值,相当于一次信任惩罚的数值
|
||||
private int treeNub = 11;//丛林里面树的数量
|
||||
|
||||
public WordModel getModel() {//获取模型
|
||||
WordModel wordModel = new WordModel();
|
||||
wordModel.setAllWorld(allWorld);
|
||||
wordModel.setWordTimes(wordTimes);
|
||||
wordModel.setGarbageTh(garbageTh);
|
||||
wordModel.setTrustPunishment(trustPunishment);
|
||||
wordModel.setTrustTh(randomForest.getTrustTh());
|
||||
wordModel.setRfModel(randomForest.getModel());
|
||||
return wordModel;
|
||||
public int getTreeNub() {
|
||||
return treeNub;
|
||||
}
|
||||
|
||||
public void insertModel(WordModel wordModel) throws Exception {//注入模型
|
||||
allWorld = wordModel.getAllWorld();
|
||||
wordTimes = wordModel.getWordTimes();
|
||||
garbageTh = wordModel.getGarbageTh();
|
||||
trustPunishment = wordModel.getTrustPunishment();
|
||||
randomForest = new RandomForest();
|
||||
randomForest.setTrustTh(wordModel.getTrustTh());
|
||||
randomForest.insertModel(wordModel.getRfModel());
|
||||
public void setTreeNub(int treeNub) {
|
||||
this.treeNub = treeNub;
|
||||
}
|
||||
|
||||
public double getTrustTh() {
|
||||
return trustTh;
|
||||
}
|
||||
|
||||
public void setTrustTh(double trustTh) {
|
||||
this.trustTh = trustTh;
|
||||
}
|
||||
|
||||
public double getTrustPunishment() {
|
||||
@ -64,9 +61,6 @@ public class WordTemple {
|
||||
this.randomForest = randomForest;
|
||||
}
|
||||
|
||||
private WordTemple() {
|
||||
}
|
||||
|
||||
public List<List<String>> getWordTimes() {
|
||||
return wordTimes;
|
||||
}
|
||||
@ -75,10 +69,6 @@ public class WordTemple {
|
||||
this.wordTimes = wordTimes;
|
||||
}
|
||||
|
||||
public static WordTemple get() {
|
||||
return Word_Temple;
|
||||
}
|
||||
|
||||
public List<Sentence> getSentences() {
|
||||
return sentences;
|
||||
}
|
||||
|
@ -13,6 +13,15 @@ public class RandomForest {
|
||||
private Random random = new Random();
|
||||
private Tree[] forest;
|
||||
private double trustTh = 0.1;//信任阈值
|
||||
private double trustPunishment = 0.1;//信任惩罚
|
||||
|
||||
public double getTrustPunishment() {
|
||||
return trustPunishment;
|
||||
}
|
||||
|
||||
public void setTrustPunishment(double trustPunishment) {
|
||||
this.trustPunishment = trustPunishment;
|
||||
}
|
||||
|
||||
public double getTrustTh() {
|
||||
return trustTh;
|
||||
@ -33,21 +42,6 @@ public class RandomForest {
|
||||
}
|
||||
}
|
||||
|
||||
public void insertModel(RfModel rfModel) throws Exception {//注入模型
|
||||
if (rfModel != null) {
|
||||
Map<Integer, Node> nodeMap = rfModel.getNodeMap();
|
||||
forest = new Tree[nodeMap.size()];
|
||||
for (Map.Entry<Integer, Node> entry : nodeMap.entrySet()) {
|
||||
int key = entry.getKey();
|
||||
Tree tree = new Tree();
|
||||
forest[key] = tree;
|
||||
tree.setRootNode(entry.getValue());
|
||||
}
|
||||
} else {
|
||||
throw new Exception("model is null");
|
||||
}
|
||||
}
|
||||
|
||||
public RfModel getModel() {//获取模型
|
||||
RfModel rfModel = new RfModel();
|
||||
Map<Integer, Node> nodeMap = new HashMap<>();
|
||||
@ -89,6 +83,7 @@ public class RandomForest {
|
||||
return type;
|
||||
}
|
||||
|
||||
//rf初始化
|
||||
public void init(DataTable dataTable) throws Exception {
|
||||
//一棵树属性的数量
|
||||
if (dataTable.getSize() > 4) {
|
||||
@ -96,7 +91,7 @@ public class RandomForest {
|
||||
//int kNub = dataTable.getSize() - 1;
|
||||
// System.out.println("knNub==" + kNub);
|
||||
for (int i = 0; i < forest.length; i++) {
|
||||
Tree tree = new Tree(getRandomData(dataTable, kNub));
|
||||
Tree tree = new Tree(getRandomData(dataTable, kNub), trustPunishment);
|
||||
forest[i] = tree;
|
||||
}
|
||||
} else {
|
||||
@ -119,6 +114,7 @@ public class RandomForest {
|
||||
}
|
||||
}
|
||||
|
||||
//从总属性列表中随机挑选属性kNub个属性数量
|
||||
private DataTable getRandomData(DataTable dataTable, int kNub) throws Exception {
|
||||
Set<String> attr = dataTable.getKeyType();
|
||||
Set<String> myName = new HashSet<>();
|
||||
|
@ -18,6 +18,7 @@ public class Tree {//决策树
|
||||
private List<Integer> endList;//最终结果分类
|
||||
private List<Node> lastNodes = new ArrayList<>();//最后一层节点集合
|
||||
private Random random = new Random();
|
||||
private double trustPunishment;//信任惩罚
|
||||
|
||||
public Node getRootNode() {
|
||||
return rootNode;
|
||||
@ -36,11 +37,13 @@ public class Tree {//决策树
|
||||
private double gainRatio;//信息增益率
|
||||
}
|
||||
|
||||
public Tree() {
|
||||
public Tree(double trustPunishment) {
|
||||
this.trustPunishment = trustPunishment;
|
||||
}
|
||||
|
||||
public Tree(DataTable dataTable) throws Exception {
|
||||
public Tree(DataTable dataTable, double trustPunishment) throws Exception {
|
||||
if (dataTable != null && dataTable.getKey() != null) {
|
||||
this.trustPunishment = trustPunishment;
|
||||
this.dataTable = dataTable;
|
||||
} else {
|
||||
throw new Exception("dataTable is empty");
|
||||
@ -221,7 +224,7 @@ public class Tree {//决策树
|
||||
private void punishment(TreeWithTrust treeWithTrust) {//信任惩罚
|
||||
//System.out.println("惩罚");
|
||||
double trust = treeWithTrust.getTrust();//获取当前信任值
|
||||
trust = ArithUtil.mul(trust, WordTemple.get().getTrustPunishment());
|
||||
trust = ArithUtil.mul(trust, trustPunishment);
|
||||
treeWithTrust.setTrust(trust);
|
||||
}
|
||||
|
||||
|
@ -70,17 +70,15 @@ public class LangTest {
|
||||
public static void test() throws Exception {
|
||||
//创建模板读取累
|
||||
TemplateReader templateReader = new TemplateReader();
|
||||
WordTemple wordTemple = new WordTemple();//初始化语言模版
|
||||
wordTemple.setTreeNub(9);
|
||||
//读取语言模版,第一个参数是模版地址,第二个参数是编码方式 (教程里的第三个参数已经省略)
|
||||
//同时也是学习过程
|
||||
templateReader.read("D:\\b/a.txt", "UTF-8");
|
||||
//学习结束获取模型参数
|
||||
// WordModel wordModel = WordTemple.get().getModel();
|
||||
//不用学习注入模型参数
|
||||
// WordTemple.get().insertModel(wordModel);
|
||||
Talk talk = new Talk();
|
||||
templateReader.read("/Users/lidapeng/Desktop/myDocument/model.txt", "UTF-8", wordTemple);
|
||||
Talk talk = new Talk(wordTemple);
|
||||
//输入语句进行识别,若有标点符号会形成LIST中的每个元素
|
||||
//返回的集合中每个值代表了输入语句,每个标点符号前语句的分类
|
||||
List<Integer> list = talk.talk("有个快递尽快代我邮寄出去");
|
||||
List<Integer> list = talk.talk("空调坏了");
|
||||
System.out.println(list);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user