将batch增强为多进程batch

This commit is contained in:
yh_cc 2019-01-18 23:02:15 +08:00
parent 864c2238f8
commit 2e3ef52a7d
5 changed files with 510 additions and 53 deletions

View File

@ -1,63 +1,59 @@
import numpy as np
import random
import torch
import torch.multiprocessing as multiprocessing
from torch.utils.data.dataloader import _set_worker_signal_handlers, _update_worker_pids, \
_remove_worker_pids, _error_if_any_worker_fails
import signal
import sys
import threading
import traceback
import os
from torch._six import FileNotFoundError
from fastNLP.core.sampler import RandomSampler
class Batch(object):
"""Batch is an iterable object which iterates over mini-batches.
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, num_workers=0, pin_memory=False,
timeout=0.0):
"""
Batch is an iterable object which iterates over mini-batches.
Example::
Example::
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()):
# ...
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()):
# ...
:param DataSet dataset: a DataSet object
:param int batch_size: the size of the batch
:param Sampler sampler: a Sampler object
:param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors.
:param DataSet dataset: a DataSet object
:param int batch_size: the size of the batch
:param Sampler sampler: a Sampler object
:param bool as_numpy: If True, return Numpy array when possible. Otherwise, return torch tensors.
:param num_workers: int, 使用多少个进程来准备数据默认为0, 即使用主线程生成数据 特性处于实验阶段谨慎使用
如果DataSet较大且每个batch的准备时间很短使用多进程可能并不能提速
:param pin_memory: bool, 默认为False. 设置为True时有可能可以节省tensor从cpu移动到gpu的阻塞时间
:param timeout: float, 大于0的数只有在num_workers>0时才有用超过该时间仍然没有获取到一个batch则报错可以用于
检测是否出现了batch产生阻塞的情况
"""
"""
if num_workers < 0:
raise ValueError('num_workers option cannot be negative; '
'use num_workers=0 to disable multiprocessing.')
if timeout < 0:
raise ValueError('timeout option should be non-negative')
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False):
self.dataset = dataset
self.batch_size = batch_size
self.sampler = sampler
self.num_workers = num_workers
self.pin_memory = pin_memory
self.timeout = timeout
self.as_numpy = as_numpy
self.idx_list = None
self.curidx = 0
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0)
self.cur_batch_indices = None
def __iter__(self):
self.idx_list = self.sampler(self.dataset)
self.curidx = 0
self.lengths = self.dataset.get_length()
return self
def __next__(self):
if self.curidx >= len(self.idx_list):
raise StopIteration
else:
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
batch_x, batch_y = {}, {}
indices = self.idx_list[self.curidx:endidx]
self.cur_batch_indices = indices
for field_name, field in self.dataset.get_all_fields().items():
if field.is_target or field.is_input:
batch = field.get(indices)
if not self.as_numpy and field.padder is not None:
batch = to_tensor(batch, field.dtype)
if field.is_target:
batch_y[field_name] = batch
if field.is_input:
batch_x[field_name] = batch
self.curidx = endidx
return batch_x, batch_y
# TODO 现在多线程的情况下每个循环都会重新创建多进程开销可能有点大。可以考虑直接复用iterator.
return _DataLoaderIter(self)
def __len__(self):
return self.num_batches
@ -65,7 +61,6 @@ class Batch(object):
def get_batch_indices(self):
return self.cur_batch_indices
def to_tensor(batch, dtype):
try:
if dtype in (int, np.int8, np.int16, np.int32, np.int64):
@ -75,3 +70,383 @@ def to_tensor(batch, dtype):
except:
pass
return batch
"""
由于多进程涉及到大量问题包括系统安全关闭进程等所以这里直接从pytorch的官方版本修改DataLoader实现多进程加速
"""
IS_WINDOWS = sys.platform == "win32"
if IS_WINDOWS:
import ctypes
from ctypes.wintypes import DWORD, BOOL, HANDLE
if sys.version_info[0] == 2:
import Queue as queue
else:
import queue
class ExceptionWrapper(object):
r"""Wraps an exception plus traceback to communicate across threads"""
def __init__(self, exc_info):
self.exc_type = exc_info[0]
self.exc_msg = "".join(traceback.format_exception(*exc_info))
_use_shared_memory = False
r"""Whether to use shared memory in default_collate"""
MANAGER_STATUS_CHECK_INTERVAL = 5.0
if IS_WINDOWS:
# On Windows, the parent ID of the worker process remains unchanged when the manager process
# is gone, and the only way to check it through OS is to let the worker have a process handle
# of the manager and ask if the process status has changed.
class ManagerWatchdog(object):
def __init__(self):
self.manager_pid = os.getppid()
self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
self.kernel32.OpenProcess.restype = HANDLE
self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
self.kernel32.WaitForSingleObject.restype = DWORD
# Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
SYNCHRONIZE = 0x00100000
self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
if not self.manager_handle:
raise ctypes.WinError(ctypes.get_last_error())
def is_alive(self):
# Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
return self.kernel32.WaitForSingleObject(self.manager_handle, 0) != 0
else:
class ManagerWatchdog(object):
def __init__(self):
self.manager_pid = os.getppid()
def is_alive(self):
return os.getppid() == self.manager_pid
def _worker_loop(dataset, index_queue, data_queue, seed, worker_id, as_numpy):
# 产生数据的循环
global _use_shared_memory
_use_shared_memory = True
# Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
# module's handlers are executed after Python returns from C low-level
# handlers, likely when the same fatal signal happened again already.
# https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
_set_worker_signal_handlers()
torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed)
watchdog = ManagerWatchdog()
while True:
try:
# 获取当前batch计数当前batch的indexes
r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
except queue.Empty:
if watchdog.is_alive():
continue
else:
break
if r is None:
break
idx, batch_indices = r
try:
# 获取相应的batch数据。这里需要修改为从dataset中取出数据并且完成padding
samples = _get_batch_from_dataset(dataset, batch_indices, as_numpy)
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info()), batch_indices))
else:
data_queue.put((idx, samples, batch_indices))
del samples
def _get_batch_from_dataset(dataset, indices, as_numpy):
"""
给定indices从DataSet中取出(batch_x, batch_y). 数据从这里产生后若没有pin_memory, 则直接传递给Trainer了如果存在
pin_memory还会经过一道pin_memory()的处理
:param dataset: fastNLP.DataSet对象
:param indices: List[int], index
:param as_numpy: bool, 是否只是转换为numpy
:return: (batch_x, batch_y)
"""
batch_x, batch_y = {}, {}
for field_name, field in dataset.get_all_fields().items():
if field.is_target or field.is_input:
batch = field.get(indices)
if not as_numpy and field.padder is not None:
batch = to_tensor(batch, field.dtype)
if field.is_target:
batch_y[field_name] = batch
if field.is_input:
batch_x[field_name] = batch
return batch_x, batch_y
def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
# 将数据送入到指定的query中. 即如果需要pin_memory, 则
if pin_memory:
torch.cuda.set_device(device_id)
while True:
try:
r = in_queue.get()
except Exception:
if done_event.is_set():
return
raise
if r is None:
break
if isinstance(r[1], ExceptionWrapper):
out_queue.put(r)
continue
idx, batch, batch_indices = r
try:
if pin_memory:
batch = pin_memory_batch(batch)
except Exception:
out_queue.put((idx, ExceptionWrapper(sys.exc_info()), batch_indices))
else:
out_queue.put((idx, batch, batch_indices))
def pin_memory_batch(batchs):
"""
:param batchs: (batch_x, batch_y)
:return: (batch_x, batch_y)
"""
for batch_dict in batchs:
for field_name, batch in batch_dict.items():
if isinstance(batch, torch.Tensor):
batch_dict[field_name] = batch.pin_memory()
return batchs
_SIGCHLD_handler_set = False
r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
handler needs to be set for all DataLoaders in a process."""
def _set_SIGCHLD_handler():
# Windows doesn't support SIGCHLD handler
if sys.platform == 'win32':
return
# can't set signal in child threads
if not isinstance(threading.current_thread(), threading._MainThread):
return
global _SIGCHLD_handler_set
if _SIGCHLD_handler_set:
return
previous_handler = signal.getsignal(signal.SIGCHLD)
if not callable(previous_handler):
previous_handler = None
def handler(signum, frame):
# This following call uses `waitid` with WNOHANG from C side. Therefore,
# Python can still get and update the process status successfully.
_error_if_any_worker_fails()
if previous_handler is not None:
previous_handler(signum, frame)
signal.signal(signal.SIGCHLD, handler)
_SIGCHLD_handler_set = True
class _DataLoaderIter(object):
r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
def __init__(self, batcher):
self.batcher = batcher
self.dataset = batcher.dataset
self.sampler = batcher.sampler
self.as_numpy = batcher.as_numpy
self.batch_size = batcher.batch_size
self.num_workers = batcher.num_workers
self.pin_memory = batcher.pin_memory and torch.cuda.is_available()
self.timeout = batcher.timeout
self.done_event = threading.Event()
self.curidx = 0
self.idx_list = self.sampler(self.dataset)
# self.sample_iter一次返回一个index. 可以通过其他方式替代
base_seed = torch.LongTensor(1).random_().item()
if self.num_workers > 0:
# 每个worker建立一个index queue
self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
self.worker_queue_idx = 0
# 存放获取到的batch
self.worker_result_queue = multiprocessing.SimpleQueue()
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {}
# 这里会将batch的数据输送到self.worker_result_queue中但是还没有送入到device中
self.workers = [
multiprocessing.Process(
target=_worker_loop,
args=(self.dataset, self.index_queues[i],
self.worker_result_queue, base_seed + i, i, self.as_numpy))
for i in range(self.num_workers)]
# self.data_queue取数据就行。如果有pin_memory的话会把数据放到另一个queue
if self.pin_memory or self.timeout > 0:
self.data_queue = queue.Queue()
if self.pin_memory:
maybe_device_id = torch.cuda.current_device()
else:
# do not initialize cuda context if not necessary
maybe_device_id = None
self.worker_manager_thread = threading.Thread(
target=_worker_manager_loop,
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
maybe_device_id))
self.worker_manager_thread.daemon = True
self.worker_manager_thread.start()
else:
self.data_queue = self.worker_result_queue
# worker们开始工作
for w in self.workers:
w.daemon = True # ensure that the worker exits on process exit
w.start()
_update_worker_pids(id(self), tuple(w.pid for w in self.workers))
_set_SIGCHLD_handler()
self.worker_pids_set = True
# prime the prefetch loop
for _ in range(2 * self.num_workers):
self._put_indices()
def _get_batch(self):
if self.timeout > 0:
try:
return self.data_queue.get(timeout=self.timeout)
except queue.Empty:
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
else:
return self.data_queue.get()
def __next__(self):
if self.num_workers == 0: # same-process loading
if self.curidx >= len(self.idx_list):
raise StopIteration
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
# 直接从数据集中采集数据即可
indices = self.idx_list[self.curidx:endidx]
self.batcher.cur_batch_indices = indices
batch_x, batch_y = _get_batch_from_dataset(dataset=self.dataset, indices=indices,
as_numpy=self.as_numpy)
if self.pin_memory:
batch_x, batch_y = pin_memory_batch((batch_x, batch_y))
self.curidx = endidx
return batch_x, batch_y
# check if the next sample has already been generated
if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch)
# 如果生成的数据为0了则停止
if self.batches_outstanding == 0:
self._shutdown_workers()
raise StopIteration
while True:
assert (not self.shutdown and self.batches_outstanding > 0)
idx, batch, batch_indices = self._get_batch()
self.batches_outstanding -= 1
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch
continue
self.batcher.cur_batch_indices = batch_indices
return self._process_next_batch(batch)
def __iter__(self):
self.curidx = 0
return self
def _put_indices(self):
# 向采集数据的index queue中放入index
assert self.batches_outstanding < 2 * self.num_workers
if self.curidx >= len(self.idx_list):
indices = None
else:
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
# 直接从数据集中采集数据即可
indices = self.idx_list[self.curidx:endidx]
if indices is None:
return
self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
self.curidx = endidx
self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
self.batches_outstanding += 1
self.send_idx += 1
def _process_next_batch(self, batch):
# 只是提醒生成下一个batch indice数据
self.rcvd_idx += 1
self._put_indices()
if isinstance(batch, ExceptionWrapper):
raise batch.exc_type(batch.exc_msg)
return batch
def __getstate__(self):
# TODO: add limited pickling support for sharing an iterator
# across multiple threads for HOGWILD.
# Probably the best way to do this is by moving the sample pushing
# to a separate thread and then just sharing the data queue
# but signalling the end is tricky without a non-blocking API
raise NotImplementedError("_DataLoaderIter cannot be pickled")
def _shutdown_workers(self):
try:
if not self.shutdown:
self.shutdown = True
self.done_event.set()
for q in self.index_queues:
q.put(None)
# if some workers are waiting to put, make place for them
try:
while not self.worker_result_queue.empty():
self.worker_result_queue.get()
except (FileNotFoundError, ImportError):
# Many weird errors can happen here due to Python
# shutting down. These are more like obscure Python bugs.
# FileNotFoundError can happen when we rebuild the fd
# fetched from the queue but the socket is already closed
# from the worker side.
# ImportError can happen when the unpickler loads the
# resource from `get`.
pass
# done_event should be sufficient to exit worker_manager_thread,
# but be safe here and put another None
self.worker_result_queue.put(None)
finally:
# removes pids no matter what
if self.worker_pids_set:
_remove_worker_pids(id(self))
self.worker_pids_set = False
def __del__(self):
if self.num_workers > 0:
self._shutdown_workers()

View File

@ -408,7 +408,7 @@ class EngChar2DPadder(PadderBase):
except:
raise ValueError("Field:{} only has one dimension.".format(field_name))
try:
value = value[1]
value = value[0]
except:
raise ValueError("Field:{} only has two dimensions.".format(field_name))

View File

@ -34,8 +34,8 @@ from fastNLP.core.utils import get_func_signature
class Trainer(object):
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0),
check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False,
callbacks=None):
check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, pin_memory=False,
timeout=0, use_tqdm=True, use_cuda=False, callbacks=None):
"""
:param DataSet train_data: the training data
:param torch.nn.modules.module model: a PyTorch model
@ -46,22 +46,27 @@ class Trainer(object):
:param int print_every: step interval to print next training information. Default: -1(no print).
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch).
:param DataSet dev_data: the validation data
:param bool use_cuda: whether to use CUDA in training.
:param str save_path: file path to save models
:param Optimizer optimizer: an optimizer object
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.\\
`ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means
it will raise error if some field are not used. 检查的原理是通过使用很小的batch(默认两个sample)来检查代码是否能够
运行但是这个过程理论上不会修改任何参数只是会检查能否运行但如果(1)模型中存在将batch_size写为某个固定值的情况(2)
模型中存在累加前向计算次数的可能会多计算几次建议将check_code_level设置为-1
it will raise error if some field are not used. 检查的原理是通过使用很小的batch(默认两个sample)来检查代码是
否能够运行但是这个过程理论上不会修改任何参数只是会检查能否运行但如果(1)模型中存在将batch_size写为某个
固定值的情况(2)模型中存在累加前向计算次数的可能会多计算几次以上情况建议将check_code_level设置为-1
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets
smaller, add "-" in front of the string. For example::
metric_key="-PPL" # language model gets better as perplexity gets smaller
:param BaseSampler sampler: method used to generate batch data.
:param num_workers: int, 使用多少个进程来准备数据默认为0, 即使用主线程生成数据 特性处于实验阶段谨慎使用
如果DataSet较大且每个batch的准备时间很短使用多进程可能并不能提速
:param pin_memory: bool, 默认为False. 设置为True时有可能可以节省tensor从cpu移动到gpu的阻塞时间
:param timeout: float, 大于0的数只有在num_workers>0时才有用超过该时间仍然没有获取到一个batch则报错可以用于
检测是否出现了batch产生阻塞的情况
:param bool use_tqdm: whether to use tqdm to show train progress.
:param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数比如early stopnegative sampling等可以
通过callback机制实现
"""
super(Trainer, self).__init__()
@ -117,6 +122,9 @@ class Trainer(object):
self.validate_every = int(validate_every) if validate_every!=0 else -1
self.best_metric_indicator = None
self.sampler = sampler
self.num_workers = num_workers
self.pin_memory = pin_memory
self.timeout = timeout
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)
if isinstance(optimizer, torch.optim.Optimizer):
@ -237,7 +245,8 @@ class Trainer(object):
len(self.train_data) % self.batch_size != 0)) * self.n_epochs
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False)
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout)
for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
# early stopping

View File

@ -186,11 +186,12 @@ def _check_function_or_method(func):
raise TypeError(f"{type(func)} is not a method or function.")
def _move_dict_value_to_device(*args, device: torch.device):
def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False):
"""
move data to model's device, element in *args should be dict. This is a inplace change.
:param device: torch.device
:param non_blocking: bool, 是否异步将数据转移到cpu, 需要tensor使用pin_memory()
:param args:
:return:
"""
@ -201,7 +202,7 @@ def _move_dict_value_to_device(*args, device: torch.device):
if isinstance(arg, dict):
for key, value in arg.items():
if isinstance(value, torch.Tensor):
arg[key] = value.to(device)
arg[key] = value.to(device, non_blocking=non_blocking)
else:
raise TypeError("Only support `dict` type right now.")

View File

@ -8,7 +8,35 @@ from fastNLP.core.dataset import DataSet
from fastNLP.core.dataset import construct_dataset
from fastNLP.core.instance import Instance
from fastNLP.core.sampler import SequentialSampler
import time
def generate_fake_dataset(num_samples=1000):
"""
产生的DataSet包含以下的field {'1':[], '2':[], '3': [], '4':[]}
:param num_samples: sample的数量
:return:
"""
max_len = 50
min_len = 10
num_features = 4
data_dict = {}
for i in range(num_features):
data = []
lengths = np.random.randint(min_len, max_len, size=(num_samples))
for length in lengths:
data.append(np.random.randint(100, size=length))
data_dict[str(i)] = data
dataset = DataSet(data_dict)
for i in range(num_features):
if np.random.randint(2) == 0:
dataset.set_input(str(i))
else:
dataset.set_target(str(i))
return dataset
class TestCase1(unittest.TestCase):
def test_simple(self):
@ -98,3 +126,47 @@ class TestCase1(unittest.TestCase):
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
for x, y in iter:
print(x, y)
def test_sequential_batch(self):
batch_size = 32
pause_seconds = 0.01
num_samples = 1000
dataset = generate_fake_dataset(num_samples)
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_x, batch_y in batch:
time.sleep(pause_seconds)
def test_multi_workers_batch(self):
batch_size = 32
pause_seconds = 0.01
num_samples = 1000
dataset = generate_fake_dataset(num_samples)
num_workers = 1
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers)
for batch_x, batch_y in batch:
time.sleep(pause_seconds)
num_workers = 2
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers)
end1 = time.time()
for batch_x, batch_y in batch:
time.sleep(pause_seconds)
def test_pin_memory(self):
batch_size = 32
pause_seconds = 0.01
num_samples = 1000
dataset = generate_fake_dataset(num_samples)
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), pin_memory=True)
for batch_x, batch_y in batch:
time.sleep(pause_seconds)
num_workers = 2
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers,
pin_memory=True)
for batch_x, batch_y in batch:
time.sleep(pause_seconds)