mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-29 18:59:01 +08:00
update tutorial-3456 lxr 220603
This commit is contained in:
parent
82b06767f5
commit
3797f91434
@ -801,24 +801,24 @@
|
||||
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
|
||||
"\n",
|
||||
"# 接着,导入数据,先生成为 dataset 形式,再变成 dataset-dict,并转为 databundle 形式\n",
|
||||
"datasets = DataSet.from_pandas(pd.read_csv('./data/test4dataset.tsv'))\n",
|
||||
"datasets = DataSet.from_pandas(pd.read_csv('./data/test4dataset.tsv', sep='\\t'))\n",
|
||||
"train_ds, test_ds = datasets.split(ratio=0.7)\n",
|
||||
"data_bundle = DataBundle(datasets={'train': train_ds, 'test': test_ds})\n",
|
||||
"\n",
|
||||
"# 然后,通过 tokenizer.encode_plus 函数,进行文本分词标注、修改并补充数据包内容\n",
|
||||
"encode = partial(tokenizer.encode_plus, max_length=100, truncation=True,\n",
|
||||
" return_attention_mask=True)\n",
|
||||
"data_bundle.apply_field_more(encode, field_name='text', progress_bar='tqdm')\n",
|
||||
"data_bundle.apply_field_more(encode, field_name='Sentence', progress_bar='tqdm')\n",
|
||||
"\n",
|
||||
"# 在修改好 'text' 字段的文本信息后,接着处理 'label' 字段的预测信息\n",
|
||||
"target_vocab = Vocabulary(padding=None, unknown=None)\n",
|
||||
"target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label')\n",
|
||||
"target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label',\n",
|
||||
"target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment')\n",
|
||||
"target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment',\n",
|
||||
" new_field_name='target')\n",
|
||||
"\n",
|
||||
"# 最后,通过 data_bundle 的其他一些函数,完成善后内容\n",
|
||||
"data_bundle.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n",
|
||||
"data_bundle.set_ignore('label', 'text') \n",
|
||||
"data_bundle.set_ignore('SentenceId', 'Sentiment', 'Sentence') \n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
|
@ -9,9 +9,9 @@
|
||||
"\n",
|
||||
"  1   fastNLP 中的 dataloader\n",
|
||||
" \n",
|
||||
"    1.1   dataloader 的职责描述\n",
|
||||
"    1.1   dataloader 的基本介绍\n",
|
||||
"\n",
|
||||
"    1.2   dataloader 的基本使用\n",
|
||||
"    1.2   dataloader 的函数创建\n",
|
||||
"\n",
|
||||
"  2   fastNLP 中 dataloader 的延伸\n",
|
||||
"\n",
|
||||
@ -27,32 +27,143 @@
|
||||
"source": [
|
||||
"## 1. fastNLP 中的 dataloader\n",
|
||||
"\n",
|
||||
"### 1.1 dataloader 的职责描述\n",
|
||||
"### 1.1 dataloader 的基本介绍\n",
|
||||
"\n",
|
||||
"在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前"
|
||||
"在`fastNLP 0.8`的开发中,最关键的开发目标就是**实现`fastNLP`对当前主流机器学习框架**,例如\n",
|
||||
"\n",
|
||||
"  **较为火热的`pytorch`**,以及**国产的`paddle`和`jittor`的兼容**,扩大受众的同时,也是助力国产\n",
|
||||
"\n",
|
||||
"本着分而治之的思想,我们可以将`fastNLP 0.8`对`pytorch`、`paddle`、`jittor`框架的兼容,划分为\n",
|
||||
"\n",
|
||||
"    **对数据预处理**、**批量`batch`的划分与补齐**、**模型训练**、**模型评测**,**四个部分的兼容**\n",
|
||||
"\n",
|
||||
"  针对数据预处理,我们已经在`tutorial-1`中介绍了`dataset`和`vocabulary`的使用\n",
|
||||
"\n",
|
||||
"    而结合`tutorial-0`,我们可以发现**数据预处理环节本质上是框架无关的**\n",
|
||||
"\n",
|
||||
"    因为在不同框架下,读取的原始数据格式都差异不大,彼此也很容易转换\n",
|
||||
"\n",
|
||||
"只有涉及到张量、模型,不同框架才展现出其各自的特色:**`pytorch`中的`tensor`和`nn.Module`**\n",
|
||||
"\n",
|
||||
"    **在`paddle`中称为`tensor`和`nn.Layer`**,**在`jittor`中则称为`Var`和`Module`**\n",
|
||||
"\n",
|
||||
"    因此,**模型训练、模型评测**,**是兼容的重难点**,我们将会在`tutorial-5`中详细介绍\n",
|
||||
"\n",
|
||||
"  针对批量`batch`的处理,作为`fastNLP 0.8`中框架无关部分想框架相关部分的过渡\n",
|
||||
"\n",
|
||||
"    就是`dataloader`模块的职责,这也是本篇教程`tutorial-3`讲解的重点\n",
|
||||
"\n",
|
||||
"**`dataloader`模块的职责**,详细划分可以包含以下三部分,**采样划分、补零对齐、框架匹配**\n",
|
||||
"\n",
|
||||
"    第一,确定`batch`大小,确定采样方式,划分后通过迭代器即可得到`batch`序列\n",
|
||||
"\n",
|
||||
"    第二,对于序列处理,这也是`fastNLP`主要针对的,将同个`batch`内的数据对齐\n",
|
||||
"\n",
|
||||
"    第三,**`batch`内数据格式要匹配框架**,**但`batch`结构需保持一致**,**参数匹配机制**\n",
|
||||
"\n",
|
||||
"  对此,`fastNLP 0.8`给出了 **`TorchDataLoader`、`PaddleDataLoader`和`JittorDataLoader`**\n",
|
||||
"\n",
|
||||
"    分别针对并匹配不同框架,但彼此之间参数名、属性、方法仍然类似,前两者大致如下表所示\n",
|
||||
"\n",
|
||||
"| <div align=\"center\">名称</div> | <div align=\"center\">参数</div> | <div align=\"center\">属性</div> | <div align=\"center\">功能</div> | <div align=\"center\">内容</div> |\n",
|
||||
"|:--|:--:|:--:|:--|:--|\n",
|
||||
"| **`dataset`** | √ | √ | 指定`dataloader`的数据内容 | |\n",
|
||||
"| `batch_size` | √ | √ | 指定`dataloader`的`batch`大小 | 默认`16` |\n",
|
||||
"| `shuffle` | √ | √ | 指定`dataloader`的数据是否打乱 | 默认`False` |\n",
|
||||
"| `collate_fn` | √ | √ | 指定`dataloader`的`batch`打包方法 | 视框架而定 |\n",
|
||||
"| `sampler` | √ | √ | ? | 默认`None` |\n",
|
||||
"| `batch_sampler` | √ | √ | ? | 默认`None` |\n",
|
||||
"| `drop_last` | √ | √ | 指定`dataloader`划分`batch`时是否丢弃剩余的 | 默认`False` |\n",
|
||||
"| `cur_batch_indices` | | √ | 记录`dataloader`当前遍历批量序号 | |\n",
|
||||
"| `num_workers` | √ | √ | 指定`dataloader`开启子进程数量 | 默认`0` |\n",
|
||||
"| `worker_init_fn` | √ | √ | 指定`dataloader`子进程初始方法 | 默认`None` |\n",
|
||||
"| `generator` | √ | √ | 指定`dataloader`子进程随机种子 | 默认`None` |\n",
|
||||
"| `prefetch_factor` | | √ | 指定为每个`worker`装载的`sampler`数量 | 默认`2` |"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eb8fb51c",
|
||||
"id": "60a8a224",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 1.2 dataloader 的基本使用\n",
|
||||
"  论及`dataloader`的函数,其中,`get_batch_indices`用来获取当前遍历到的`batch`序号,其他函数\n",
|
||||
"\n",
|
||||
"在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前,"
|
||||
"    包括`set_ignore`、`set_pad`和`databundle`类似,请参考`tutorial-2`,此处不做更多介绍\n",
|
||||
"\n",
|
||||
"    以下是`tutorial-2`中已经介绍过的数据预处理流程,接下来是对相关数据进行`dataloader`处理"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"id": "aca72b49",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Processing: 0%| | 0/4 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Processing: 0%| | 0/2 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Processing: 0%| | 0/2 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"+------------+------------------+-----------+------------------+--------------------+--------------------+\n",
|
||||
"| SentenceId | Sentence | Sentiment | input_ids | token_type_ids | attention_mask |\n",
|
||||
"+------------+------------------+-----------+------------------+--------------------+--------------------+\n",
|
||||
"| 5 | A comedy-dram... | positive | [101, 1037, 4... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
|
||||
"| 2 | This quiet , ... | positive | [101, 2023, 4... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
|
||||
"| 1 | A series of e... | negative | [101, 1037, 2... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
|
||||
"| 6 | The Importanc... | neutral | [101, 1996, 5... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
|
||||
"+------------+------------------+-----------+------------------+--------------------+--------------------+\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.append('..')\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"from functools import partial\n",
|
||||
"from fastNLP.transformers.torch import BertTokenizer\n",
|
||||
@ -63,69 +174,112 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"class PipeDemo:\n",
|
||||
" def __init__(self, tokenizer='bert-base-uncased', num_proc=1):\n",
|
||||
" def __init__(self, tokenizer='bert-base-uncased'):\n",
|
||||
" self.tokenizer = BertTokenizer.from_pretrained(tokenizer)\n",
|
||||
" self.num_proc = num_proc\n",
|
||||
"\n",
|
||||
" def process_from_file(self, path='./data/test4dataset.tsv'):\n",
|
||||
" datasets = DataSet.from_pandas(pd.read_csv(path))\n",
|
||||
" datasets = DataSet.from_pandas(pd.read_csv(path, sep='\\t'))\n",
|
||||
" train_ds, test_ds = datasets.split(ratio=0.7)\n",
|
||||
" train_ds, dev_ds = datasets.split(ratio=0.8)\n",
|
||||
" data_bundle = DataBundle(datasets={'train': train_ds, 'dev': dev_ds, 'test': test_ds})\n",
|
||||
"\n",
|
||||
" encode = partial(self.tokenizer.encode_plus, max_length=100, truncation=True,\n",
|
||||
" return_attention_mask=True)\n",
|
||||
" data_bundle.apply_field_more(encode, field_name='text', num_proc=self.num_proc)\n",
|
||||
"\n",
|
||||
" data_bundle.apply_field_more(encode, field_name='Sentence', progress_bar='tqdm')\n",
|
||||
" \n",
|
||||
" target_vocab = Vocabulary(padding=None, unknown=None)\n",
|
||||
"\n",
|
||||
" target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label')\n",
|
||||
" target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label',\n",
|
||||
" target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment')\n",
|
||||
" target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment',\n",
|
||||
" new_field_name='target')\n",
|
||||
"\n",
|
||||
" data_bundle.set_pad('input_ids', pad_val=self.tokenizer.pad_token_id)\n",
|
||||
" data_bundle.set_ignore('label', 'text') \n",
|
||||
" return data_bundle"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "de53bff4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"  "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "57a29cb9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pipe = PipeDemo(tokenizer='bert-base-uncased', num_proc=4)\n",
|
||||
" data_bundle.set_ignore('SentenceId', 'Sentence', 'Sentiment') \n",
|
||||
" return data_bundle\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"pipe = PipeDemo(tokenizer='bert-base-uncased')\n",
|
||||
"\n",
|
||||
"data_bundle = pipe.process_from_file('./data/test4dataset.tsv')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "226bb081",
|
||||
"id": "76e6b8ab",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"  "
|
||||
"### 1.2 dataloader 的函数创建\n",
|
||||
"\n",
|
||||
"在`fastNLP 0.8`中,**更方便、可能更常用的`dataloader`创建方法是通过`prepare_xx_dataloader`函数**\n",
|
||||
"\n",
|
||||
"  例如下方的`prepare_torch_dataloader`函数,指定必要参数,读取数据集,生成对应`dataloader`\n",
|
||||
"\n",
|
||||
"  类型为`TorchDataLoader`,只能适用于`pytorch`框架,因此对应`trainer`初始化时`driver='torch'`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7827557d",
|
||||
"execution_count": 7,
|
||||
"id": "5fd60e42",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP import prepare_torch_dataloader\n",
|
||||
"\n",
|
||||
"dl_bundle = prepare_torch_dataloader(data_bundle, batch_size=arg.batch_size)"
|
||||
"train_dataset = data_bundle.get_dataset('train')\n",
|
||||
"evaluate_dataset = data_bundle.get_dataset('dev')\n",
|
||||
"\n",
|
||||
"train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
|
||||
"evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7c53f181",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```python\n",
|
||||
"trainer = Trainer(\n",
|
||||
" model=model,\n",
|
||||
" train_dataloader=train_dataloader,\n",
|
||||
" optimizers=optimizer,\n",
|
||||
"\t...\n",
|
||||
"\tdriver='torch',\n",
|
||||
"\tdevice='cuda',\n",
|
||||
"\t...\n",
|
||||
" evaluate_dataloaders=evaluate_dataloader, \n",
|
||||
" metrics={'acc': Accuracy()},\n",
|
||||
"\t...\n",
|
||||
")\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9f457a6e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"之所以称`prepare_xx_dataloader`函数更方便,是因为其**导入对象不仅可也是`DataSet`类型**,**还可以**\n",
|
||||
"\n",
|
||||
"  **是`DataBundle`类型**,不过数据集名称需要是`'train'`、`'dev'`、`'test'`供`fastNLP`识别\n",
|
||||
"\n",
|
||||
"  例如下方就是**直接通过`prepare_paddle_dataloader`函数生成基于`PaddleDataLoader`的字典**\n",
|
||||
"\n",
|
||||
"  在接下来`trainer`的初始化过程中,按如下方式使用即可,除了初始化时`driver='paddle'`外\n",
|
||||
"\n",
|
||||
"    这里也可以看出 **`evaluate_dataloaders`的妙处**,一次评测可以针对多个数据集"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "7827557d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP import prepare_paddle_dataloader\n",
|
||||
"\n",
|
||||
"dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -133,16 +287,14 @@
|
||||
"id": "d898cf40",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"  \n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"trainer = Trainer(\n",
|
||||
" model=model,\n",
|
||||
" train_dataloader=dl_bundle['train'],\n",
|
||||
" optimizers=optimizer,\n",
|
||||
"\t...\n",
|
||||
"\tdriver=\"torch\",\n",
|
||||
"\tdevice='cuda',\n",
|
||||
"\tdriver='paddle',\n",
|
||||
"\tdevice='gpu',\n",
|
||||
"\t...\n",
|
||||
" evaluate_dataloaders={'dev': dl_bundle['dev'], 'test': dl_bundle['test']}, \n",
|
||||
" metrics={'acc': Accuracy()},\n",
|
||||
@ -187,6 +339,14 @@
|
||||
"print(type(dl_bundle), type(dl_bundle['train']))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5f816ef5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"  "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -312,6 +312,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.append('..')\n",
|
||||
"\n",
|
||||
"from fastNLP import Metric\n",
|
||||
"\n",
|
||||
"class MyMetric(Metric):\n",
|
||||
@ -333,33 +336,6 @@
|
||||
" return {'prefix': acc}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "af3f8c63",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"  模型使用方面,此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型,实现`SST-2`二分类"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2fd210c5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.append('..')\n",
|
||||
"\n",
|
||||
"from fastNLP.models.torch import CNNText\n",
|
||||
"\n",
|
||||
"model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
|
||||
"\n",
|
||||
"from torch.optim import AdamW\n",
|
||||
"\n",
|
||||
"optimizers = AdamW(params=model.parameters(), lr=5e-4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0155f447",
|
||||
@ -389,9 +365,9 @@
|
||||
"id": "e9d81760",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"接着是数据预处理,需要注意的是,由于`MyMetric`的`update`函数中,输入参数名称为`pred`和`true`\n",
|
||||
"  在数据预处理中,需要注意的是,由于`MyMetric`的`update`函数中,输入参数名称为`pred`和`true`\n",
|
||||
"\n",
|
||||
"  对应地,需要将数据集中表示预测目标的字段,调整为`true`(预定义的`metric`,应调整为`target`"
|
||||
"    对应地,需要将数据集中表示预测目标的字段,调整为`true`(预定义的`metric`,应调整为`target`"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -429,14 +405,136 @@
|
||||
"evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "af3f8c63",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"  模型使用方面,此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型,实现`SST-2`二分类"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2fd210c5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP.models.torch import CNNText\n",
|
||||
"\n",
|
||||
"model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
|
||||
"\n",
|
||||
"from torch.optim import AdamW\n",
|
||||
"\n",
|
||||
"optimizers = AdamW(params=model.parameters(), lr=5e-4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6e723b87",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. fastNLP 中 trainer 的补充介绍\n",
|
||||
"\n",
|
||||
"### 3.1 trainer 的内部结构\n",
|
||||
"\n",
|
||||
"在`tutorial-0`中,我们已经介绍了`trainer`的基本使用,从`tutorial-1`到`tutorial-4`,我们也已经展示了\n",
|
||||
"\n",
|
||||
"  很多`trainer`的使用案例,这里通过表格,相对完整地介绍`trainer`模块的属性和初始化参数(标粗为必选参数\n",
|
||||
"\n",
|
||||
"| <div align=\"center\">名称</div> | <div align=\"center\">参数</div> | <div align=\"center\">属性</div> | <div align=\"center\">功能</div> | <div align=\"center\">内容</div> |\n",
|
||||
"|:--|:--:|:--:|:--|:--|\n",
|
||||
"| **`model`** | √ | √ | 指定`trainer`控制的模型 | 视框架而定,如`torch.nn.Module` |\n",
|
||||
"| **`driver`** | √ | | 指定`trainer`驱动的框架 | 包括`'torch'`、`'paddle'`、`'jittor'` |\n",
|
||||
"| | | √ | 记录`trainer`驱动的框架 | `Driver`类型,在初始化阶段生成 |\n",
|
||||
"| `device` | √ | | 指定`trainer`运行的卡位 | 例如`'cpu'`、`'cuda'`、`0`、`[0, 1]`等 |\n",
|
||||
"| | | √ | 记录`trainer`运行的卡位 | `Device`类型,在初始化阶段生成 |\n",
|
||||
"| `n_epochs` | √ | - | 指定`trainer`迭代的轮数 | 默认`20`,记录在`driver.n_epochs`中 |\n",
|
||||
"| **`optimizers`** | √ | √ | 指定`trainer`优化的方法 | 视框架而定,如`torch.optim.Adam` |\n",
|
||||
"| `metrics` | √ | √ | 指定`trainer`评测的方法 | 字典类型,如`{'acc': Metric()}` |\n",
|
||||
"| `evaluator` | | √ | 内置的`trainer`评测模块 | `Evaluator`类型,在初始化阶段生成 |\n",
|
||||
"| `input_mapping` | √ | √ | 调整`dataloader`的参数不匹配 | 函数类型,输出字典匹配`forward`输入参数 |\n",
|
||||
"| `output_mapping` | √ | √ | 调整`forward`输出的参数不匹配 | 函数类型,输出字典匹配`xx_step`输入参数 |\n",
|
||||
"| **`train_dataloader`** | √ | √ | 指定`trainer`训练的数据 | `DataLoader`类型,生成视框架而定 |\n",
|
||||
"| `evaluate_dataloaders` | √ | √ | 指定`trainer`评测的数据 | `DataLoader`类型,生成视框架而定 |\n",
|
||||
"| `train_fn` | √ | √ | 指定`trainer`获取某个批次的损失值 | 函数类型,默认为`model.train_step` |\n",
|
||||
"| `evaluate_fn` | √ | √ | 指定`trainer`获取某个批次的评估量 | 函数类型,默认为`model.evaluate_step` |\n",
|
||||
"| `batch_step_fn` | √ | √ | 指定`trainer`训练时前向传输一个批次的方式 | 函数类型,默认为`TrainBatchLoop.batch_step_fn` |\n",
|
||||
"| `evaluate_batch_step_fn` | √ | √ | 指定`trainer`评测时前向传输一个批次的方式 | 函数类型,默认为`EvaluateBatchLoop.batch_step_fn` |\n",
|
||||
"| `accumulation_steps` | √ | √ | 指定`trainer`训练时反向传播的频率 | 默认为`1`,即每个批次都反向传播 |\n",
|
||||
"| `evaluate_every` | √ | √ | 指定`evaluator`评测时计算的频率 | 默认`-1`表示每个循环一次,相反`1`表示每个批次一次 |\n",
|
||||
"| `progress_bar` | √ | √ | 指定`trainer`训练和评测时的进度条样式 | 包括`'auto'`、`'tqdm'`、`'raw'`、`'rich'` |\n",
|
||||
"| `callbacks` | √ | | 指定`trainer`训练时需要触发的函数 | `Callback`列表类型,详见`tutorial-7` |\n",
|
||||
"| `callback_manager` | | √ | 记录与管理`callbacks`相关内容 | `CallbackManager`类型,详见`tutorial-7` |\n",
|
||||
"| `monitor` | √ | √ | 辅助部分的`callbacks`相关内容 | 字符串/函数类型,详见`tutorial-7` |\n",
|
||||
"| `marker` | √ | √ | 标记`trainer`实例,辅助`callbacks`相关内容 | 字符串型,详见`tutorial-7` |\n",
|
||||
"| `trainer_state` | | √ | 记录`trainer`状态,辅助`callbacks`相关内容 | `TrainerState`类型,详见`tutorial-7` |\n",
|
||||
"| `state` | | √ | 记录`trainer`状态,辅助`callbacks`相关内容 | `State`类型,详见`tutorial-7` |\n",
|
||||
"| `fp16` | √ | √ | 指定`trainer`是否进行混合精度训练 | 布尔类型,默认`False` |"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2fc8b9f3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"  以及`trainer`模块内部的基础方法,相关进阶操作,如“`on`系列函数”、`callback`控制,请参考后续的`tutorial-7`\n",
|
||||
"\n",
|
||||
"| <div align=\"center\">名称</div> |<div align=\"center\">功能</div> | <div align=\"center\">主要参数</div> |\n",
|
||||
"|:--|:--|:--|\n",
|
||||
"| `run` | 控制`trainer`中模型的训练和评测 | 详见后文 |\n",
|
||||
"| `train_step` | 实现`trainer`训练中一个批数据的前向传播过程 | 输入`batch` |\n",
|
||||
"| `backward` | 实现`trainer`训练中一次损失的反向传播过程 | 输入`output` |\n",
|
||||
"| `zero_grad` | 实现`trainer`训练中`optimizers`的梯度置零 | 无输入 |\n",
|
||||
"| `step` | 实现`trainer`训练中`optimizers`的参数更新 | 无输入 |\n",
|
||||
"| `epoch_evaluate` | 实现`trainer`训练中每个循环的评测,实际是否执行取决于评测频率 | 无输入 |\n",
|
||||
"| `step_evaluate` | 实现`trainer`训练中每个批次的评测,实际是否执行取决于评测频率 | 无输入 |\n",
|
||||
"| `save_model` | 保存`trainer`中的模型参数/状态字典至`fastnlp_model.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`False` |\n",
|
||||
"| `load_model` | 加载`trainer`中的模型参数/状态字典自`fastnlp_model.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只加载状态字典,默认`True` |\n",
|
||||
"| `save_checkpoint` | <div style=\"line-height:25px;\">保存`trainer`中模型参数/状态字典 以及 `callback`、`sampler`<br>和`optimizer`的状态至`fastnlp_model/checkpoint.pkl.tar`</div> | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`True` |\n",
|
||||
"| `load_checkpoint` | <div style=\"line-height:25px;\">加载`trainer`中模型参数/状态字典 以及 `callback`、`sampler`<br>和`optimizer`的状态自`fastnlp_model/checkpoint.pkl.tar`</div> | <div style=\"line-height:25px;\">`folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`True`<br>`resume_training`指明是否只精确到上次训练的批量,默认`True`</div> |\n",
|
||||
"| `add_callback_fn` | 在`trainer`初始化后添加`callback`函数 | 输入`event`指明回调时机,`fn`指明回调函数 |\n",
|
||||
"| `on` | 函数修饰器,将一个函数转变为`callback`函数 | 详见`tutorial-7` |\n",
|
||||
"\n",
|
||||
"<!-- ```python\n",
|
||||
"Trainer.__init__():\n",
|
||||
"\ton_after_trainer_initialized(trainer, driver)\n",
|
||||
"Trainer.run():\n",
|
||||
"\tif num_eval_sanity_batch > 0: # 如果设置了 num_eval_sanity_batch\n",
|
||||
"\t\ton_sanity_check_begin(trainer)\n",
|
||||
"\t\ton_sanity_check_end(trainer, sanity_check_res)\n",
|
||||
"\ttry:\n",
|
||||
"\t\ton_train_begin(trainer)\n",
|
||||
"\t\twhile cur_epoch_idx < n_epochs:\n",
|
||||
"\t\t\ton_train_epoch_begin(trainer)\n",
|
||||
"\t\t\twhile batch_idx_in_epoch<=num_batches_per_epoch:\n",
|
||||
"\t\t\t\ton_fetch_data_begin(trainer)\n",
|
||||
"\t\t\t\tbatch = next(dataloader)\n",
|
||||
"\t\t\t\ton_fetch_data_end(trainer)\n",
|
||||
"\t\t\t\ton_train_batch_begin(trainer, batch, indices)\n",
|
||||
"\t\t\t\ton_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping 后的\n",
|
||||
"\t\t\t\ton_after_backward(trainer)\n",
|
||||
"\t\t\t\ton_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n",
|
||||
"\t\t\t\ton_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n",
|
||||
"\t\t\t\ton_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n",
|
||||
"\t\t\t\ton_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n",
|
||||
"\t\t\t\ton_train_batch_end(trainer)\n",
|
||||
"\t\t\ton_train_epoch_end(trainer)\n",
|
||||
"\texcept BaseException:\n",
|
||||
"\t\tself.on_exception(trainer, exception)\n",
|
||||
"\tfinally:\n",
|
||||
"\t\ton_train_end(trainer)\n",
|
||||
"``` -->"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1e21df35",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"然后就是初始化`trainer`实例,其中`metrics`变量输入的键值对,字串`'suffix'`和之前定义的字串`'prefix'`\n",
|
||||
"紧接着,初始化`trainer`实例,继续完成`SST-2`分类,其中`metrics`输入的键值对,字串`'suffix'`和之前定义的\n",
|
||||
"\n",
|
||||
"  将拼接在一起显示到`trainer`的`progress bar`中,故完整的输出形式为`{'prefix#suffix': float}`"
|
||||
"  字串`'prefix'`将拼接在一起显示到`progress bar`中,故完整的输出形式为`{'prefix#suffix': float}`"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -462,51 +560,23 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6e723b87",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. fastNLP 中 trainer 的补充介绍\n",
|
||||
"\n",
|
||||
"### 3.1 trainer 的内部结构\n",
|
||||
"\n",
|
||||
"在`tutorial-0`中,我们已经介绍了`trainer`的基本使用,从`tutorial-1`到`tutorial-4`,我们也已经\n",
|
||||
"\n",
|
||||
"  展示了很多关于`trainer`的使用案例,以下我们先补充介绍训练模块`trainer`的一些内部结构\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"'accumulation_steps', 'add_callback_fn', 'backward', 'batch_idx_in_epoch', 'batch_step_fn',\n",
|
||||
"'callback_manager', 'check_batch_step_fn', 'cur_epoch_idx', 'data_device', 'dataloader',\n",
|
||||
"'device', 'driver', 'driver_name', 'epoch_evaluate', 'evaluate_batch_step_fn', 'evaluate_dataloaders',\n",
|
||||
"'evaluate_every', 'evaluate_fn', 'evaluator', 'extract_loss_from_outputs', 'fp16',\n",
|
||||
"'get_no_sync_context', 'global_forward_batches', 'has_checked_train_batch_loop',\n",
|
||||
"'input_mapping', 'kwargs', 'larger_better', 'load_checkpoint', 'load_model', 'marker',\n",
|
||||
"'metrics', 'model', 'model_device', 'monitor', 'move_data_to_device', 'n_epochs', 'num_batches_per_epoch',\n",
|
||||
"'on', 'on_after_backward', 'on_after_optimizers_step', 'on_after_trainer_initialized',\n",
|
||||
"'on_after_zero_grad', 'on_before_backward', 'on_before_optimizers_step', 'on_before_zero_grad',\n",
|
||||
"'on_evaluate_begin', 'on_evaluate_end', 'on_exception', 'on_fetch_data_begin', 'on_fetch_data_end',\n",
|
||||
"'on_load_checkpoint', 'on_load_model', 'on_sanity_check_begin', 'on_sanity_check_end',\n",
|
||||
"'on_save_checkpoint', 'on_save_model', 'on_train_batch_begin', 'on_train_batch_end',\n",
|
||||
"'on_train_begin', 'on_train_end', 'on_train_epoch_begin', 'on_train_epoch_end',\n",
|
||||
"'optimizers', 'output_mapping', 'progress_bar', 'run', 'run_evaluate',\n",
|
||||
"'save_checkpoint', 'save_model', 'start_batch_idx_in_epoch', 'state',\n",
|
||||
"'step', 'step_evaluate', 'total_batches', 'train_batch_loop', 'train_dataloader', 'train_fn', 'train_step',\n",
|
||||
"'trainer_state', 'zero_grad'\n",
|
||||
"\n",
|
||||
"  run(num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, catch_KeyboardInterrupt=None)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c348864c",
|
||||
"id": "b1b2e8b7",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"最后就是`run`函数的使用,关于其参数,这里也以表格形式列出,由此就解答了`num_eval_batch_per_dl=10`的含义\n",
|
||||
"\n",
|
||||
"| <div align=\"center\">名称</div> | <div align=\"center\">功能</div> | <div align=\"center\">默认值</div> |\n",
|
||||
"|:--|:--|:--|\n",
|
||||
"| `num_train_batch_per_epoch` | 指定`trainer`训练时,每个循环计算批量数目 | 整数类型,默认`-1`,表示训练时,每个循环计算所有批量 |\n",
|
||||
"| `num_eval_batch_per_dl` | 指定`trainer`评测时,每个循环计算批量数目 | 整数类型,默认`-1`,表示评测时,每个循环计算所有批量 |\n",
|
||||
"| `num_eval_sanity_batch` | 指定`trainer`训练开始前,试探性评测批量数目 | 整数类型,默认`2`,表示训练开始前评估两个批量 |\n",
|
||||
"| `resume_from` | 指定`trainer`恢复状态的路径,需要是文件夹 | 字符串型,默认`None`,使用可参考`CheckpointCallback` |\n",
|
||||
"| `resume_training` | 指定`trainer`恢复状态的程度 | 布尔类型,默认`True`恢复所有状态,`False`仅恢复`model`和`optimizers`状态 |"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@ -518,6 +588,16 @@
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.run(num_eval_batch_per_dl=10)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f1abfa0a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
|
@ -19,11 +19,7 @@
|
||||
"\n",
|
||||
"    2.2   使用 jittor 搭建并训练模型\n",
|
||||
"\n",
|
||||
"  3   fastNLP 实现 paddle 与 pytorch 互转\n",
|
||||
"\n",
|
||||
"    3.1   \n",
|
||||
"\n",
|
||||
"    3.2   "
|
||||
"  3   fastNLP 实现 paddle 与 pytorch 互转"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -32,7 +28,303 @@
|
||||
"id": "08752c5a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"\n",
|
||||
"sst2data = load_dataset('glue', 'sst2')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7e8cc210",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.append('..')\n",
|
||||
"\n",
|
||||
"from fastNLP import DataSet\n",
|
||||
"\n",
|
||||
"dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n",
|
||||
"\n",
|
||||
"dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n",
|
||||
" progress_bar=\"tqdm\")\n",
|
||||
"dataset.delete_field('sentence')\n",
|
||||
"dataset.delete_field('label')\n",
|
||||
"dataset.delete_field('idx')\n",
|
||||
"\n",
|
||||
"from fastNLP import Vocabulary\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary()\n",
|
||||
"vocab.from_dataset(dataset, field_name='words')\n",
|
||||
"vocab.index_dataset(dataset, field_name='words')\n",
|
||||
"\n",
|
||||
"train_dataset, evaluate_dataset = dataset.split(ratio=0.85)\n",
|
||||
"print(type(train_dataset), isinstance(train_dataset, DataSet))\n",
|
||||
"\n",
|
||||
"from fastNLP.io import DataBundle\n",
|
||||
"\n",
|
||||
"data_bundle = DataBundle(datasets={'train': train_dataset, 'dev': evaluate_dataset})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "57a3272f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. fastNLP 结合 paddle 训练模型\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"import paddle\n",
|
||||
"\n",
|
||||
"lstm = paddle.nn.LSTM(16, 32, 2)\n",
|
||||
"\n",
|
||||
"x = paddle.randn((4, 23, 16))\n",
|
||||
"h = paddle.randn((2, 4, 32))\n",
|
||||
"c = paddle.randn((2, 4, 32))\n",
|
||||
"\n",
|
||||
"y, (h, c) = lstm(x, (h, c))\n",
|
||||
"\n",
|
||||
"print(y.shape) # [4, 23, 32]\n",
|
||||
"print(h.shape) # [2, 4, 32]\n",
|
||||
"print(c.shape) # [2, 4, 32]\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e31b3198",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import paddle\n",
|
||||
"import paddle.nn as nn\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ClsByPaddle(nn.Layer):\n",
|
||||
" def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n",
|
||||
" nn.Layer.__init__(self)\n",
|
||||
" self.hidden_dim = hidden_dim\n",
|
||||
"\n",
|
||||
" self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)\n",
|
||||
" # self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, \n",
|
||||
" # num_layers=num_layers, direction='bidirectional', dropout=dropout)\n",
|
||||
" self.mlp = nn.Sequential(('linear_1', nn.Linear(hidden_dim * 2, hidden_dim * 2)),\n",
|
||||
" ('activate', nn.ReLU()),\n",
|
||||
" ('linear_2', nn.Linear(hidden_dim * 2, output_dim)))\n",
|
||||
" \n",
|
||||
" self.loss_fn = nn.CrossEntropyLoss()\n",
|
||||
"\n",
|
||||
" def forward(self, words):\n",
|
||||
" output = self.embedding(words)\n",
|
||||
" # output, (hidden, cell) = self.lstm(output)\n",
|
||||
" hidden = paddle.randn((2, words.shape[0], self.hidden_dim))\n",
|
||||
" output = self.mlp(paddle.concat((hidden[-1], hidden[-2]), axis=1))\n",
|
||||
" return output\n",
|
||||
" \n",
|
||||
" def train_step(self, words, target):\n",
|
||||
" pred = self(words)\n",
|
||||
" return {\"loss\": self.loss_fn(pred, target)}\n",
|
||||
"\n",
|
||||
" def evaluate_step(self, words, target):\n",
|
||||
" pred = self(words)\n",
|
||||
" pred = paddle.max(pred, axis=-1)[1]\n",
|
||||
" return {\"pred\": pred, \"target\": target}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c63b030f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = ClsByPaddle(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n",
|
||||
"\n",
|
||||
"model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2997c0aa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from paddle.optimizer import AdamW\n",
|
||||
"\n",
|
||||
"optimizers = AdamW(parameters=model.parameters(), learning_rate=1e-2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ead35fb8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP import prepare_paddle_dataloader\n",
|
||||
"\n",
|
||||
"# train_dataloader = prepare_paddle_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
|
||||
"# evaluate_dataloader = prepare_paddle_dataloader(evaluate_dataset, batch_size=16)\n",
|
||||
"\n",
|
||||
"dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "25e8da83",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP import Trainer, Accuracy\n",
|
||||
"\n",
|
||||
"trainer = Trainer(\n",
|
||||
" model=model,\n",
|
||||
" driver='paddle',\n",
|
||||
" device='gpu', # 'cpu', 'gpu', 'gpu:x'\n",
|
||||
" n_epochs=10,\n",
|
||||
" optimizers=optimizers,\n",
|
||||
" train_dataloader=dl_bundle['train'], # train_dataloader,\n",
|
||||
" evaluate_dataloaders=dl_bundle['dev'], # evaluate_dataloader,\n",
|
||||
" metrics={'acc': Accuracy()}\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d63c5d74",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.run(num_eval_batch_per_dl=10) # 然后卡了?"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cb9a0b3c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. fastNLP 结合 jittor 训练模型"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c600191d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import jittor\n",
|
||||
"import jittor.nn as nn\n",
|
||||
"\n",
|
||||
"from jittor import Module\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ClsByJittor(Module):\n",
|
||||
" def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n",
|
||||
" Module.__init__(self)\n",
|
||||
" self.hidden_dim = hidden_dim\n",
|
||||
"\n",
|
||||
" self.embedding = nn.Embedding(num=vocab_size, dim=embedding_dim)\n",
|
||||
" self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, \n",
|
||||
" num_layers=num_layers, bidirectional=True, dropout=dropout)\n",
|
||||
" self.mlp = nn.Sequential([nn.Linear(hidden_dim * 2, hidden_dim * 2),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Linear(hidden_dim * 2, output_dim)])\n",
|
||||
"\n",
|
||||
" self.loss_fn = nn.BCELoss()\n",
|
||||
"\n",
|
||||
" def execute(self, words):\n",
|
||||
" output = self.embedding(words)\n",
|
||||
" output, (hidden, cell) = self.lstm(output)\n",
|
||||
" # hidden = jittor.randn((2, words.shape[0], self.hidden_dim))\n",
|
||||
" output = self.mlp(jittor.concat((hidden[-1], hidden[-2]), axis=1))\n",
|
||||
" return output\n",
|
||||
" \n",
|
||||
" def train_step(self, words, target):\n",
|
||||
" pred = self(words)\n",
|
||||
" return {\"loss\": self.loss_fn(pred, target)}\n",
|
||||
"\n",
|
||||
" def evaluate_step(self, words, target):\n",
|
||||
" pred = self(words)\n",
|
||||
" pred = jittor.max(pred, axis=-1)[1]\n",
|
||||
" return {\"pred\": pred, \"target\": target}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a94ed8c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = ClsByJittor(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n",
|
||||
"\n",
|
||||
"model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6d15ebc1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from jittor.optim import AdamW\n",
|
||||
"\n",
|
||||
"optimizers = AdamW(params=model.parameters(), lr=1e-2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "95d8d09e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP import prepare_jittor_dataloader\n",
|
||||
"\n",
|
||||
"# train_dataloader = prepare_jittor_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
|
||||
"# evaluate_dataloader = prepare_jittor_dataloader(evaluate_dataset, batch_size=16)\n",
|
||||
"\n",
|
||||
"dl_bundle = prepare_jittor_dataloader(data_bundle, batch_size=16, shuffle=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "917eab81",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP import Trainer, Accuracy\n",
|
||||
"\n",
|
||||
"trainer = Trainer(\n",
|
||||
" model=model,\n",
|
||||
" driver='jittor',\n",
|
||||
" device='gpu', # 'cpu', 'gpu', 'cuda'\n",
|
||||
" n_epochs=10,\n",
|
||||
" optimizers=optimizers,\n",
|
||||
" train_dataloader=dl_bundle['train'], # train_dataloader,\n",
|
||||
" evaluate_dataloaders=dl_bundle['dev'], # evaluate_dataloader,\n",
|
||||
" metrics={'acc': Accuracy()}\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f7c4ac5a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.run(num_eval_batch_per_dl=10)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
Loading…
Reference in New Issue
Block a user