to load 'ernie-1.0-base-zh'.\u001b[0m\n",
+ "\u001b[32m[2022-06-22 21:31:15,580] [ INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/ernie_v1_chn_base.pdparams\u001b[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "import paddle.nn as nn\n",
+ "\n",
+ "class SeqClsModel(nn.Layer):\n",
+ " def __init__(self, model_checkpoint, num_labels):\n",
+ " super(SeqClsModel, self).__init__()\n",
+ " self.model = AutoModelForSequenceClassification.from_pretrained(\n",
+ " model_checkpoint,\n",
+ " num_classes=num_labels,\n",
+ " )\n",
+ "\n",
+ " def forward(self, input_ids, attention_mask, token_type_ids):\n",
+ " logits = self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)\n",
+ " return logits\n",
+ "\n",
+ " def train_step(self, input_ids, attention_mask, token_type_ids, label):\n",
+ " logits = self(input_ids, attention_mask, token_type_ids)\n",
+ " loss = nn.CrossEntropyLoss()(logits, label)\n",
+ " return {\"loss\": loss}\n",
+ "\n",
+ " def evaluate_step(self, input_ids, attention_mask, token_type_ids, label):\n",
+ " logits = self(input_ids, attention_mask, token_type_ids)\n",
+ " return {'pred': logits, 'target': label}\n",
+ "\n",
+ "model = SeqClsModel(model_checkpoint, num_labels=2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2 设置参数并使用 Trainer 开始训练\n",
+ "\n",
+ "现在我们可以着手使用 ``FastNLP.Trainer`` 进行训练了。\n",
+ "\n",
+ "首先,为了高效地训练 ``ERNIE`` 模型,我们最好为学习率指定一定的策略。``paddlenlp`` 提供的 ``LinearDecayWithWarmup`` 可以令学习率在一段时间内从 0 开始线性地增长(预热),然后再线性地衰减至 0 。在本篇教程中,我们将学习率设置为 ``5e-5``,预热时间为 ``0.1``,然后将得到的的 ``lr_scheduler`` 赋值给 ``AdamW`` 优化器。\n",
+ "\n",
+ "其次,我们还可以为 ``Trainer`` 指定多个 ``Callback`` 来在基础的训练过程之外进行额外的定制操作。在本篇教程中,我们使用的 ``Callback`` 有以下三种:\n",
+ "\n",
+ "- ``RichCallback`` - 在训练和验证时显示进度条,以便观察训练的过程\n",
+ "- ``LRSchedCallback`` - 由于我们使用了 ``Scheduler``,因此需要将 ``lr_scheduler`` 传给该 ``Callback`` 以在训练中进行更新\n",
+ "- ``LoadBestModelCallback`` - 该 ``Callback`` 会评估结果中的 ``'acc#accuracy'`` 值,保存训练中出现的正确率最高的模型,并在训练结束时加载到模型上,方便对模型进行测试和评估。\n",
+ "\n",
+ "在 ``Trainer`` 中,我们还可以设置 ``metrics`` 来衡量模型的表现。``Accuracy`` 能够根据传入的预测值和真实值计算出模型预测的正确率。还记得模型中 ``evaluate_step`` 函数的返回值吗?键 ``pred`` 和 ``target`` 分别为 ``Accuracy.update`` 的参数名,在验证过程中 ``FastNLP`` 会自动将键和参数名匹配从而计算出正确率,这也是我们规定模型需要返回字典类型数据的原因。\n",
+ "\n",
+ "``Accuracy`` 的返回值包含三个部分:``acc``、``total`` 和 ``correct``,分别代表 ``正确率``、 ``数据总数`` 和 ``预测正确的数目``,这让您能够直观地知晓训练中模型的变化,``LoadBestModelCallback`` 的参数 ``'acc#accuracy'`` 也正是代表了 ``accuracy`` 指标的 ``acc`` 结果。\n",
+ "\n",
+ "在设定好参数之后,调用 ``run`` 函数便可以进行训练和验证了。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "[21:31:16] INFO Running evaluator sanity check for 2 batches. trainer.py:631\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[21:31:16]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=4641;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=822054;file://../fastNLP/core/controllers/trainer.py#631\u001b\\\u001b[2m631\u001b[0m\u001b]8;;\u001b\\\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:60 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m60\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.895833,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1075.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.895833\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1075.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:120 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m120\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.8975,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1077.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.8975\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1077.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:180 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m180\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.911667,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1094.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.911667\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1094.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:240 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m240\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.9225,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1107.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9225\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1107.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:0, Batch:300 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.9275,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1113.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9275\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1113.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:1, Batch:60 -----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m60\u001b[0m -----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.930833,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1117.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.930833\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1117.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:1, Batch:120 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m120\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.935833,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1123.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.935833\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1123.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:1, Batch:180 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m180\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.935833,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1123.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.935833\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1123.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:1, Batch:240 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m240\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.9375,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1125.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9375\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1125.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "---------------------------- Eval. results on Epoch:1, Batch:300 ----------------------------\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "{\n",
+ " \"acc#accuracy\": 0.941667,\n",
+ " \"total#accuracy\": 1200.0,\n",
+ " \"correct#accuracy\": 1130.0\n",
+ "}\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m{\u001b[0m\n",
+ " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.941667\u001b[0m,\n",
+ " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n",
+ " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1130.0\u001b[0m\n",
+ "\u001b[1m}\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "[21:34:28] INFO Loading best model from fnlp-ernie/2022-0 load_best_model_callback.py:111\n",
+ " 6-22-21_29_12_898095/best_so_far with \n",
+ " acc#accuracy: 0.941667... \n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[21:34:28]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from fnlp-ernie/\u001b[1;36m2022\u001b[0m-\u001b[1;36m0\u001b[0m \u001b]8;id=340364;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=763898;file://../fastNLP/core/callbacks/load_best_model_callback.py#111\u001b\\\u001b[2m111\u001b[0m\u001b]8;;\u001b\\\n",
+ "\u001b[2;36m \u001b[0m \u001b[1;36m6\u001b[0m-\u001b[1;36m22\u001b[0m-21_29_12_898095/best_so_far with \u001b[2m \u001b[0m\n",
+ "\u001b[2;36m \u001b[0m acc#accuracy: \u001b[1;36m0.941667\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "[21:34:34] INFO Deleting fnlp-ernie/2022-06-22-21_29_12_8 load_best_model_callback.py:131\n",
+ " 98095/best_so_far... \n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[2;36m[21:34:34]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Deleting fnlp-ernie/\u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m22\u001b[0m-21_29_12_8 \u001b]8;id=430330;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=508566;file://../fastNLP/core/callbacks/load_best_model_callback.py#131\u001b\\\u001b[2m131\u001b[0m\u001b]8;;\u001b\\\n",
+ "\u001b[2;36m \u001b[0m 98095/best_so_far\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from fastNLP import LRSchedCallback, RichCallback, LoadBestModelCallback\n",
+ "from fastNLP import Trainer, Accuracy\n",
+ "from paddlenlp.transformers import LinearDecayWithWarmup\n",
+ "\n",
+ "n_epochs = 2\n",
+ "num_training_steps = len(train_dataloader) * n_epochs\n",
+ "lr_scheduler = LinearDecayWithWarmup(5e-5, num_training_steps, 0.1)\n",
+ "optimizer = paddle.optimizer.AdamW(\n",
+ " learning_rate=lr_scheduler,\n",
+ " parameters=model.parameters(),\n",
+ ")\n",
+ "callbacks = [\n",
+ " LRSchedCallback(lr_scheduler, step_on=\"batch\"),\n",
+ " LoadBestModelCallback(\"acc#accuracy\", larger_better=True, save_folder=\"fnlp-ernie\"),\n",
+ " RichCallback()\n",
+ "]\n",
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " driver=\"paddle\",\n",
+ " optimizers=optimizer,\n",
+ " device=0,\n",
+ " n_epochs=n_epochs,\n",
+ " train_dataloader=train_dataloader,\n",
+ " evaluate_dataloaders=val_dataloader,\n",
+ " evaluate_every=60,\n",
+ " metrics={\"accuracy\": Accuracy()},\n",
+ " callbacks=callbacks,\n",
+ ")\n",
+ "trainer.run()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.3 测试和评估\n",
+ "\n",
+ "现在我们已经得到了一个表现良好的 ``ERNIE`` 模型,接下来可以在测试集上测试模型的效果了。``FastNLP.Evaluator`` 提供了定制函数的功能。我们以 ``test_dataloader`` 初始化一个 ``Evaluator``,然后将写好的测试函数 ``test_batch_step_fn`` 传给参数 ``evaluate_batch_step_fn``,``Evaluate`` 在对每个 batch 进行评估时就会调用我们自定义的 ``test_batch_step_fn`` 函数而不是 ``evaluate_step`` 函数。在这里,我们仅测试 5 条数据并输出文本和对应的标签。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "text: ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般']\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "text: ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般']\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "labels: 0\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "labels: 0\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "text: ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片!开始\n",
+ "还怀疑是不是赠送的个别现象,可是后来发现每张DVD后面都有!真不知道生产商怎么想的,我想看的是猫\n",
+ "和老鼠,不是米老鼠!如果厂家是想赠送的话,那就全套米老鼠和唐老鸭都赠送,只在每张DVD后面添加一\n",
+ "集算什么??简直是画蛇添足!!']\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "text: ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片!开始\n",
+ "还怀疑是不是赠送的个别现象,可是后来发现每张DVD后面都有!真不知道生产商怎么想的,我想看的是猫\n",
+ "和老鼠,不是米老鼠!如果厂家是想赠送的话,那就全套米老鼠和唐老鸭都赠送,只在每张DVD后面添加一\n",
+ "集算什么??简直是画蛇添足!!']\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "labels: 0\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "labels: 0\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "text: ['还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气\n",
+ "泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。'\n",
+ "]\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "text: ['还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气\n",
+ "泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。'\n",
+ "]\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "labels: 0\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "labels: 0\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "text: ['交通方便;环境很好;服务态度很好 房间较小']\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "text: ['交通方便;环境很好;服务态度很好 房间较小']\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "labels: 1\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "labels: 1\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "text: ['不错,作者的观点很颠覆目前中国父母的教育方式,其实古人们对于教育已经有了很系统的体系\n",
+ "了,可是现在的父母以及祖父母们更多的娇惯纵容孩子,放眼看去自私的孩子是大多数,父母觉得自己的\n",
+ "孩子在外面只要不吃亏就是好事,完全把古人几千年总结的教育古训抛在的九霄云外。所以推荐准妈妈们\n",
+ "可以在等待宝宝降临的时候,好好学习一下,怎么把孩子教育成一个有爱心、有责任心、宽容、大度的人\n",
+ "。']\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "text: ['不错,作者的观点很颠覆目前中国父母的教育方式,其实古人们对于教育已经有了很系统的体系\n",
+ "了,可是现在的父母以及祖父母们更多的娇惯纵容孩子,放眼看去自私的孩子是大多数,父母觉得自己的\n",
+ "孩子在外面只要不吃亏就是好事,完全把古人几千年总结的教育古训抛在的九霄云外。所以推荐准妈妈们\n",
+ "可以在等待宝宝降临的时候,好好学习一下,怎么把孩子教育成一个有爱心、有责任心、宽容、大度的人\n",
+ "。']\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "labels: 1\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "labels: 1\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n"
+ ],
+ "text/plain": []
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "{}"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from fastNLP import Evaluator\n",
+ "def test_batch_step_fn(evaluator, batch):\n",
+ " input_ids = batch[\"input_ids\"]\n",
+ " attention_mask = batch[\"attention_mask\"]\n",
+ " token_type_ids = batch[\"token_type_ids\"]\n",
+ " logits = model(input_ids, attention_mask, token_type_ids)\n",
+ " predict = logits.argmax().item()\n",
+ " print(\"text:\", batch['text'])\n",
+ " print(\"labels:\", predict)\n",
+ "\n",
+ "evaluator = Evaluator(\n",
+ " model=model,\n",
+ " dataloaders=test_dataloader,\n",
+ " driver=\"paddle\",\n",
+ " device=0,\n",
+ " evaluate_batch_step_fn=test_batch_step_fn,\n",
+ ")\n",
+ "evaluator.run(5) "
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.7.13 ('fnlp-paddle')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "31f2d9d3efc23c441973d7c4273acfea8b132b6a578f002629b6b44b8f65e720"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/tutorials/figures/paddle-ernie-1.0-masking-levels.png b/tutorials/figures/paddle-ernie-1.0-masking-levels.png
new file mode 100644
index 00000000..ff2519c4
Binary files /dev/null and b/tutorials/figures/paddle-ernie-1.0-masking-levels.png differ
diff --git a/tutorials/figures/paddle-ernie-1.0-masking.png b/tutorials/figures/paddle-ernie-1.0-masking.png
new file mode 100644
index 00000000..ed003a2f
Binary files /dev/null and b/tutorials/figures/paddle-ernie-1.0-masking.png differ
diff --git a/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png b/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png
new file mode 100644
index 00000000..d45f65d8
Binary files /dev/null and b/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png differ
diff --git a/tutorials/figures/paddle-ernie-3.0-framework.png b/tutorials/figures/paddle-ernie-3.0-framework.png
new file mode 100644
index 00000000..f50ddb1c
Binary files /dev/null and b/tutorials/figures/paddle-ernie-3.0-framework.png differ