mirror of
https://gitee.com/dromara/easyAi.git
synced 2024-11-30 10:47:49 +08:00
添加语言分类器
This commit is contained in:
parent
c94423ace7
commit
4b77887601
90
src/main/java/org/wlld/naturalLanguage/LangBody.java
Normal file
90
src/main/java/org/wlld/naturalLanguage/LangBody.java
Normal file
@ -0,0 +1,90 @@
|
||||
package org.wlld.naturalLanguage;
|
||||
|
||||
/**
|
||||
* @author lidapeng
|
||||
* @description
|
||||
* @date 6:43 下午 2020/2/23
|
||||
*/
|
||||
public class LangBody {
|
||||
private int a1;
|
||||
private int a2;
|
||||
private int a3;
|
||||
private int a4;
|
||||
private int a5;
|
||||
private int a6;
|
||||
private int a7;
|
||||
private int a8;
|
||||
private int key;
|
||||
|
||||
public int getA1() {
|
||||
return a1;
|
||||
}
|
||||
|
||||
public void setA1(int a1) {
|
||||
this.a1 = a1;
|
||||
}
|
||||
|
||||
public int getA2() {
|
||||
return a2;
|
||||
}
|
||||
|
||||
public void setA2(int a2) {
|
||||
this.a2 = a2;
|
||||
}
|
||||
|
||||
public int getA3() {
|
||||
return a3;
|
||||
}
|
||||
|
||||
public void setA3(int a3) {
|
||||
this.a3 = a3;
|
||||
}
|
||||
|
||||
public int getA4() {
|
||||
return a4;
|
||||
}
|
||||
|
||||
public void setA4(int a4) {
|
||||
this.a4 = a4;
|
||||
}
|
||||
|
||||
public int getA5() {
|
||||
return a5;
|
||||
}
|
||||
|
||||
public void setA5(int a5) {
|
||||
this.a5 = a5;
|
||||
}
|
||||
|
||||
public int getA6() {
|
||||
return a6;
|
||||
}
|
||||
|
||||
public void setA6(int a6) {
|
||||
this.a6 = a6;
|
||||
}
|
||||
|
||||
public int getA7() {
|
||||
return a7;
|
||||
}
|
||||
|
||||
public void setA7(int a7) {
|
||||
this.a7 = a7;
|
||||
}
|
||||
|
||||
public int getA8() {
|
||||
return a8;
|
||||
}
|
||||
|
||||
public void setA8(int a8) {
|
||||
this.a8 = a8;
|
||||
}
|
||||
|
||||
public int getKey() {
|
||||
return key;
|
||||
}
|
||||
|
||||
public void setKey(int key) {
|
||||
this.key = key;
|
||||
}
|
||||
}
|
@ -1,6 +1,5 @@
|
||||
package org.wlld.naturalLanguage;
|
||||
|
||||
import org.omg.Messaging.SYNC_WITH_TRANSPORT;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
@ -14,8 +13,13 @@ public class Sentence {
|
||||
private Word firstWord;
|
||||
private List<Word> waitWords = new ArrayList<>();//词
|
||||
private List<String> keyWords;//分词结果下标按照时间序列排序
|
||||
private List<Integer> features = new ArrayList<>();//时序特征
|
||||
private int key;
|
||||
|
||||
public List<Integer> getFeatures() {
|
||||
return features;
|
||||
}
|
||||
|
||||
public List<String> getKeyWords() {
|
||||
return keyWords;
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
package org.wlld.naturalLanguage;
|
||||
|
||||
|
||||
import org.wlld.randomForest.RandomForest;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@ -11,8 +13,10 @@ import java.util.List;
|
||||
*/
|
||||
public class Talk {
|
||||
private List<WorldBody> allWorld = WordTemple.get().getAllWorld();//所有词集合
|
||||
private RandomForest randomForest = WordTemple.get().getRandomForest();//获取随机森林模型
|
||||
private List<List<String>> wordTimes = WordTemple.get().getWordTimes();
|
||||
|
||||
public void talk(String sentence) {
|
||||
public void talk(String sentence) throws Exception {
|
||||
String rgm = null;
|
||||
if (sentence.indexOf(",") > -1) {
|
||||
rgm = ",";
|
||||
@ -33,11 +37,50 @@ public class Talk {
|
||||
sentences.add(sentenceWords);
|
||||
}
|
||||
restructure(sentences);
|
||||
for (Sentence sentence1 : sentences) {
|
||||
System.out.println(sentence1.getKeyWords());
|
||||
//进行识别
|
||||
if (randomForest != null) {
|
||||
for (Sentence sentence1 : sentences) {
|
||||
List<Integer> features = sentence1.getFeatures();
|
||||
List<String> keyWords = sentence1.getKeyWords();
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int nub = 0;
|
||||
List<String> words = wordTimes.get(i);
|
||||
String word = keyWords.get(i);
|
||||
if (word != null) {
|
||||
nub = getNub(words, word);
|
||||
}
|
||||
features.add(nub);
|
||||
}
|
||||
LangBody langBody = new LangBody();
|
||||
langBody.setA1(features.get(0));
|
||||
langBody.setA2(features.get(1));
|
||||
langBody.setA3(features.get(2));
|
||||
langBody.setA4(features.get(3));
|
||||
langBody.setA5(features.get(4));
|
||||
langBody.setA6(features.get(5));
|
||||
langBody.setA7(features.get(6));
|
||||
langBody.setA8(features.get(7));
|
||||
int type = randomForest.forest(langBody);
|
||||
System.out.println("type==" + type);
|
||||
}
|
||||
} else {
|
||||
System.out.println("随机森林没有训练");
|
||||
}
|
||||
}
|
||||
|
||||
private int getNub(List<String> words, String testWord) {
|
||||
int nub = 0;
|
||||
int size = words.size();
|
||||
for (int i = 0; i < size; i++) {
|
||||
String word = words.get(i);
|
||||
if (testWord.hashCode() == word.hashCode() && testWord.equals(word)) {
|
||||
nub = i + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return nub;
|
||||
}
|
||||
|
||||
private void catchSentence(String sentence, Sentence sentenceWords) {//把句子拆开
|
||||
int len = sentence.length();
|
||||
for (int i = 0; i < len; i++) {
|
||||
|
@ -65,7 +65,7 @@ public class TemplateReader {//模板读取类
|
||||
word();
|
||||
}
|
||||
|
||||
public void word() {
|
||||
public void word() throws Exception {
|
||||
Tokenizer tokenizer = new Tokenizer();
|
||||
tokenizer.start(model);
|
||||
}
|
||||
|
@ -1,11 +1,11 @@
|
||||
package org.wlld.naturalLanguage;
|
||||
|
||||
import org.wlld.randomForest.DataTable;
|
||||
import org.wlld.randomForest.RandomForest;
|
||||
import org.wlld.tools.ArithUtil;
|
||||
import org.wlld.tools.Frequency;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* @author lidapeng
|
||||
@ -15,9 +15,10 @@ import java.util.Map;
|
||||
public class Tokenizer extends Frequency {
|
||||
private List<Sentence> sentences = WordTemple.get().getSentences();//所有断句
|
||||
private List<WorldBody> allWorld = WordTemple.get().getAllWorld();//所有词集合
|
||||
private List<List<String>> wordTimes = WordTemple.get().getWordTimes();//所有词编号
|
||||
private Word nowWord;//上一次出现的关键字
|
||||
|
||||
public void start(Map<Integer, List<String>> model) {
|
||||
public void start(Map<Integer, List<String>> model) throws Exception {
|
||||
//model的主键是类别,值是该类别语句的集合
|
||||
for (Map.Entry<Integer, List<String>> mod : model.entrySet()) {
|
||||
if (mod.getKey() != 0) {
|
||||
@ -35,15 +36,83 @@ public class Tokenizer extends Frequency {
|
||||
}
|
||||
restructure();//对集合中的词进行词频统计
|
||||
//这里分词已经结束,对词进行编号
|
||||
//test();
|
||||
number();
|
||||
//进入随机森林进行学习
|
||||
study();
|
||||
}
|
||||
|
||||
private void test() {//分词测试类
|
||||
private void number() {//分词编号
|
||||
for (Sentence sentence : sentences) {
|
||||
System.out.println(sentence.getKeyWords());
|
||||
List<Integer> features = sentence.getFeatures();
|
||||
List<String> sentenceList = sentence.getKeyWords();
|
||||
int size = sentenceList.size();//时间序列的深度
|
||||
for (int i = 0; i < size; i++) {
|
||||
if (!wordTimes.contains(i)) {
|
||||
wordTimes.add(new ArrayList<>());
|
||||
}
|
||||
List<String> list = wordTimes.get(i);
|
||||
int nub = list.size();
|
||||
features.add(nub);
|
||||
list.add(sentenceList.get(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void study() throws Exception {
|
||||
Set<String> column = new HashSet<>();
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int t = i + 1;
|
||||
column.add("a" + t);
|
||||
}
|
||||
column.add("key");
|
||||
DataTable dataTable = new DataTable(column);
|
||||
dataTable.setKey("key");
|
||||
//初始化随机森林
|
||||
RandomForest randomForest = new RandomForest(5);
|
||||
WordTemple.get().setRandomForest(randomForest);//保存随机森林到模版
|
||||
randomForest.init(dataTable);
|
||||
for (Sentence sentence : sentences) {
|
||||
LangBody langBody = new LangBody();
|
||||
List<Integer> features = sentence.getFeatures();
|
||||
langBody.setKey(sentence.getKey());
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int nub = 0;
|
||||
if (features.contains(i)) {
|
||||
nub = features.get(i);
|
||||
}
|
||||
int t = i + 1;
|
||||
switch (t) {
|
||||
case 1:
|
||||
langBody.setA1(nub);
|
||||
break;
|
||||
case 2:
|
||||
langBody.setA2(nub);
|
||||
break;
|
||||
case 3:
|
||||
langBody.setA3(nub);
|
||||
break;
|
||||
case 4:
|
||||
langBody.setA4(nub);
|
||||
break;
|
||||
case 5:
|
||||
langBody.setA5(nub);
|
||||
break;
|
||||
case 6:
|
||||
langBody.setA6(nub);
|
||||
break;
|
||||
case 7:
|
||||
langBody.setA7(nub);
|
||||
break;
|
||||
case 8:
|
||||
langBody.setA8(nub);
|
||||
break;
|
||||
}
|
||||
}
|
||||
randomForest.insert(langBody);
|
||||
}
|
||||
randomForest.study();
|
||||
}
|
||||
|
||||
private void restructure() {//对句子里面的Word进行词频统计
|
||||
for (Sentence words : sentences) {
|
||||
List<WorldBody> listWord = allWorld;
|
||||
|
@ -1,5 +1,7 @@
|
||||
package org.wlld.naturalLanguage;
|
||||
|
||||
import org.wlld.randomForest.RandomForest;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@ -12,10 +14,28 @@ public class WordTemple {
|
||||
private static WordTemple Word_Temple = new WordTemple();
|
||||
private List<Sentence> sentences = new ArrayList<>();//所有断句
|
||||
private List<WorldBody> allWorld = new ArrayList<>();//所有词集合
|
||||
private List<List<String>> wordTimes = new ArrayList<>();//词编号
|
||||
private RandomForest randomForest;//保存的随机森林模型
|
||||
|
||||
public RandomForest getRandomForest() {
|
||||
return randomForest;
|
||||
}
|
||||
|
||||
public void setRandomForest(RandomForest randomForest) {
|
||||
this.randomForest = randomForest;
|
||||
}
|
||||
|
||||
private WordTemple() {
|
||||
}
|
||||
|
||||
public List<List<String>> getWordTimes() {
|
||||
return wordTimes;
|
||||
}
|
||||
|
||||
public void setWordTimes(List<List<String>> wordTimes) {
|
||||
this.wordTimes = wordTimes;
|
||||
}
|
||||
|
||||
public static WordTemple get() {
|
||||
return Word_Temple;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user