mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
修改部分注释
This commit is contained in:
parent
c077107555
commit
16388d5698
@ -355,7 +355,7 @@ class Trainer(object):
|
||||
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有
|
||||
效。
|
||||
:param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模
|
||||
型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。
|
||||
型。保存的时候不仅保存了参数,还保存了模型结构。即便使用了nn.DataParallel,这里也只保存模型。
|
||||
:param prefetch: bool, 是否使用额外的进程对产生batch数据。理论上会使得Batch迭代更快。
|
||||
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。
|
||||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型
|
||||
@ -366,7 +366,7 @@ class Trainer(object):
|
||||
|
||||
2. torch.device:将模型装载到torch.device上。
|
||||
|
||||
3. int: 将使用device_id为该值的gpu进行训练
|
||||
3. int: 将使用该gpu进行训练
|
||||
|
||||
4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。
|
||||
|
||||
|
@ -200,13 +200,13 @@ def _move_model_to_device(model, device):
|
||||
else:
|
||||
if not torch.cuda.is_available() and (
|
||||
device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')):
|
||||
raise ValueError("There is no usable gpu. set `device` as `cpu`.")
|
||||
raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.")
|
||||
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.")
|
||||
|
||||
if isinstance(device, int):
|
||||
assert device>-1, "device can only be positive integer"
|
||||
assert device>-1, "device can only be non-negative integer"
|
||||
assert torch.cuda.device_count()>device, "Only has {} gpus, cannot use device {}.".format(torch.cuda.device_count(),
|
||||
device)
|
||||
device = torch.device('cuda:{}'.format(device))
|
||||
@ -227,7 +227,7 @@ def _move_model_to_device(model, device):
|
||||
assert list(types)[0] == int, "Only int supported for multiple devices."
|
||||
assert len(set(device))==len(device), "Duplicated device id found in device."
|
||||
for d in device:
|
||||
assert d>-1, "Only positive device id allowed."
|
||||
assert d>-1, "Only non-negative device id allowed."
|
||||
if len(device)>1:
|
||||
output_device = device[0]
|
||||
model = nn.DataParallel(model, device_ids=device, output_device=output_device)
|
||||
|
@ -33,6 +33,8 @@ class TestMoveModelDeivce(unittest.TestCase):
|
||||
assert model.param.device == torch.device('cuda:0')
|
||||
with self.assertRaises(Exception):
|
||||
_move_model_to_device(model, 'cuda:1000')
|
||||
# 测试None
|
||||
model = _move_model_to_device(model, None)
|
||||
|
||||
def test_case2(self):
|
||||
# 测试使用int初始化
|
||||
|
Loading…
Reference in New Issue
Block a user