mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 12:47:35 +08:00
eca0c6f8c4
* pass all tests * prepare CWS & POS API * update tutorials * add README.md in tutorials/ & api/
249 lines
6.2 KiB
Plaintext
249 lines
6.2 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"source": [
|
|
"# fastNLP 1分钟上手教程"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## step 1\n",
|
|
"读取数据集"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"c:\\users\\zyfeng\\miniconda3\\envs\\fastnlp\\lib\\site-packages\\tqdm\\autonotebook\\__init__.py:14: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
|
|
" \" (e.g. in jupyter console)\", TqdmExperimentalWarning)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import sys\n",
|
|
"sys.path.append(\"../\")\n",
|
|
"\n",
|
|
"from fastNLP import DataSet\n",
|
|
"\n",
|
|
"data_path = \"./sample_data/tutorial_sample_dataset.csv\"\n",
|
|
"ds = DataSet.read_csv(data_path, headers=('raw_sentence', 'label'), sep='\\t')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'raw_sentence': This quiet , introspective and entertaining independent is worth seeking . type=str,\n",
|
|
"'label': 4 type=str}"
|
|
]
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"ds[1]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## step 2\n",
|
|
"数据预处理\n",
|
|
"1. 类型转换\n",
|
|
"2. 切分验证集\n",
|
|
"3. 构建词典"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# 将所有数字转为小写\n",
|
|
"ds.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n",
|
|
"# label转int\n",
|
|
"ds.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True)\n",
|
|
"\n",
|
|
"def split_sent(ins):\n",
|
|
" return ins['raw_sentence'].split()\n",
|
|
"ds.apply(split_sent, new_field_name='words', is_input=True)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Train size: 54\n",
|
|
"Test size: 23\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# 分割训练集/验证集\n",
|
|
"train_data, dev_data = ds.split(0.3)\n",
|
|
"print(\"Train size: \", len(train_data))\n",
|
|
"print(\"Test size: \", len(dev_data))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from fastNLP import Vocabulary\n",
|
|
"vocab = Vocabulary(min_freq=2)\n",
|
|
"train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n",
|
|
"\n",
|
|
"# index句子, Vocabulary.to_index(word)\n",
|
|
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n",
|
|
"dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## step 3\n",
|
|
" 定义模型"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from fastNLP.models import CNNText\n",
|
|
"model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## step 4\n",
|
|
"开始训练"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"input fields after batch(if batch size is 2):\n",
|
|
"\twords: (1)type:numpy.ndarray (2)dtype:object, (3)shape:(2,) \n",
|
|
"\tword_seq: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 25]) \n",
|
|
"target fields after batch(if batch size is 2):\n",
|
|
"\tlabel_seq: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
|
|
"\n",
|
|
"training epochs started 2019-01-12 17-00-48\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "23979df0f63e446fbb0406b919b91dd3",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6), HTML(value='')), layout=Layout(display='i…"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Evaluation at Epoch 1/3. Step:2/6. AccuracyMetric: acc=0.173913\n",
|
|
"Evaluation at Epoch 2/3. Step:4/6. AccuracyMetric: acc=0.26087\n",
|
|
"Evaluation at Epoch 3/3. Step:6/6. AccuracyMetric: acc=0.304348\n",
|
|
"\n",
|
|
"In Epoch:3/Step:6, got best dev performance:AccuracyMetric: acc=0.304348\n",
|
|
"Reloaded the best model.\n",
|
|
"Train finished!\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n",
|
|
"trainer = Trainer(model=model, \n",
|
|
" train_data=train_data, \n",
|
|
" dev_data=dev_data,\n",
|
|
" loss=CrossEntropyLoss(),\n",
|
|
" metrics=AccuracyMetric()\n",
|
|
" )\n",
|
|
"trainer.train()\n",
|
|
"print('Train finished!')\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 本教程结束。更多操作请参考进阶教程。"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 1
|
|
}
|