减少batch中不断创建多进程的开销

This commit is contained in:
yh_cc 2019-01-18 23:33:19 +08:00
parent 2e3ef52a7d
commit d9ac334409
3 changed files with 75 additions and 15 deletions

View File

@ -15,24 +15,28 @@ from fastNLP.core.sampler import RandomSampler
class Batch(object):
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, num_workers=0, pin_memory=False,
timeout=0.0):
timeout=0.0, keep_process=False):
"""
Batch is an iterable object which iterates over mini-batches.
Example::
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()):
# ...
iterator = Batch(data_set, batch_size=16, sampler=SequentialSampler())
for epoch in range(num_epochs):
for batch_x, batch_y in iterator: # 每次epoch会重新使用sampler生成index的。
# ...
:param DataSet dataset: a DataSet object
:param int batch_size: the size of the batch
:param Sampler sampler: a Sampler object
:param bool as_numpy: If True, return Numpy array when possible. Otherwise, return torch tensors.
:param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors.
:param num_workers: int, 使用多少个进程来准备数据默认为0, 即使用主线程生成数据 特性处于实验阶段谨慎使用
如果DataSet较大且每个batch的准备时间很短使用多进程可能并不能提速
:param pin_memory: bool, 默认为False. 设置为True时有可能可以节省tensor从cpu移动到gpu的阻塞时间
:param timeout: float, 大于0的数只有在num_workers>0时才有用超过该时间仍然没有获取到一个batch则报错可以用于
检测是否出现了batch产生阻塞的情况
:param keep_process: bool. 默认为False该参数只在多进程下有效在多进程的情况下反复产生batch的iterator会导致
不断创建销毁进程可能对速度有一定的影响当keep_process为True时直到Batch对象被删除之前多进程都没有关
如果设置了keep_process为True可以通过del BatchObject来删除Batch对象并关闭进程
"""
if num_workers < 0:
@ -45,15 +49,24 @@ class Batch(object):
self.batch_size = batch_size
self.sampler = sampler
self.num_workers = num_workers
self.keep_process = keep_process
self.pin_memory = pin_memory
self.timeout = timeout
self.as_numpy = as_numpy
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0)
self.cur_batch_indices = None
self._data_iterator = None
def __iter__(self):
# TODO 现在多线程的情况下每个循环都会重新创建多进程开销可能有点大。可以考虑直接复用iterator.
return _DataLoaderIter(self)
if self._data_iterator is not None:
# 重新设置index_list
self._data_iterator.reset()
return self._data_iterator
elif self.keep_process and self.num_workers>0:
self._data_iterator = _DataLoaderIter(self)
return self._data_iterator
else: # 大多数情况是这个
return _DataLoaderIter(self)
def __len__(self):
return self.num_batches
@ -61,6 +74,12 @@ class Batch(object):
def get_batch_indices(self):
return self.cur_batch_indices
def __del__(self):
if self.keep_process is True:
del self._data_iterator
def to_tensor(batch, dtype):
try:
if dtype in (int, np.int8, np.int16, np.int32, np.int64):
@ -276,6 +295,7 @@ class _DataLoaderIter(object):
self.num_workers = batcher.num_workers
self.pin_memory = batcher.pin_memory and torch.cuda.is_available()
self.timeout = batcher.timeout
self.keep_process = batcher.keep_process
self.done_event = threading.Event()
self.curidx = 0
self.idx_list = self.sampler(self.dataset)
@ -335,6 +355,17 @@ class _DataLoaderIter(object):
for _ in range(2 * self.num_workers):
self._put_indices()
def reset(self):
"""
重置curidx以及重新采样idx_list. 只有再需要keep_process时才有用
:return:
"""
if self.keep_process:
self.curidx = 0
self.idx_list = self.sampler(self.dataset)
for _ in range(2 * self.num_workers):
self._put_indices()
def _get_batch(self):
if self.timeout > 0:
try:
@ -366,7 +397,8 @@ class _DataLoaderIter(object):
# 如果生成的数据为0了则停止
if self.batches_outstanding == 0:
self._shutdown_workers()
if not self.keep_process:
self._shutdown_workers()
raise StopIteration
while True:
@ -449,4 +481,4 @@ class _DataLoaderIter(object):
def __del__(self):
if self.num_workers > 0:
self._shutdown_workers()
self._shutdown_workers()

View File

@ -61,7 +61,8 @@ class Trainer(object):
:param BaseSampler sampler: method used to generate batch data.
:param num_workers: int, 使用多少个进程来准备数据默认为0, 即使用主线程生成数据 特性处于实验阶段谨慎使用
如果DataSet较大且每个batch的准备时间很短使用多进程可能并不能提速
:param pin_memory: bool, 默认为False. 设置为True时有可能可以节省tensor从cpu移动到gpu的阻塞时间
:param pin_memory: bool, 默认为False. 当设置为True时会使用锁页内存可能导致内存占用变多如果内存比较充足
可以考虑设置为True进行加速, 当pin_memory为True时默认使用non_blocking=True的方式将数据从cpu移动到gpu
:param timeout: float, 大于0的数只有在num_workers>0时才有用超过该时间仍然没有获取到一个batch则报错可以用于
检测是否出现了batch产生阻塞的情况
:param bool use_tqdm: whether to use tqdm to show train progress.
@ -246,7 +247,8 @@ class Trainer(object):
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout)
num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout,
keep_process=True)
for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
# early stopping
@ -255,7 +257,8 @@ class Trainer(object):
indices = data_iterator.get_batch_indices()
# negative sampling; replace unknown; re-weight batch_y
self.callback_manager.before_batch(batch_x, batch_y, indices)
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device,
non_blocking=self.pin_memory) # pin_memory, use non_blockling.
prediction = self._data_forward(self.model, batch_x)
# edit prediction

View File

@ -237,6 +237,31 @@ class TrainerTestGround(unittest.TestCase):
use_tqdm=False,
print_every=2)
def test_case2(self):
# check metrics Wrong
data_set = prepare_fake_dataset2('x1', 'x2')
def test_trainer_multiprocess(self):
dataset = prepare_fake_dataset2('x1', 'x2')
dataset.set_input('x1', 'x2', 'y', flag=True)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)
def forward(self, x1, x2, y):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
loss = F.cross_entropy(x, y)
return {'loss': loss}
model = Model()
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=True,
print_every=2,
num_workers=2,
pin_memory=False,
timeout=0,
)
trainer.train()