From 2e3ef52a7d47598e92707f0ee5c3251eb68bcb95 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Fri, 18 Jan 2019 23:02:15 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B0=86batch=E5=A2=9E=E5=BC=BA=E4=B8=BA?= =?UTF-8?q?=E5=A4=9A=E8=BF=9B=E7=A8=8Bbatch?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/batch.py | 459 +++++++++++++++++++++++++++++++++---- fastNLP/core/fieldarray.py | 2 +- fastNLP/core/trainer.py | 25 +- fastNLP/core/utils.py | 5 +- test/core/test_batch.py | 72 ++++++ 5 files changed, 510 insertions(+), 53 deletions(-) diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index d4fcbf23..05bd5665 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -1,63 +1,59 @@ import numpy as np +import random import torch +import torch.multiprocessing as multiprocessing +from torch.utils.data.dataloader import _set_worker_signal_handlers, _update_worker_pids, \ + _remove_worker_pids, _error_if_any_worker_fails +import signal +import sys +import threading +import traceback +import os +from torch._six import FileNotFoundError from fastNLP.core.sampler import RandomSampler - class Batch(object): - """Batch is an iterable object which iterates over mini-batches. + def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, num_workers=0, pin_memory=False, + timeout=0.0): + """ + Batch is an iterable object which iterates over mini-batches. - Example:: + Example:: - for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()): - # ... + for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()): + # ... - :param DataSet dataset: a DataSet object - :param int batch_size: the size of the batch - :param Sampler sampler: a Sampler object - :param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. + :param DataSet dataset: a DataSet object + :param int batch_size: the size of the batch + :param Sampler sampler: a Sampler object + :param bool as_numpy: If True, return Numpy array when possible. Otherwise, return torch tensors. + :param num_workers: int, 使用多少个进程来准备数据。默认为0, 即使用主线程生成数据。 特性处于实验阶段,谨慎使用。 + 如果DataSet较大,且每个batch的准备时间很短,使用多进程可能并不能提速。 + :param pin_memory: bool, 默认为False. 设置为True时,有可能可以节省tensor从cpu移动到gpu的阻塞时间。 + :param timeout: float, 大于0的数,只有在num_workers>0时才有用。超过该时间仍然没有获取到一个batch则报错,可以用于 + 检测是否出现了batch产生阻塞的情况。 + """ - """ + if num_workers < 0: + raise ValueError('num_workers option cannot be negative; ' + 'use num_workers=0 to disable multiprocessing.') + if timeout < 0: + raise ValueError('timeout option should be non-negative') - def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False): self.dataset = dataset self.batch_size = batch_size self.sampler = sampler + self.num_workers = num_workers + self.pin_memory = pin_memory + self.timeout = timeout self.as_numpy = as_numpy - self.idx_list = None - self.curidx = 0 self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) self.cur_batch_indices = None def __iter__(self): - self.idx_list = self.sampler(self.dataset) - self.curidx = 0 - self.lengths = self.dataset.get_length() - return self - - def __next__(self): - if self.curidx >= len(self.idx_list): - raise StopIteration - else: - endidx = min(self.curidx + self.batch_size, len(self.idx_list)) - batch_x, batch_y = {}, {} - - indices = self.idx_list[self.curidx:endidx] - self.cur_batch_indices = indices - - for field_name, field in self.dataset.get_all_fields().items(): - if field.is_target or field.is_input: - batch = field.get(indices) - if not self.as_numpy and field.padder is not None: - batch = to_tensor(batch, field.dtype) - if field.is_target: - batch_y[field_name] = batch - if field.is_input: - batch_x[field_name] = batch - - self.curidx = endidx - - return batch_x, batch_y + # TODO 现在多线程的情况下每个循环都会重新创建多进程,开销可能有点大。可以考虑直接复用iterator. + return _DataLoaderIter(self) def __len__(self): return self.num_batches @@ -65,7 +61,6 @@ class Batch(object): def get_batch_indices(self): return self.cur_batch_indices - def to_tensor(batch, dtype): try: if dtype in (int, np.int8, np.int16, np.int32, np.int64): @@ -75,3 +70,383 @@ def to_tensor(batch, dtype): except: pass return batch + + +""" +由于多进程涉及到大量问题,包括系统、安全关闭进程等。所以这里直接从pytorch的官方版本修改DataLoader实现多进程加速 +""" + +IS_WINDOWS = sys.platform == "win32" +if IS_WINDOWS: + import ctypes + from ctypes.wintypes import DWORD, BOOL, HANDLE + +if sys.version_info[0] == 2: + import Queue as queue +else: + import queue + + +class ExceptionWrapper(object): + r"""Wraps an exception plus traceback to communicate across threads""" + + def __init__(self, exc_info): + self.exc_type = exc_info[0] + self.exc_msg = "".join(traceback.format_exception(*exc_info)) + + +_use_shared_memory = False +r"""Whether to use shared memory in default_collate""" + +MANAGER_STATUS_CHECK_INTERVAL = 5.0 + +if IS_WINDOWS: + # On Windows, the parent ID of the worker process remains unchanged when the manager process + # is gone, and the only way to check it through OS is to let the worker have a process handle + # of the manager and ask if the process status has changed. + class ManagerWatchdog(object): + def __init__(self): + self.manager_pid = os.getppid() + + self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) + self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) + self.kernel32.OpenProcess.restype = HANDLE + self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) + self.kernel32.WaitForSingleObject.restype = DWORD + + # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx + SYNCHRONIZE = 0x00100000 + self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid) + + if not self.manager_handle: + raise ctypes.WinError(ctypes.get_last_error()) + + def is_alive(self): + # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx + return self.kernel32.WaitForSingleObject(self.manager_handle, 0) != 0 +else: + class ManagerWatchdog(object): + def __init__(self): + self.manager_pid = os.getppid() + + def is_alive(self): + return os.getppid() == self.manager_pid + + +def _worker_loop(dataset, index_queue, data_queue, seed, worker_id, as_numpy): + # 产生数据的循环 + global _use_shared_memory + _use_shared_memory = True + + # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal + # module's handlers are executed after Python returns from C low-level + # handlers, likely when the same fatal signal happened again already. + # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 + _set_worker_signal_handlers() + + torch.set_num_threads(1) + random.seed(seed) + torch.manual_seed(seed) + + watchdog = ManagerWatchdog() + + while True: + try: + # 获取当前batch计数,当前batch的indexes + r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL) + except queue.Empty: + if watchdog.is_alive(): + continue + else: + break + if r is None: + break + idx, batch_indices = r + try: + # 获取相应的batch数据。这里需要修改为从dataset中取出数据并且完成padding + samples = _get_batch_from_dataset(dataset, batch_indices, as_numpy) + except Exception: + data_queue.put((idx, ExceptionWrapper(sys.exc_info()), batch_indices)) + else: + data_queue.put((idx, samples, batch_indices)) + del samples + +def _get_batch_from_dataset(dataset, indices, as_numpy): + """ + 给定indices,从DataSet中取出(batch_x, batch_y). 数据从这里产生后,若没有pin_memory, 则直接传递给Trainer了,如果存在 + pin_memory还会经过一道pin_memory()的处理 + :param dataset: fastNLP.DataSet对象 + :param indices: List[int], index + :param as_numpy: bool, 是否只是转换为numpy + :return: (batch_x, batch_y) + """ + batch_x, batch_y = {}, {} + for field_name, field in dataset.get_all_fields().items(): + if field.is_target or field.is_input: + batch = field.get(indices) + if not as_numpy and field.padder is not None: + batch = to_tensor(batch, field.dtype) + if field.is_target: + batch_y[field_name] = batch + if field.is_input: + batch_x[field_name] = batch + + return batch_x, batch_y + + +def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id): + # 将数据送入到指定的query中. 即如果需要pin_memory, 则 + if pin_memory: + torch.cuda.set_device(device_id) + + while True: + try: + r = in_queue.get() + except Exception: + if done_event.is_set(): + return + raise + if r is None: + break + if isinstance(r[1], ExceptionWrapper): + out_queue.put(r) + continue + idx, batch, batch_indices = r + try: + if pin_memory: + batch = pin_memory_batch(batch) + except Exception: + out_queue.put((idx, ExceptionWrapper(sys.exc_info()), batch_indices)) + else: + out_queue.put((idx, batch, batch_indices)) + + +def pin_memory_batch(batchs): + """ + + :param batchs: (batch_x, batch_y) + :return: (batch_x, batch_y) + """ + for batch_dict in batchs: + for field_name, batch in batch_dict.items(): + if isinstance(batch, torch.Tensor): + batch_dict[field_name] = batch.pin_memory() + return batchs + + +_SIGCHLD_handler_set = False +r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one +handler needs to be set for all DataLoaders in a process.""" + + +def _set_SIGCHLD_handler(): + # Windows doesn't support SIGCHLD handler + if sys.platform == 'win32': + return + # can't set signal in child threads + if not isinstance(threading.current_thread(), threading._MainThread): + return + global _SIGCHLD_handler_set + if _SIGCHLD_handler_set: + return + previous_handler = signal.getsignal(signal.SIGCHLD) + if not callable(previous_handler): + previous_handler = None + + def handler(signum, frame): + # This following call uses `waitid` with WNOHANG from C side. Therefore, + # Python can still get and update the process status successfully. + _error_if_any_worker_fails() + if previous_handler is not None: + previous_handler(signum, frame) + + signal.signal(signal.SIGCHLD, handler) + _SIGCHLD_handler_set = True + + +class _DataLoaderIter(object): + r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" + + def __init__(self, batcher): + self.batcher = batcher + self.dataset = batcher.dataset + self.sampler = batcher.sampler + self.as_numpy = batcher.as_numpy + self.batch_size = batcher.batch_size + self.num_workers = batcher.num_workers + self.pin_memory = batcher.pin_memory and torch.cuda.is_available() + self.timeout = batcher.timeout + self.done_event = threading.Event() + self.curidx = 0 + self.idx_list = self.sampler(self.dataset) + + # self.sample_iter一次返回一个index. 可以通过其他方式替代 + + base_seed = torch.LongTensor(1).random_().item() + + if self.num_workers > 0: + # 每个worker建立一个index queue + self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)] + self.worker_queue_idx = 0 + # 存放获取到的batch + self.worker_result_queue = multiprocessing.SimpleQueue() + self.batches_outstanding = 0 + self.worker_pids_set = False + self.shutdown = False + self.send_idx = 0 + self.rcvd_idx = 0 + self.reorder_dict = {} + + # 这里会将batch的数据输送到self.worker_result_queue中,但是还没有送入到device中 + self.workers = [ + multiprocessing.Process( + target=_worker_loop, + args=(self.dataset, self.index_queues[i], + self.worker_result_queue, base_seed + i, i, self.as_numpy)) + for i in range(self.num_workers)] + + # self.data_queue取数据就行。如果有pin_memory的话,会把数据放到另一个queue + if self.pin_memory or self.timeout > 0: + self.data_queue = queue.Queue() + if self.pin_memory: + maybe_device_id = torch.cuda.current_device() + else: + # do not initialize cuda context if not necessary + maybe_device_id = None + self.worker_manager_thread = threading.Thread( + target=_worker_manager_loop, + args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, + maybe_device_id)) + self.worker_manager_thread.daemon = True + self.worker_manager_thread.start() + else: + self.data_queue = self.worker_result_queue + + # worker们开始工作 + for w in self.workers: + w.daemon = True # ensure that the worker exits on process exit + w.start() + + _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) + _set_SIGCHLD_handler() + self.worker_pids_set = True + + # prime the prefetch loop + for _ in range(2 * self.num_workers): + self._put_indices() + + def _get_batch(self): + if self.timeout > 0: + try: + return self.data_queue.get(timeout=self.timeout) + except queue.Empty: + raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) + else: + return self.data_queue.get() + + def __next__(self): + if self.num_workers == 0: # same-process loading + if self.curidx >= len(self.idx_list): + raise StopIteration + endidx = min(self.curidx + self.batch_size, len(self.idx_list)) + # 直接从数据集中采集数据即可 + indices = self.idx_list[self.curidx:endidx] + self.batcher.cur_batch_indices = indices + batch_x, batch_y = _get_batch_from_dataset(dataset=self.dataset, indices=indices, + as_numpy=self.as_numpy) + if self.pin_memory: + batch_x, batch_y = pin_memory_batch((batch_x, batch_y)) + self.curidx = endidx + return batch_x, batch_y + + # check if the next sample has already been generated + if self.rcvd_idx in self.reorder_dict: + batch = self.reorder_dict.pop(self.rcvd_idx) + return self._process_next_batch(batch) + + # 如果生成的数据为0了,则停止 + if self.batches_outstanding == 0: + self._shutdown_workers() + raise StopIteration + + while True: + assert (not self.shutdown and self.batches_outstanding > 0) + idx, batch, batch_indices = self._get_batch() + self.batches_outstanding -= 1 + if idx != self.rcvd_idx: + # store out-of-order samples + self.reorder_dict[idx] = batch + continue + self.batcher.cur_batch_indices = batch_indices + return self._process_next_batch(batch) + + def __iter__(self): + self.curidx = 0 + + return self + + def _put_indices(self): + # 向采集数据的index queue中放入index + assert self.batches_outstanding < 2 * self.num_workers + if self.curidx >= len(self.idx_list): + indices = None + else: + endidx = min(self.curidx + self.batch_size, len(self.idx_list)) + # 直接从数据集中采集数据即可 + indices = self.idx_list[self.curidx:endidx] + if indices is None: + return + self.index_queues[self.worker_queue_idx].put((self.send_idx, indices)) + self.curidx = endidx + self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers + self.batches_outstanding += 1 + self.send_idx += 1 + + def _process_next_batch(self, batch): + # 只是提醒生成下一个batch indice数据 + self.rcvd_idx += 1 + self._put_indices() + if isinstance(batch, ExceptionWrapper): + raise batch.exc_type(batch.exc_msg) + return batch + + def __getstate__(self): + # TODO: add limited pickling support for sharing an iterator + # across multiple threads for HOGWILD. + # Probably the best way to do this is by moving the sample pushing + # to a separate thread and then just sharing the data queue + # but signalling the end is tricky without a non-blocking API + raise NotImplementedError("_DataLoaderIter cannot be pickled") + + def _shutdown_workers(self): + try: + if not self.shutdown: + self.shutdown = True + self.done_event.set() + for q in self.index_queues: + q.put(None) + # if some workers are waiting to put, make place for them + try: + while not self.worker_result_queue.empty(): + self.worker_result_queue.get() + except (FileNotFoundError, ImportError): + # Many weird errors can happen here due to Python + # shutting down. These are more like obscure Python bugs. + # FileNotFoundError can happen when we rebuild the fd + # fetched from the queue but the socket is already closed + # from the worker side. + # ImportError can happen when the unpickler loads the + # resource from `get`. + pass + # done_event should be sufficient to exit worker_manager_thread, + # but be safe here and put another None + self.worker_result_queue.put(None) + finally: + # removes pids no matter what + if self.worker_pids_set: + _remove_worker_pids(id(self)) + self.worker_pids_set = False + + def __del__(self): + if self.num_workers > 0: + self._shutdown_workers() diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 96854e72..f3fcb3c8 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -408,7 +408,7 @@ class EngChar2DPadder(PadderBase): except: raise ValueError("Field:{} only has one dimension.".format(field_name)) try: - value = value[1] + value = value[0] except: raise ValueError("Field:{} only has two dimensions.".format(field_name)) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index fcafeb32..76a8562b 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -34,8 +34,8 @@ from fastNLP.core.utils import get_func_signature class Trainer(object): def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), - check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False, - callbacks=None): + check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, pin_memory=False, + timeout=0, use_tqdm=True, use_cuda=False, callbacks=None): """ :param DataSet train_data: the training data :param torch.nn.modules.module model: a PyTorch model @@ -46,22 +46,27 @@ class Trainer(object): :param int print_every: step interval to print next training information. Default: -1(no print). :param int validate_every: step interval to do next validation. Default: -1(validate every epoch). :param DataSet dev_data: the validation data - :param bool use_cuda: whether to use CUDA in training. :param str save_path: file path to save models :param Optimizer optimizer: an optimizer object :param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.\\ `ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means - it will raise error if some field are not used. 检查的原理是通过使用很小的batch(默认两个sample)来检查代码是否能够 - 运行,但是这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况,;(2) - 模型中存在累加前向计算次数的,可能会多计算几次。建议将check_code_level设置为-1 + it will raise error if some field are not used. 检查的原理是通过使用很小的batch(默认两个sample)来检查代码是 + 否能够运行,但是这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个 + 固定值的情况;(2)模型中存在累加前向计算次数的,可能会多计算几次。以上情况建议将check_code_level设置为-1 :param str metric_key: a single indicator used to decide the best model based on metric results. It must be one of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets smaller, add "-" in front of the string. For example:: metric_key="-PPL" # language model gets better as perplexity gets smaller :param BaseSampler sampler: method used to generate batch data. + :param num_workers: int, 使用多少个进程来准备数据。默认为0, 即使用主线程生成数据。 特性处于实验阶段,谨慎使用。 + 如果DataSet较大,且每个batch的准备时间很短,使用多进程可能并不能提速。 + :param pin_memory: bool, 默认为False. 设置为True时,有可能可以节省tensor从cpu移动到gpu的阻塞时间。 + :param timeout: float, 大于0的数,只有在num_workers>0时才有用。超过该时间仍然没有获取到一个batch则报错,可以用于 + 检测是否出现了batch产生阻塞的情况。 :param bool use_tqdm: whether to use tqdm to show train progress. - + :param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 + 通过callback机制实现。 """ super(Trainer, self).__init__() @@ -117,6 +122,9 @@ class Trainer(object): self.validate_every = int(validate_every) if validate_every!=0 else -1 self.best_metric_indicator = None self.sampler = sampler + self.num_workers = num_workers + self.pin_memory = pin_memory + self.timeout = timeout self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) if isinstance(optimizer, torch.optim.Optimizer): @@ -237,7 +245,8 @@ class Trainer(object): len(self.train_data) % self.batch_size != 0)) * self.n_epochs with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: avg_loss = 0 - data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) + data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, + num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout) for epoch in range(1, self.n_epochs+1): pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 2e0f383e..695efdfc 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -186,11 +186,12 @@ def _check_function_or_method(func): raise TypeError(f"{type(func)} is not a method or function.") -def _move_dict_value_to_device(*args, device: torch.device): +def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): """ move data to model's device, element in *args should be dict. This is a inplace change. :param device: torch.device + :param non_blocking: bool, 是否异步将数据转移到cpu, 需要tensor使用pin_memory() :param args: :return: """ @@ -201,7 +202,7 @@ def _move_dict_value_to_device(*args, device: torch.device): if isinstance(arg, dict): for key, value in arg.items(): if isinstance(value, torch.Tensor): - arg[key] = value.to(device) + arg[key] = value.to(device, non_blocking=non_blocking) else: raise TypeError("Only support `dict` type right now.") diff --git a/test/core/test_batch.py b/test/core/test_batch.py index 7308ebf0..29a48559 100644 --- a/test/core/test_batch.py +++ b/test/core/test_batch.py @@ -8,7 +8,35 @@ from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import construct_dataset from fastNLP.core.instance import Instance from fastNLP.core.sampler import SequentialSampler +import time +def generate_fake_dataset(num_samples=1000): + """ + 产生的DataSet包含以下的field {'1':[], '2':[], '3': [], '4':[]} + :param num_samples: sample的数量 + :return: + """ + + max_len = 50 + min_len = 10 + num_features = 4 + + data_dict = {} + for i in range(num_features): + data = [] + lengths = np.random.randint(min_len, max_len, size=(num_samples)) + for length in lengths: + data.append(np.random.randint(100, size=length)) + data_dict[str(i)] = data + + dataset = DataSet(data_dict) + + for i in range(num_features): + if np.random.randint(2) == 0: + dataset.set_input(str(i)) + else: + dataset.set_target(str(i)) + return dataset class TestCase1(unittest.TestCase): def test_simple(self): @@ -98,3 +126,47 @@ class TestCase1(unittest.TestCase): iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) for x, y in iter: print(x, y) + + def test_sequential_batch(self): + batch_size = 32 + pause_seconds = 0.01 + num_samples = 1000 + dataset = generate_fake_dataset(num_samples) + + batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler()) + for batch_x, batch_y in batch: + time.sleep(pause_seconds) + + def test_multi_workers_batch(self): + batch_size = 32 + pause_seconds = 0.01 + num_samples = 1000 + dataset = generate_fake_dataset(num_samples) + + num_workers = 1 + batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers) + for batch_x, batch_y in batch: + time.sleep(pause_seconds) + + num_workers = 2 + batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers) + end1 = time.time() + for batch_x, batch_y in batch: + time.sleep(pause_seconds) + + def test_pin_memory(self): + batch_size = 32 + pause_seconds = 0.01 + num_samples = 1000 + dataset = generate_fake_dataset(num_samples) + + batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), pin_memory=True) + for batch_x, batch_y in batch: + time.sleep(pause_seconds) + + num_workers = 2 + batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers, + pin_memory=True) + for batch_x, batch_y in batch: + time.sleep(pause_seconds) +