mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 21:28:01 +08:00
减少batch中不断创建多进程的开销
This commit is contained in:
parent
2e3ef52a7d
commit
d9ac334409
@ -15,24 +15,28 @@ from fastNLP.core.sampler import RandomSampler
|
||||
|
||||
class Batch(object):
|
||||
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, num_workers=0, pin_memory=False,
|
||||
timeout=0.0):
|
||||
timeout=0.0, keep_process=False):
|
||||
"""
|
||||
Batch is an iterable object which iterates over mini-batches.
|
||||
|
||||
Example::
|
||||
|
||||
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()):
|
||||
# ...
|
||||
iterator = Batch(data_set, batch_size=16, sampler=SequentialSampler())
|
||||
for epoch in range(num_epochs):
|
||||
for batch_x, batch_y in iterator: # 每次epoch会重新使用sampler生成index的。
|
||||
# ...
|
||||
|
||||
:param DataSet dataset: a DataSet object
|
||||
:param int batch_size: the size of the batch
|
||||
:param Sampler sampler: a Sampler object
|
||||
:param bool as_numpy: If True, return Numpy array when possible. Otherwise, return torch tensors.
|
||||
:param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors.
|
||||
:param num_workers: int, 使用多少个进程来准备数据。默认为0, 即使用主线程生成数据。 特性处于实验阶段,谨慎使用。
|
||||
如果DataSet较大,且每个batch的准备时间很短,使用多进程可能并不能提速。
|
||||
:param pin_memory: bool, 默认为False. 设置为True时,有可能可以节省tensor从cpu移动到gpu的阻塞时间。
|
||||
:param timeout: float, 大于0的数,只有在num_workers>0时才有用。超过该时间仍然没有获取到一个batch则报错,可以用于
|
||||
检测是否出现了batch产生阻塞的情况。
|
||||
:param keep_process: bool. 默认为False,该参数只在多进程下有效。在多进程的情况下,反复产生batch的iterator会导致
|
||||
不断创建、销毁进程,可能对速度有一定的影响。当keep_process为True时,直到Batch对象被删除之前,多进程都没有关
|
||||
闭。如果设置了keep_process为True,可以通过del BatchObject来删除Batch对象并关闭进程。
|
||||
"""
|
||||
|
||||
if num_workers < 0:
|
||||
@ -45,15 +49,24 @@ class Batch(object):
|
||||
self.batch_size = batch_size
|
||||
self.sampler = sampler
|
||||
self.num_workers = num_workers
|
||||
self.keep_process = keep_process
|
||||
self.pin_memory = pin_memory
|
||||
self.timeout = timeout
|
||||
self.as_numpy = as_numpy
|
||||
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0)
|
||||
self.cur_batch_indices = None
|
||||
self._data_iterator = None
|
||||
|
||||
def __iter__(self):
|
||||
# TODO 现在多线程的情况下每个循环都会重新创建多进程,开销可能有点大。可以考虑直接复用iterator.
|
||||
return _DataLoaderIter(self)
|
||||
if self._data_iterator is not None:
|
||||
# 重新设置index_list
|
||||
self._data_iterator.reset()
|
||||
return self._data_iterator
|
||||
elif self.keep_process and self.num_workers>0:
|
||||
self._data_iterator = _DataLoaderIter(self)
|
||||
return self._data_iterator
|
||||
else: # 大多数情况是这个
|
||||
return _DataLoaderIter(self)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_batches
|
||||
@ -61,6 +74,12 @@ class Batch(object):
|
||||
def get_batch_indices(self):
|
||||
return self.cur_batch_indices
|
||||
|
||||
def __del__(self):
|
||||
if self.keep_process is True:
|
||||
del self._data_iterator
|
||||
|
||||
|
||||
|
||||
def to_tensor(batch, dtype):
|
||||
try:
|
||||
if dtype in (int, np.int8, np.int16, np.int32, np.int64):
|
||||
@ -276,6 +295,7 @@ class _DataLoaderIter(object):
|
||||
self.num_workers = batcher.num_workers
|
||||
self.pin_memory = batcher.pin_memory and torch.cuda.is_available()
|
||||
self.timeout = batcher.timeout
|
||||
self.keep_process = batcher.keep_process
|
||||
self.done_event = threading.Event()
|
||||
self.curidx = 0
|
||||
self.idx_list = self.sampler(self.dataset)
|
||||
@ -335,6 +355,17 @@ class _DataLoaderIter(object):
|
||||
for _ in range(2 * self.num_workers):
|
||||
self._put_indices()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
重置curidx以及重新采样idx_list. 只有再需要keep_process时才有用
|
||||
:return:
|
||||
"""
|
||||
if self.keep_process:
|
||||
self.curidx = 0
|
||||
self.idx_list = self.sampler(self.dataset)
|
||||
for _ in range(2 * self.num_workers):
|
||||
self._put_indices()
|
||||
|
||||
def _get_batch(self):
|
||||
if self.timeout > 0:
|
||||
try:
|
||||
@ -366,7 +397,8 @@ class _DataLoaderIter(object):
|
||||
|
||||
# 如果生成的数据为0了,则停止
|
||||
if self.batches_outstanding == 0:
|
||||
self._shutdown_workers()
|
||||
if not self.keep_process:
|
||||
self._shutdown_workers()
|
||||
raise StopIteration
|
||||
|
||||
while True:
|
||||
@ -449,4 +481,4 @@ class _DataLoaderIter(object):
|
||||
|
||||
def __del__(self):
|
||||
if self.num_workers > 0:
|
||||
self._shutdown_workers()
|
||||
self._shutdown_workers()
|
@ -61,7 +61,8 @@ class Trainer(object):
|
||||
:param BaseSampler sampler: method used to generate batch data.
|
||||
:param num_workers: int, 使用多少个进程来准备数据。默认为0, 即使用主线程生成数据。 特性处于实验阶段,谨慎使用。
|
||||
如果DataSet较大,且每个batch的准备时间很短,使用多进程可能并不能提速。
|
||||
:param pin_memory: bool, 默认为False. 设置为True时,有可能可以节省tensor从cpu移动到gpu的阻塞时间。
|
||||
:param pin_memory: bool, 默认为False. 当设置为True时,会使用锁页内存,可能导致内存占用变多。如果内存比较充足,
|
||||
可以考虑设置为True进行加速, 当pin_memory为True时,默认使用non_blocking=True的方式将数据从cpu移动到gpu。
|
||||
:param timeout: float, 大于0的数,只有在num_workers>0时才有用。超过该时间仍然没有获取到一个batch则报错,可以用于
|
||||
检测是否出现了batch产生阻塞的情况。
|
||||
:param bool use_tqdm: whether to use tqdm to show train progress.
|
||||
@ -246,7 +247,8 @@ class Trainer(object):
|
||||
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
|
||||
avg_loss = 0
|
||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
|
||||
num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout)
|
||||
num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout,
|
||||
keep_process=True)
|
||||
for epoch in range(1, self.n_epochs+1):
|
||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
|
||||
# early stopping
|
||||
@ -255,7 +257,8 @@ class Trainer(object):
|
||||
indices = data_iterator.get_batch_indices()
|
||||
# negative sampling; replace unknown; re-weight batch_y
|
||||
self.callback_manager.before_batch(batch_x, batch_y, indices)
|
||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
|
||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device,
|
||||
non_blocking=self.pin_memory) # pin_memory, use non_blockling.
|
||||
prediction = self._data_forward(self.model, batch_x)
|
||||
|
||||
# edit prediction
|
||||
|
@ -237,6 +237,31 @@ class TrainerTestGround(unittest.TestCase):
|
||||
use_tqdm=False,
|
||||
print_every=2)
|
||||
|
||||
def test_case2(self):
|
||||
# check metrics Wrong
|
||||
data_set = prepare_fake_dataset2('x1', 'x2')
|
||||
def test_trainer_multiprocess(self):
|
||||
dataset = prepare_fake_dataset2('x1', 'x2')
|
||||
dataset.set_input('x1', 'x2', 'y', flag=True)
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(5, 4)
|
||||
|
||||
def forward(self, x1, x2, y):
|
||||
x1 = self.fc(x1)
|
||||
x2 = self.fc(x2)
|
||||
x = x1 + x2
|
||||
loss = F.cross_entropy(x, y)
|
||||
return {'loss': loss}
|
||||
|
||||
model = Model()
|
||||
trainer = Trainer(
|
||||
train_data=dataset,
|
||||
model=model,
|
||||
use_tqdm=True,
|
||||
print_every=2,
|
||||
num_workers=2,
|
||||
pin_memory=False,
|
||||
timeout=0,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user