修改部分注释

This commit is contained in:
yh 2019-04-29 15:12:50 +08:00
parent c077107555
commit 16388d5698
3 changed files with 7 additions and 5 deletions

View File

@ -355,7 +355,7 @@ class Trainer(object):
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1则每个epoch结束验证一次仅在传入dev_data时有
:param str,None save_path: 将模型保存路径如果为None则不保存模型如果dev_data为None则保存最后一次迭代的模
保存的时候不仅保存了参数还保存了模型结构即便使用DataParallel这里也只保存模型
保存的时候不仅保存了参数还保存了模型结构即便使用了nn.DataParallel这里也只保存模型
:param prefetch: bool, 是否使用额外的进程对产生batch数据理论上会使得Batch迭代更快
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False则将loss打印在终端中
:param str,int,torch.device,list(int) device: 将模型load到哪个设备默认为None即Trainer不对模型
@ -366,7 +366,7 @@ class Trainer(object):
2. torch.device将模型装载到torch.device上
3. int: 将使用device_id为值的gpu进行训练
3. int: 将使用该gpu进行训练
4. list(int)如果多于1个device将使用torch.nn.DataParallel包裹model, 并使用传入的device

View File

@ -200,13 +200,13 @@ def _move_model_to_device(model, device):
else:
if not torch.cuda.is_available() and (
device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')):
raise ValueError("There is no usable gpu. set `device` as `cpu`.")
raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.")
if isinstance(model, torch.nn.DataParallel):
raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.")
if isinstance(device, int):
assert device>-1, "device can only be positive integer"
assert device>-1, "device can only be non-negative integer"
assert torch.cuda.device_count()>device, "Only has {} gpus, cannot use device {}.".format(torch.cuda.device_count(),
device)
device = torch.device('cuda:{}'.format(device))
@ -227,7 +227,7 @@ def _move_model_to_device(model, device):
assert list(types)[0] == int, "Only int supported for multiple devices."
assert len(set(device))==len(device), "Duplicated device id found in device."
for d in device:
assert d>-1, "Only positive device id allowed."
assert d>-1, "Only non-negative device id allowed."
if len(device)>1:
output_device = device[0]
model = nn.DataParallel(model, device_ids=device, output_device=output_device)

View File

@ -33,6 +33,8 @@ class TestMoveModelDeivce(unittest.TestCase):
assert model.param.device == torch.device('cuda:0')
with self.assertRaises(Exception):
_move_model_to_device(model, 'cuda:1000')
# 测试None
model = _move_model_to_device(model, None)
def test_case2(self):
# 测试使用int初始化