fastNLP/docs/transfer.ipynb
2022-05-07 15:48:58 +00:00

102 lines
4.5 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'应当为一个字符串,其值应当为以下之一:``[None, \"dist\", \"unrepeatdist\"]``;为 ``None`` 时,表示不需要考虑当前 ``dataloader`` 切换为分布式状态;为 ``\"dist\"`` 时,表示该 ``dataloader`` 应该保证每个 ``gpu`` 上返回的 ``batch`` 的数量是一样多的,允许出现少量 ``sample`` ,在 不同 ``gpu`` 上出现重复;为 ``\"unrepeatdist\"`` 时,表示该 ``dataloader`` 应该保证所有 ``gpu`` 上迭代出来的数据合并起来应该刚好等于原始的 数据,允许不同 ``gpu`` 上 ``batch`` 的数量不一致。其中 ``trainer`` 中 ``kwargs`` 的参数 ``use_dist_sampler`` 为 ``True`` 时,该值为 ``\"dist\"`` 否则为 ``None`` ``evaluator`` 中的 ``kwargs`` 的参数 ``use_dist_sampler`` 为 ``True`` 时,该值为 ``\"unrepeatdist\"``,否则为 ``None`` 注意当 ``dist`` 为 ``ReproducibleSampler, ReproducibleBatchSampler`` 时,是断点重训加载时 ``driver.load`` 函数在调用; 当 ``dist`` 为 ``str`` 或者 ``None`` 时,是 ``trainer`` 在初始化时调用该函数;'"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import re\n",
"import sys\n",
"sys.path.append(\"../\")\n",
"# import fastNLP\n",
"\n",
"def get_class(text):\n",
" return f\":class:`~{text}`\"\n",
"\n",
"def get_meth(text):\n",
" return f\":meth:`~{text}`\"\n",
"\n",
"def get_module(text):\n",
" return f\":mod:`~{text}`\"\n",
"\n",
"def replace(matched):\n",
" \"\"\"\n",
" \"\"\"\n",
" text = matched.group()\n",
" non_space = text.strip()\n",
" if non_space == \"\":\n",
" return text\n",
" # 如果原本就添加了 `,那么只加一个\n",
" if non_space.startswith(\"`\"):\n",
" res = \"`\" + non_space\n",
" else:\n",
" res = \"``\" + non_space\n",
" if non_space.endswith(\"`\"):\n",
" res += \"`\"\n",
" else:\n",
" res += \"``\"\n",
" return text.replace(non_space, f\"{res}\")\n",
"\n",
"def transfer(text):\n",
" \"\"\"\n",
" 将输入的 ``text`` 中的英文单词添加 \"``\"。在得到结果后最好手动检查一下,\n",
" \"\"\"\n",
" res = re.sub(\n",
" # 匹配字母、下划线、点、逗号、引号、中括号和`\n",
" pattern=r\"[a-zA-Z_ \\.,\\\"\\'\\[\\]`]+\",\n",
" repl=replace,\n",
" string=text\n",
" )\n",
" return res\n",
"\n",
"\n",
"text = '应当为一个字符串,其值应当为以下之一:[None, \"dist\", \"unrepeatdist\"];为 None 时,表示不需要考虑当前 dataloader \\\n",
" 切换为分布式状态;为 \"dist\" 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 \\\n",
" 不同 gpu 上出现重复;为 \"unrepeatdist\" 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 \\\n",
" 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 \"dist\" \\\n",
" 否则为 None evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 \"unrepeatdist\",否则为 None \\\n",
" 注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; \\\n",
" 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数;'\n",
"transfer(text)"
]
}
],
"metadata": {
"interpreter": {
"hash": "c79c3370938623706c2d55a7989cf7c7c31ff0346157477d22565bb370580b77"
},
"kernelspec": {
"display_name": "Python 3.7.13 ('fnlp')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}