mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
将batch增强为多进程batch
This commit is contained in:
parent
864c2238f8
commit
2e3ef52a7d
@ -1,63 +1,59 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
import random
|
||||||
import torch
|
import torch
|
||||||
|
import torch.multiprocessing as multiprocessing
|
||||||
|
from torch.utils.data.dataloader import _set_worker_signal_handlers, _update_worker_pids, \
|
||||||
|
_remove_worker_pids, _error_if_any_worker_fails
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import traceback
|
||||||
|
import os
|
||||||
|
from torch._six import FileNotFoundError
|
||||||
|
|
||||||
from fastNLP.core.sampler import RandomSampler
|
from fastNLP.core.sampler import RandomSampler
|
||||||
|
|
||||||
|
|
||||||
class Batch(object):
|
class Batch(object):
|
||||||
"""Batch is an iterable object which iterates over mini-batches.
|
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, num_workers=0, pin_memory=False,
|
||||||
|
timeout=0.0):
|
||||||
|
"""
|
||||||
|
Batch is an iterable object which iterates over mini-batches.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()):
|
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()):
|
||||||
# ...
|
# ...
|
||||||
|
|
||||||
:param DataSet dataset: a DataSet object
|
:param DataSet dataset: a DataSet object
|
||||||
:param int batch_size: the size of the batch
|
:param int batch_size: the size of the batch
|
||||||
:param Sampler sampler: a Sampler object
|
:param Sampler sampler: a Sampler object
|
||||||
:param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors.
|
:param bool as_numpy: If True, return Numpy array when possible. Otherwise, return torch tensors.
|
||||||
|
:param num_workers: int, 使用多少个进程来准备数据。默认为0, 即使用主线程生成数据。 特性处于实验阶段,谨慎使用。
|
||||||
|
如果DataSet较大,且每个batch的准备时间很短,使用多进程可能并不能提速。
|
||||||
|
:param pin_memory: bool, 默认为False. 设置为True时,有可能可以节省tensor从cpu移动到gpu的阻塞时间。
|
||||||
|
:param timeout: float, 大于0的数,只有在num_workers>0时才有用。超过该时间仍然没有获取到一个batch则报错,可以用于
|
||||||
|
检测是否出现了batch产生阻塞的情况。
|
||||||
|
"""
|
||||||
|
|
||||||
"""
|
if num_workers < 0:
|
||||||
|
raise ValueError('num_workers option cannot be negative; '
|
||||||
|
'use num_workers=0 to disable multiprocessing.')
|
||||||
|
if timeout < 0:
|
||||||
|
raise ValueError('timeout option should be non-negative')
|
||||||
|
|
||||||
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False):
|
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.sampler = sampler
|
self.sampler = sampler
|
||||||
|
self.num_workers = num_workers
|
||||||
|
self.pin_memory = pin_memory
|
||||||
|
self.timeout = timeout
|
||||||
self.as_numpy = as_numpy
|
self.as_numpy = as_numpy
|
||||||
self.idx_list = None
|
|
||||||
self.curidx = 0
|
|
||||||
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0)
|
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0)
|
||||||
self.cur_batch_indices = None
|
self.cur_batch_indices = None
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
self.idx_list = self.sampler(self.dataset)
|
# TODO 现在多线程的情况下每个循环都会重新创建多进程,开销可能有点大。可以考虑直接复用iterator.
|
||||||
self.curidx = 0
|
return _DataLoaderIter(self)
|
||||||
self.lengths = self.dataset.get_length()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
if self.curidx >= len(self.idx_list):
|
|
||||||
raise StopIteration
|
|
||||||
else:
|
|
||||||
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
|
|
||||||
batch_x, batch_y = {}, {}
|
|
||||||
|
|
||||||
indices = self.idx_list[self.curidx:endidx]
|
|
||||||
self.cur_batch_indices = indices
|
|
||||||
|
|
||||||
for field_name, field in self.dataset.get_all_fields().items():
|
|
||||||
if field.is_target or field.is_input:
|
|
||||||
batch = field.get(indices)
|
|
||||||
if not self.as_numpy and field.padder is not None:
|
|
||||||
batch = to_tensor(batch, field.dtype)
|
|
||||||
if field.is_target:
|
|
||||||
batch_y[field_name] = batch
|
|
||||||
if field.is_input:
|
|
||||||
batch_x[field_name] = batch
|
|
||||||
|
|
||||||
self.curidx = endidx
|
|
||||||
|
|
||||||
return batch_x, batch_y
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_batches
|
return self.num_batches
|
||||||
@ -65,7 +61,6 @@ class Batch(object):
|
|||||||
def get_batch_indices(self):
|
def get_batch_indices(self):
|
||||||
return self.cur_batch_indices
|
return self.cur_batch_indices
|
||||||
|
|
||||||
|
|
||||||
def to_tensor(batch, dtype):
|
def to_tensor(batch, dtype):
|
||||||
try:
|
try:
|
||||||
if dtype in (int, np.int8, np.int16, np.int32, np.int64):
|
if dtype in (int, np.int8, np.int16, np.int32, np.int64):
|
||||||
@ -75,3 +70,383 @@ def to_tensor(batch, dtype):
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
由于多进程涉及到大量问题,包括系统、安全关闭进程等。所以这里直接从pytorch的官方版本修改DataLoader实现多进程加速
|
||||||
|
"""
|
||||||
|
|
||||||
|
IS_WINDOWS = sys.platform == "win32"
|
||||||
|
if IS_WINDOWS:
|
||||||
|
import ctypes
|
||||||
|
from ctypes.wintypes import DWORD, BOOL, HANDLE
|
||||||
|
|
||||||
|
if sys.version_info[0] == 2:
|
||||||
|
import Queue as queue
|
||||||
|
else:
|
||||||
|
import queue
|
||||||
|
|
||||||
|
|
||||||
|
class ExceptionWrapper(object):
|
||||||
|
r"""Wraps an exception plus traceback to communicate across threads"""
|
||||||
|
|
||||||
|
def __init__(self, exc_info):
|
||||||
|
self.exc_type = exc_info[0]
|
||||||
|
self.exc_msg = "".join(traceback.format_exception(*exc_info))
|
||||||
|
|
||||||
|
|
||||||
|
_use_shared_memory = False
|
||||||
|
r"""Whether to use shared memory in default_collate"""
|
||||||
|
|
||||||
|
MANAGER_STATUS_CHECK_INTERVAL = 5.0
|
||||||
|
|
||||||
|
if IS_WINDOWS:
|
||||||
|
# On Windows, the parent ID of the worker process remains unchanged when the manager process
|
||||||
|
# is gone, and the only way to check it through OS is to let the worker have a process handle
|
||||||
|
# of the manager and ask if the process status has changed.
|
||||||
|
class ManagerWatchdog(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.manager_pid = os.getppid()
|
||||||
|
|
||||||
|
self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
|
||||||
|
self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
|
||||||
|
self.kernel32.OpenProcess.restype = HANDLE
|
||||||
|
self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
|
||||||
|
self.kernel32.WaitForSingleObject.restype = DWORD
|
||||||
|
|
||||||
|
# Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
|
||||||
|
SYNCHRONIZE = 0x00100000
|
||||||
|
self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
|
||||||
|
|
||||||
|
if not self.manager_handle:
|
||||||
|
raise ctypes.WinError(ctypes.get_last_error())
|
||||||
|
|
||||||
|
def is_alive(self):
|
||||||
|
# Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
|
||||||
|
return self.kernel32.WaitForSingleObject(self.manager_handle, 0) != 0
|
||||||
|
else:
|
||||||
|
class ManagerWatchdog(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.manager_pid = os.getppid()
|
||||||
|
|
||||||
|
def is_alive(self):
|
||||||
|
return os.getppid() == self.manager_pid
|
||||||
|
|
||||||
|
|
||||||
|
def _worker_loop(dataset, index_queue, data_queue, seed, worker_id, as_numpy):
|
||||||
|
# 产生数据的循环
|
||||||
|
global _use_shared_memory
|
||||||
|
_use_shared_memory = True
|
||||||
|
|
||||||
|
# Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
|
||||||
|
# module's handlers are executed after Python returns from C low-level
|
||||||
|
# handlers, likely when the same fatal signal happened again already.
|
||||||
|
# https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
|
||||||
|
_set_worker_signal_handlers()
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
watchdog = ManagerWatchdog()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# 获取当前batch计数,当前batch的indexes
|
||||||
|
r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
|
||||||
|
except queue.Empty:
|
||||||
|
if watchdog.is_alive():
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
if r is None:
|
||||||
|
break
|
||||||
|
idx, batch_indices = r
|
||||||
|
try:
|
||||||
|
# 获取相应的batch数据。这里需要修改为从dataset中取出数据并且完成padding
|
||||||
|
samples = _get_batch_from_dataset(dataset, batch_indices, as_numpy)
|
||||||
|
except Exception:
|
||||||
|
data_queue.put((idx, ExceptionWrapper(sys.exc_info()), batch_indices))
|
||||||
|
else:
|
||||||
|
data_queue.put((idx, samples, batch_indices))
|
||||||
|
del samples
|
||||||
|
|
||||||
|
def _get_batch_from_dataset(dataset, indices, as_numpy):
|
||||||
|
"""
|
||||||
|
给定indices,从DataSet中取出(batch_x, batch_y). 数据从这里产生后,若没有pin_memory, 则直接传递给Trainer了,如果存在
|
||||||
|
pin_memory还会经过一道pin_memory()的处理
|
||||||
|
:param dataset: fastNLP.DataSet对象
|
||||||
|
:param indices: List[int], index
|
||||||
|
:param as_numpy: bool, 是否只是转换为numpy
|
||||||
|
:return: (batch_x, batch_y)
|
||||||
|
"""
|
||||||
|
batch_x, batch_y = {}, {}
|
||||||
|
for field_name, field in dataset.get_all_fields().items():
|
||||||
|
if field.is_target or field.is_input:
|
||||||
|
batch = field.get(indices)
|
||||||
|
if not as_numpy and field.padder is not None:
|
||||||
|
batch = to_tensor(batch, field.dtype)
|
||||||
|
if field.is_target:
|
||||||
|
batch_y[field_name] = batch
|
||||||
|
if field.is_input:
|
||||||
|
batch_x[field_name] = batch
|
||||||
|
|
||||||
|
return batch_x, batch_y
|
||||||
|
|
||||||
|
|
||||||
|
def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
|
||||||
|
# 将数据送入到指定的query中. 即如果需要pin_memory, 则
|
||||||
|
if pin_memory:
|
||||||
|
torch.cuda.set_device(device_id)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
r = in_queue.get()
|
||||||
|
except Exception:
|
||||||
|
if done_event.is_set():
|
||||||
|
return
|
||||||
|
raise
|
||||||
|
if r is None:
|
||||||
|
break
|
||||||
|
if isinstance(r[1], ExceptionWrapper):
|
||||||
|
out_queue.put(r)
|
||||||
|
continue
|
||||||
|
idx, batch, batch_indices = r
|
||||||
|
try:
|
||||||
|
if pin_memory:
|
||||||
|
batch = pin_memory_batch(batch)
|
||||||
|
except Exception:
|
||||||
|
out_queue.put((idx, ExceptionWrapper(sys.exc_info()), batch_indices))
|
||||||
|
else:
|
||||||
|
out_queue.put((idx, batch, batch_indices))
|
||||||
|
|
||||||
|
|
||||||
|
def pin_memory_batch(batchs):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param batchs: (batch_x, batch_y)
|
||||||
|
:return: (batch_x, batch_y)
|
||||||
|
"""
|
||||||
|
for batch_dict in batchs:
|
||||||
|
for field_name, batch in batch_dict.items():
|
||||||
|
if isinstance(batch, torch.Tensor):
|
||||||
|
batch_dict[field_name] = batch.pin_memory()
|
||||||
|
return batchs
|
||||||
|
|
||||||
|
|
||||||
|
_SIGCHLD_handler_set = False
|
||||||
|
r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
|
||||||
|
handler needs to be set for all DataLoaders in a process."""
|
||||||
|
|
||||||
|
|
||||||
|
def _set_SIGCHLD_handler():
|
||||||
|
# Windows doesn't support SIGCHLD handler
|
||||||
|
if sys.platform == 'win32':
|
||||||
|
return
|
||||||
|
# can't set signal in child threads
|
||||||
|
if not isinstance(threading.current_thread(), threading._MainThread):
|
||||||
|
return
|
||||||
|
global _SIGCHLD_handler_set
|
||||||
|
if _SIGCHLD_handler_set:
|
||||||
|
return
|
||||||
|
previous_handler = signal.getsignal(signal.SIGCHLD)
|
||||||
|
if not callable(previous_handler):
|
||||||
|
previous_handler = None
|
||||||
|
|
||||||
|
def handler(signum, frame):
|
||||||
|
# This following call uses `waitid` with WNOHANG from C side. Therefore,
|
||||||
|
# Python can still get and update the process status successfully.
|
||||||
|
_error_if_any_worker_fails()
|
||||||
|
if previous_handler is not None:
|
||||||
|
previous_handler(signum, frame)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGCHLD, handler)
|
||||||
|
_SIGCHLD_handler_set = True
|
||||||
|
|
||||||
|
|
||||||
|
class _DataLoaderIter(object):
|
||||||
|
r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
|
||||||
|
|
||||||
|
def __init__(self, batcher):
|
||||||
|
self.batcher = batcher
|
||||||
|
self.dataset = batcher.dataset
|
||||||
|
self.sampler = batcher.sampler
|
||||||
|
self.as_numpy = batcher.as_numpy
|
||||||
|
self.batch_size = batcher.batch_size
|
||||||
|
self.num_workers = batcher.num_workers
|
||||||
|
self.pin_memory = batcher.pin_memory and torch.cuda.is_available()
|
||||||
|
self.timeout = batcher.timeout
|
||||||
|
self.done_event = threading.Event()
|
||||||
|
self.curidx = 0
|
||||||
|
self.idx_list = self.sampler(self.dataset)
|
||||||
|
|
||||||
|
# self.sample_iter一次返回一个index. 可以通过其他方式替代
|
||||||
|
|
||||||
|
base_seed = torch.LongTensor(1).random_().item()
|
||||||
|
|
||||||
|
if self.num_workers > 0:
|
||||||
|
# 每个worker建立一个index queue
|
||||||
|
self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
|
||||||
|
self.worker_queue_idx = 0
|
||||||
|
# 存放获取到的batch
|
||||||
|
self.worker_result_queue = multiprocessing.SimpleQueue()
|
||||||
|
self.batches_outstanding = 0
|
||||||
|
self.worker_pids_set = False
|
||||||
|
self.shutdown = False
|
||||||
|
self.send_idx = 0
|
||||||
|
self.rcvd_idx = 0
|
||||||
|
self.reorder_dict = {}
|
||||||
|
|
||||||
|
# 这里会将batch的数据输送到self.worker_result_queue中,但是还没有送入到device中
|
||||||
|
self.workers = [
|
||||||
|
multiprocessing.Process(
|
||||||
|
target=_worker_loop,
|
||||||
|
args=(self.dataset, self.index_queues[i],
|
||||||
|
self.worker_result_queue, base_seed + i, i, self.as_numpy))
|
||||||
|
for i in range(self.num_workers)]
|
||||||
|
|
||||||
|
# self.data_queue取数据就行。如果有pin_memory的话,会把数据放到另一个queue
|
||||||
|
if self.pin_memory or self.timeout > 0:
|
||||||
|
self.data_queue = queue.Queue()
|
||||||
|
if self.pin_memory:
|
||||||
|
maybe_device_id = torch.cuda.current_device()
|
||||||
|
else:
|
||||||
|
# do not initialize cuda context if not necessary
|
||||||
|
maybe_device_id = None
|
||||||
|
self.worker_manager_thread = threading.Thread(
|
||||||
|
target=_worker_manager_loop,
|
||||||
|
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
|
||||||
|
maybe_device_id))
|
||||||
|
self.worker_manager_thread.daemon = True
|
||||||
|
self.worker_manager_thread.start()
|
||||||
|
else:
|
||||||
|
self.data_queue = self.worker_result_queue
|
||||||
|
|
||||||
|
# worker们开始工作
|
||||||
|
for w in self.workers:
|
||||||
|
w.daemon = True # ensure that the worker exits on process exit
|
||||||
|
w.start()
|
||||||
|
|
||||||
|
_update_worker_pids(id(self), tuple(w.pid for w in self.workers))
|
||||||
|
_set_SIGCHLD_handler()
|
||||||
|
self.worker_pids_set = True
|
||||||
|
|
||||||
|
# prime the prefetch loop
|
||||||
|
for _ in range(2 * self.num_workers):
|
||||||
|
self._put_indices()
|
||||||
|
|
||||||
|
def _get_batch(self):
|
||||||
|
if self.timeout > 0:
|
||||||
|
try:
|
||||||
|
return self.data_queue.get(timeout=self.timeout)
|
||||||
|
except queue.Empty:
|
||||||
|
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
|
||||||
|
else:
|
||||||
|
return self.data_queue.get()
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.num_workers == 0: # same-process loading
|
||||||
|
if self.curidx >= len(self.idx_list):
|
||||||
|
raise StopIteration
|
||||||
|
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
|
||||||
|
# 直接从数据集中采集数据即可
|
||||||
|
indices = self.idx_list[self.curidx:endidx]
|
||||||
|
self.batcher.cur_batch_indices = indices
|
||||||
|
batch_x, batch_y = _get_batch_from_dataset(dataset=self.dataset, indices=indices,
|
||||||
|
as_numpy=self.as_numpy)
|
||||||
|
if self.pin_memory:
|
||||||
|
batch_x, batch_y = pin_memory_batch((batch_x, batch_y))
|
||||||
|
self.curidx = endidx
|
||||||
|
return batch_x, batch_y
|
||||||
|
|
||||||
|
# check if the next sample has already been generated
|
||||||
|
if self.rcvd_idx in self.reorder_dict:
|
||||||
|
batch = self.reorder_dict.pop(self.rcvd_idx)
|
||||||
|
return self._process_next_batch(batch)
|
||||||
|
|
||||||
|
# 如果生成的数据为0了,则停止
|
||||||
|
if self.batches_outstanding == 0:
|
||||||
|
self._shutdown_workers()
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
while True:
|
||||||
|
assert (not self.shutdown and self.batches_outstanding > 0)
|
||||||
|
idx, batch, batch_indices = self._get_batch()
|
||||||
|
self.batches_outstanding -= 1
|
||||||
|
if idx != self.rcvd_idx:
|
||||||
|
# store out-of-order samples
|
||||||
|
self.reorder_dict[idx] = batch
|
||||||
|
continue
|
||||||
|
self.batcher.cur_batch_indices = batch_indices
|
||||||
|
return self._process_next_batch(batch)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.curidx = 0
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _put_indices(self):
|
||||||
|
# 向采集数据的index queue中放入index
|
||||||
|
assert self.batches_outstanding < 2 * self.num_workers
|
||||||
|
if self.curidx >= len(self.idx_list):
|
||||||
|
indices = None
|
||||||
|
else:
|
||||||
|
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
|
||||||
|
# 直接从数据集中采集数据即可
|
||||||
|
indices = self.idx_list[self.curidx:endidx]
|
||||||
|
if indices is None:
|
||||||
|
return
|
||||||
|
self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
|
||||||
|
self.curidx = endidx
|
||||||
|
self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
|
||||||
|
self.batches_outstanding += 1
|
||||||
|
self.send_idx += 1
|
||||||
|
|
||||||
|
def _process_next_batch(self, batch):
|
||||||
|
# 只是提醒生成下一个batch indice数据
|
||||||
|
self.rcvd_idx += 1
|
||||||
|
self._put_indices()
|
||||||
|
if isinstance(batch, ExceptionWrapper):
|
||||||
|
raise batch.exc_type(batch.exc_msg)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
# TODO: add limited pickling support for sharing an iterator
|
||||||
|
# across multiple threads for HOGWILD.
|
||||||
|
# Probably the best way to do this is by moving the sample pushing
|
||||||
|
# to a separate thread and then just sharing the data queue
|
||||||
|
# but signalling the end is tricky without a non-blocking API
|
||||||
|
raise NotImplementedError("_DataLoaderIter cannot be pickled")
|
||||||
|
|
||||||
|
def _shutdown_workers(self):
|
||||||
|
try:
|
||||||
|
if not self.shutdown:
|
||||||
|
self.shutdown = True
|
||||||
|
self.done_event.set()
|
||||||
|
for q in self.index_queues:
|
||||||
|
q.put(None)
|
||||||
|
# if some workers are waiting to put, make place for them
|
||||||
|
try:
|
||||||
|
while not self.worker_result_queue.empty():
|
||||||
|
self.worker_result_queue.get()
|
||||||
|
except (FileNotFoundError, ImportError):
|
||||||
|
# Many weird errors can happen here due to Python
|
||||||
|
# shutting down. These are more like obscure Python bugs.
|
||||||
|
# FileNotFoundError can happen when we rebuild the fd
|
||||||
|
# fetched from the queue but the socket is already closed
|
||||||
|
# from the worker side.
|
||||||
|
# ImportError can happen when the unpickler loads the
|
||||||
|
# resource from `get`.
|
||||||
|
pass
|
||||||
|
# done_event should be sufficient to exit worker_manager_thread,
|
||||||
|
# but be safe here and put another None
|
||||||
|
self.worker_result_queue.put(None)
|
||||||
|
finally:
|
||||||
|
# removes pids no matter what
|
||||||
|
if self.worker_pids_set:
|
||||||
|
_remove_worker_pids(id(self))
|
||||||
|
self.worker_pids_set = False
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.num_workers > 0:
|
||||||
|
self._shutdown_workers()
|
||||||
|
@ -408,7 +408,7 @@ class EngChar2DPadder(PadderBase):
|
|||||||
except:
|
except:
|
||||||
raise ValueError("Field:{} only has one dimension.".format(field_name))
|
raise ValueError("Field:{} only has one dimension.".format(field_name))
|
||||||
try:
|
try:
|
||||||
value = value[1]
|
value = value[0]
|
||||||
except:
|
except:
|
||||||
raise ValueError("Field:{} only has two dimensions.".format(field_name))
|
raise ValueError("Field:{} only has two dimensions.".format(field_name))
|
||||||
|
|
||||||
|
@ -34,8 +34,8 @@ from fastNLP.core.utils import get_func_signature
|
|||||||
class Trainer(object):
|
class Trainer(object):
|
||||||
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
|
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
|
||||||
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0),
|
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0),
|
||||||
check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False,
|
check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, pin_memory=False,
|
||||||
callbacks=None):
|
timeout=0, use_tqdm=True, use_cuda=False, callbacks=None):
|
||||||
"""
|
"""
|
||||||
:param DataSet train_data: the training data
|
:param DataSet train_data: the training data
|
||||||
:param torch.nn.modules.module model: a PyTorch model
|
:param torch.nn.modules.module model: a PyTorch model
|
||||||
@ -46,22 +46,27 @@ class Trainer(object):
|
|||||||
:param int print_every: step interval to print next training information. Default: -1(no print).
|
:param int print_every: step interval to print next training information. Default: -1(no print).
|
||||||
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch).
|
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch).
|
||||||
:param DataSet dev_data: the validation data
|
:param DataSet dev_data: the validation data
|
||||||
:param bool use_cuda: whether to use CUDA in training.
|
|
||||||
:param str save_path: file path to save models
|
:param str save_path: file path to save models
|
||||||
:param Optimizer optimizer: an optimizer object
|
:param Optimizer optimizer: an optimizer object
|
||||||
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.\\
|
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.\\
|
||||||
`ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means
|
`ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means
|
||||||
it will raise error if some field are not used. 检查的原理是通过使用很小的batch(默认两个sample)来检查代码是否能够
|
it will raise error if some field are not used. 检查的原理是通过使用很小的batch(默认两个sample)来检查代码是
|
||||||
运行,但是这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况,;(2)
|
否能够运行,但是这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个
|
||||||
模型中存在累加前向计算次数的,可能会多计算几次。建议将check_code_level设置为-1
|
固定值的情况;(2)模型中存在累加前向计算次数的,可能会多计算几次。以上情况建议将check_code_level设置为-1
|
||||||
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one
|
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one
|
||||||
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets
|
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets
|
||||||
smaller, add "-" in front of the string. For example::
|
smaller, add "-" in front of the string. For example::
|
||||||
|
|
||||||
metric_key="-PPL" # language model gets better as perplexity gets smaller
|
metric_key="-PPL" # language model gets better as perplexity gets smaller
|
||||||
:param BaseSampler sampler: method used to generate batch data.
|
:param BaseSampler sampler: method used to generate batch data.
|
||||||
|
:param num_workers: int, 使用多少个进程来准备数据。默认为0, 即使用主线程生成数据。 特性处于实验阶段,谨慎使用。
|
||||||
|
如果DataSet较大,且每个batch的准备时间很短,使用多进程可能并不能提速。
|
||||||
|
:param pin_memory: bool, 默认为False. 设置为True时,有可能可以节省tensor从cpu移动到gpu的阻塞时间。
|
||||||
|
:param timeout: float, 大于0的数,只有在num_workers>0时才有用。超过该时间仍然没有获取到一个batch则报错,可以用于
|
||||||
|
检测是否出现了batch产生阻塞的情况。
|
||||||
:param bool use_tqdm: whether to use tqdm to show train progress.
|
:param bool use_tqdm: whether to use tqdm to show train progress.
|
||||||
|
:param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以
|
||||||
|
通过callback机制实现。
|
||||||
"""
|
"""
|
||||||
super(Trainer, self).__init__()
|
super(Trainer, self).__init__()
|
||||||
|
|
||||||
@ -117,6 +122,9 @@ class Trainer(object):
|
|||||||
self.validate_every = int(validate_every) if validate_every!=0 else -1
|
self.validate_every = int(validate_every) if validate_every!=0 else -1
|
||||||
self.best_metric_indicator = None
|
self.best_metric_indicator = None
|
||||||
self.sampler = sampler
|
self.sampler = sampler
|
||||||
|
self.num_workers = num_workers
|
||||||
|
self.pin_memory = pin_memory
|
||||||
|
self.timeout = timeout
|
||||||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)
|
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)
|
||||||
|
|
||||||
if isinstance(optimizer, torch.optim.Optimizer):
|
if isinstance(optimizer, torch.optim.Optimizer):
|
||||||
@ -237,7 +245,8 @@ class Trainer(object):
|
|||||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs
|
len(self.train_data) % self.batch_size != 0)) * self.n_epochs
|
||||||
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
|
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
|
||||||
avg_loss = 0
|
avg_loss = 0
|
||||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False)
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
|
||||||
|
num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout)
|
||||||
for epoch in range(1, self.n_epochs+1):
|
for epoch in range(1, self.n_epochs+1):
|
||||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
|
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
|
||||||
# early stopping
|
# early stopping
|
||||||
|
@ -186,11 +186,12 @@ def _check_function_or_method(func):
|
|||||||
raise TypeError(f"{type(func)} is not a method or function.")
|
raise TypeError(f"{type(func)} is not a method or function.")
|
||||||
|
|
||||||
|
|
||||||
def _move_dict_value_to_device(*args, device: torch.device):
|
def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
move data to model's device, element in *args should be dict. This is a inplace change.
|
move data to model's device, element in *args should be dict. This is a inplace change.
|
||||||
:param device: torch.device
|
:param device: torch.device
|
||||||
|
:param non_blocking: bool, 是否异步将数据转移到cpu, 需要tensor使用pin_memory()
|
||||||
:param args:
|
:param args:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
@ -201,7 +202,7 @@ def _move_dict_value_to_device(*args, device: torch.device):
|
|||||||
if isinstance(arg, dict):
|
if isinstance(arg, dict):
|
||||||
for key, value in arg.items():
|
for key, value in arg.items():
|
||||||
if isinstance(value, torch.Tensor):
|
if isinstance(value, torch.Tensor):
|
||||||
arg[key] = value.to(device)
|
arg[key] = value.to(device, non_blocking=non_blocking)
|
||||||
else:
|
else:
|
||||||
raise TypeError("Only support `dict` type right now.")
|
raise TypeError("Only support `dict` type right now.")
|
||||||
|
|
||||||
|
@ -8,7 +8,35 @@ from fastNLP.core.dataset import DataSet
|
|||||||
from fastNLP.core.dataset import construct_dataset
|
from fastNLP.core.dataset import construct_dataset
|
||||||
from fastNLP.core.instance import Instance
|
from fastNLP.core.instance import Instance
|
||||||
from fastNLP.core.sampler import SequentialSampler
|
from fastNLP.core.sampler import SequentialSampler
|
||||||
|
import time
|
||||||
|
|
||||||
|
def generate_fake_dataset(num_samples=1000):
|
||||||
|
"""
|
||||||
|
产生的DataSet包含以下的field {'1':[], '2':[], '3': [], '4':[]}
|
||||||
|
:param num_samples: sample的数量
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_len = 50
|
||||||
|
min_len = 10
|
||||||
|
num_features = 4
|
||||||
|
|
||||||
|
data_dict = {}
|
||||||
|
for i in range(num_features):
|
||||||
|
data = []
|
||||||
|
lengths = np.random.randint(min_len, max_len, size=(num_samples))
|
||||||
|
for length in lengths:
|
||||||
|
data.append(np.random.randint(100, size=length))
|
||||||
|
data_dict[str(i)] = data
|
||||||
|
|
||||||
|
dataset = DataSet(data_dict)
|
||||||
|
|
||||||
|
for i in range(num_features):
|
||||||
|
if np.random.randint(2) == 0:
|
||||||
|
dataset.set_input(str(i))
|
||||||
|
else:
|
||||||
|
dataset.set_target(str(i))
|
||||||
|
return dataset
|
||||||
|
|
||||||
class TestCase1(unittest.TestCase):
|
class TestCase1(unittest.TestCase):
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
@ -98,3 +126,47 @@ class TestCase1(unittest.TestCase):
|
|||||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
|
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
|
||||||
for x, y in iter:
|
for x, y in iter:
|
||||||
print(x, y)
|
print(x, y)
|
||||||
|
|
||||||
|
def test_sequential_batch(self):
|
||||||
|
batch_size = 32
|
||||||
|
pause_seconds = 0.01
|
||||||
|
num_samples = 1000
|
||||||
|
dataset = generate_fake_dataset(num_samples)
|
||||||
|
|
||||||
|
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler())
|
||||||
|
for batch_x, batch_y in batch:
|
||||||
|
time.sleep(pause_seconds)
|
||||||
|
|
||||||
|
def test_multi_workers_batch(self):
|
||||||
|
batch_size = 32
|
||||||
|
pause_seconds = 0.01
|
||||||
|
num_samples = 1000
|
||||||
|
dataset = generate_fake_dataset(num_samples)
|
||||||
|
|
||||||
|
num_workers = 1
|
||||||
|
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers)
|
||||||
|
for batch_x, batch_y in batch:
|
||||||
|
time.sleep(pause_seconds)
|
||||||
|
|
||||||
|
num_workers = 2
|
||||||
|
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers)
|
||||||
|
end1 = time.time()
|
||||||
|
for batch_x, batch_y in batch:
|
||||||
|
time.sleep(pause_seconds)
|
||||||
|
|
||||||
|
def test_pin_memory(self):
|
||||||
|
batch_size = 32
|
||||||
|
pause_seconds = 0.01
|
||||||
|
num_samples = 1000
|
||||||
|
dataset = generate_fake_dataset(num_samples)
|
||||||
|
|
||||||
|
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), pin_memory=True)
|
||||||
|
for batch_x, batch_y in batch:
|
||||||
|
time.sleep(pause_seconds)
|
||||||
|
|
||||||
|
num_workers = 2
|
||||||
|
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers,
|
||||||
|
pin_memory=True)
|
||||||
|
for batch_x, batch_y in batch:
|
||||||
|
time.sleep(pause_seconds)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user