mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
添加了 reproducible batch sampler 的测试
This commit is contained in:
parent
7d5ce620f4
commit
1528107480
@ -8,7 +8,6 @@ import math
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Dict, Union, List
|
from typing import Dict, Union, List
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
import os
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改;
|
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改;
|
||||||
from tests.helpers.datasets.normal_data import NormalIterator
|
from tests.helpers.datasets.normal_data import NormalSampler
|
||||||
|
|
||||||
|
|
||||||
class Test_WrapDataLoader:
|
class Test_WrapDataLoader:
|
||||||
@ -9,7 +9,7 @@ class Test_WrapDataLoader:
|
|||||||
def test_normal_generator(self):
|
def test_normal_generator(self):
|
||||||
all_sanity_batches = [4, 20, 100]
|
all_sanity_batches = [4, 20, 100]
|
||||||
for sanity_batches in all_sanity_batches:
|
for sanity_batches in all_sanity_batches:
|
||||||
data = NormalIterator(num_of_data=1000)
|
data = NormalSampler(num_of_data=1000)
|
||||||
wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches)
|
wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches)
|
||||||
dataloader = iter(wrapper(dataloader=data))
|
dataloader = iter(wrapper(dataloader=data))
|
||||||
mark = 0
|
mark = 0
|
||||||
|
@ -1,161 +1,131 @@
|
|||||||
from array import array
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from array import array
|
||||||
|
|
||||||
|
from tests.helpers.datasets.normal_data import NormalSampler, NormalBatchSampler
|
||||||
from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler
|
from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler
|
||||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
|
|
||||||
from tests.helpers.datasets.torch_data import TorchNormalDataset
|
|
||||||
|
|
||||||
#
|
|
||||||
# class TestReproducibleBatchSampler:
|
class TestReproducibleBatchSampler:
|
||||||
# # TODO 拆分测试,在这里只测试一个东西
|
def test_1(self):
|
||||||
# def test_torch_dataloader_1(self):
|
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;
|
||||||
# import torch
|
|
||||||
# from torch.utils.data import DataLoader
|
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False)
|
||||||
# # no shuffle
|
|
||||||
# before_batch_size = 7
|
forward_steps = 3
|
||||||
# dataset = TorchNormalDataset(num_of_data=100)
|
iterator = iter(reproduce_batch_sampler)
|
||||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
|
i = 0
|
||||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
while i < forward_steps:
|
||||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
next(iterator)
|
||||||
#
|
i += 1
|
||||||
# forward_steps = 3
|
|
||||||
# iter_dataloader = iter(dataloader)
|
# 保存状态;
|
||||||
# for _ in range(forward_steps):
|
state = reproduce_batch_sampler.state_dict()
|
||||||
# next(iter_dataloader)
|
|
||||||
#
|
assert state == {"index_list": array("I", list(range(100))),
|
||||||
# # 1. 保存状态
|
"num_consumed_samples": forward_steps * 4,
|
||||||
# _get_re_batchsampler = dataloader.batch_sampler
|
"sampler_type": "ReproduceBatchSampler"}
|
||||||
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
|
|
||||||
# state = _get_re_batchsampler.state_dict()
|
# 重新生成一个 batchsampler 然后加载状态;
|
||||||
# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
|
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;
|
||||||
# "sampler_type": "ReproduceBatchSampler"}
|
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False)
|
||||||
#
|
reproduce_batch_sampler.load_state_dict(state)
|
||||||
# # 2. 断点重训,重新生成一个 dataloader;
|
|
||||||
# # 不改变 batch_size;
|
real_res = []
|
||||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
|
supposed_res = (list(range(12, 16)), list(range(16, 20)))
|
||||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
forward_steps = 2
|
||||||
# re_batchsampler.load_state_dict(state)
|
iter_dataloader = iter(reproduce_batch_sampler)
|
||||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
for _ in range(forward_steps):
|
||||||
#
|
real_res.append(next(iter_dataloader))
|
||||||
# real_res = []
|
|
||||||
# supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35))))
|
for i in range(forward_steps):
|
||||||
# forward_steps = 2
|
assert supposed_res[i] == real_res[i]
|
||||||
# iter_dataloader = iter(dataloader)
|
|
||||||
# for _ in range(forward_steps):
|
# 改变 batchsize;
|
||||||
# real_res.append(next(iter_dataloader))
|
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;
|
||||||
#
|
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=7, drop_last=False)
|
||||||
# for i in range(forward_steps):
|
reproduce_batch_sampler.load_state_dict(state)
|
||||||
# assert all(real_res[i] == supposed_res[i])
|
|
||||||
#
|
real_res = []
|
||||||
# # 改变 batch_size;
|
supposed_res = (list(range(12, 19)), list(range(19, 26)))
|
||||||
# after_batch_size = 3
|
forward_steps = 2
|
||||||
# dataloader = DataLoader(dataset, batch_size=after_batch_size)
|
iter_dataloader = iter(reproduce_batch_sampler)
|
||||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
for _ in range(forward_steps):
|
||||||
# re_batchsampler.load_state_dict(state)
|
real_res.append(next(iter_dataloader))
|
||||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
|
||||||
#
|
for i in range(forward_steps):
|
||||||
# real_res = []
|
assert supposed_res[i] == real_res[i]
|
||||||
# supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27))))
|
|
||||||
# forward_steps = 2
|
# 断点重训的第二轮是否是一个完整的 dataloader;
|
||||||
# iter_dataloader = iter(dataloader)
|
# 先把断点重训所在的那一个 epoch 跑完;
|
||||||
# for _ in range(forward_steps):
|
begin_idx = 26
|
||||||
# real_res.append(next(iter_dataloader))
|
while True:
|
||||||
#
|
try:
|
||||||
# for i in range(forward_steps):
|
data = next(iter_dataloader)
|
||||||
# assert all(real_res[i] == supposed_res[i])
|
_batch_size = len(data)
|
||||||
#
|
assert data == list(range(begin_idx, begin_idx + _batch_size))
|
||||||
# # 断点重训的第二轮是否是一个完整的 dataloader;
|
begin_idx += _batch_size
|
||||||
# # 先把断点重训所在的那一个 epoch 跑完;
|
except StopIteration:
|
||||||
# begin_idx = 27
|
break
|
||||||
# while True:
|
|
||||||
# try:
|
# 开始新的一轮;
|
||||||
# data = next(iter_dataloader)
|
begin_idx = 0
|
||||||
# _batch_size = len(data)
|
iter_dataloader = iter(reproduce_batch_sampler)
|
||||||
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
|
while True:
|
||||||
# begin_idx += _batch_size
|
try:
|
||||||
# except StopIteration:
|
data = next(iter_dataloader)
|
||||||
# break
|
_batch_size = len(data)
|
||||||
#
|
assert data == list(range(begin_idx, begin_idx + _batch_size))
|
||||||
# # 开始新的一轮;
|
begin_idx += _batch_size
|
||||||
# begin_idx = 0
|
except StopIteration:
|
||||||
# iter_dataloader = iter(dataloader)
|
break
|
||||||
# while True:
|
|
||||||
# try:
|
def test_2(self):
|
||||||
# data = next(iter_dataloader)
|
|
||||||
# _batch_size = len(data)
|
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
|
||||||
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
|
before_batch_size = 7
|
||||||
# begin_idx += _batch_size
|
sampler = NormalSampler(num_of_data=100)
|
||||||
# except StopIteration:
|
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
|
||||||
# break
|
reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False)
|
||||||
#
|
|
||||||
# def test_torch_dataloader_2(self):
|
# 将一轮的所有数据保存下来,看是否恢复的是正确的;
|
||||||
# # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
|
all_supposed_data = []
|
||||||
# from torch.utils.data import DataLoader
|
forward_steps = 3
|
||||||
# # no shuffle
|
iter_dataloader = iter(reproduce_batch_sampler)
|
||||||
# before_batch_size = 7
|
for _ in range(forward_steps):
|
||||||
# dataset = TorchNormalDataset(num_of_data=100)
|
all_supposed_data.extend(next(iter_dataloader))
|
||||||
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
|
|
||||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
|
# 1. 保存状态
|
||||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
state = reproduce_batch_sampler.state_dict()
|
||||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
|
||||||
#
|
# 2. 断点重训,重新生成一个 dataloader;
|
||||||
# # 将一轮的所有数据保存下来,看是否恢复的是正确的;
|
# 不改变 batch_size;
|
||||||
# all_supposed_data = []
|
sampler = NormalSampler(num_of_data=100, shuffle=True)
|
||||||
# forward_steps = 3
|
reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False)
|
||||||
# iter_dataloader = iter(dataloader)
|
reproduce_batch_sampler.load_state_dict(state)
|
||||||
# for _ in range(forward_steps):
|
|
||||||
# all_supposed_data.extend(next(iter_dataloader).tolist())
|
# 先把这一轮的数据过完;
|
||||||
#
|
pre_index_list = reproduce_batch_sampler.state_dict()["index_list"]
|
||||||
# # 1. 保存状态
|
iter_dataloader = iter(reproduce_batch_sampler)
|
||||||
# _get_re_batchsampler = dataloader.batch_sampler
|
while True:
|
||||||
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
|
try:
|
||||||
# state = _get_re_batchsampler.state_dict()
|
all_supposed_data.extend(next(iter_dataloader))
|
||||||
#
|
except StopIteration:
|
||||||
# # 2. 断点重训,重新生成一个 dataloader;
|
break
|
||||||
# # 不改变 batch_size;
|
assert all_supposed_data == list(pre_index_list)
|
||||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
|
|
||||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
# 重新开启新的一轮;
|
||||||
# re_batchsampler.load_state_dict(state)
|
for _ in range(3):
|
||||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
iter_dataloader = iter(reproduce_batch_sampler)
|
||||||
#
|
res = []
|
||||||
# # 先把这一轮的数据过完;
|
while True:
|
||||||
# pre_index_list = dataloader.batch_sampler.state_dict()["index_list"]
|
try:
|
||||||
# while True:
|
res.extend(next(iter_dataloader))
|
||||||
# try:
|
except StopIteration:
|
||||||
# all_supposed_data.extend(next(iter_dataloader).tolist())
|
break
|
||||||
# except StopIteration:
|
assert res != all_supposed_data
|
||||||
# break
|
|
||||||
# assert all_supposed_data == list(pre_index_list)
|
|
||||||
#
|
|
||||||
# # 重新开启新的一轮;
|
|
||||||
# for _ in range(3):
|
|
||||||
# iter_dataloader = iter(dataloader)
|
|
||||||
# res = []
|
|
||||||
# while True:
|
|
||||||
# try:
|
|
||||||
# res.append(next(iter_dataloader))
|
|
||||||
# except StopIteration:
|
|
||||||
# break
|
|
||||||
#
|
|
||||||
# def test_3(self):
|
|
||||||
# import torch
|
|
||||||
# from torch.utils.data import DataLoader
|
|
||||||
# before_batch_size = 7
|
|
||||||
# dataset = TorchNormalDataset(num_of_data=100)
|
|
||||||
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
|
|
||||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
|
|
||||||
#
|
|
||||||
# for idx, data in enumerate(dataloader):
|
|
||||||
# if idx > 3:
|
|
||||||
# break
|
|
||||||
#
|
|
||||||
# iterator = iter(dataloader)
|
|
||||||
# for each in iterator:
|
|
||||||
# pass
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetWithVaryLength:
|
class DatasetWithVaryLength:
|
||||||
|
141
tests/core/samplers/test_reproducible_batch_sampler_torch.py
Normal file
141
tests/core/samplers/test_reproducible_batch_sampler_torch.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
from array import array
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastNLP.core.samplers import ReproduceBatchSampler
|
||||||
|
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
|
||||||
|
from tests.helpers.datasets.torch_data import TorchNormalDataset
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.torch
|
||||||
|
class TestReproducibleBatchSamplerTorch:
|
||||||
|
def test_torch_dataloader_1(self):
|
||||||
|
# no shuffle
|
||||||
|
before_batch_size = 7
|
||||||
|
dataset = TorchNormalDataset(num_of_data=100)
|
||||||
|
dataloader = DataLoader(dataset, batch_size=before_batch_size)
|
||||||
|
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||||
|
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
||||||
|
|
||||||
|
forward_steps = 3
|
||||||
|
iter_dataloader = iter(dataloader)
|
||||||
|
for _ in range(forward_steps):
|
||||||
|
next(iter_dataloader)
|
||||||
|
|
||||||
|
# 1. 保存状态
|
||||||
|
_get_re_batchsampler = dataloader.batch_sampler
|
||||||
|
assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
|
||||||
|
state = _get_re_batchsampler.state_dict()
|
||||||
|
assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
|
||||||
|
"sampler_type": "ReproduceBatchSampler"}
|
||||||
|
|
||||||
|
# 2. 断点重训,重新生成一个 dataloader;
|
||||||
|
# 不改变 batch_size;
|
||||||
|
dataloader = DataLoader(dataset, batch_size=before_batch_size)
|
||||||
|
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||||
|
re_batchsampler.load_state_dict(state)
|
||||||
|
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
||||||
|
|
||||||
|
real_res = []
|
||||||
|
supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35))))
|
||||||
|
forward_steps = 2
|
||||||
|
iter_dataloader = iter(dataloader)
|
||||||
|
for _ in range(forward_steps):
|
||||||
|
real_res.append(next(iter_dataloader))
|
||||||
|
|
||||||
|
for i in range(forward_steps):
|
||||||
|
assert all(real_res[i] == supposed_res[i])
|
||||||
|
|
||||||
|
# 改变 batch_size;
|
||||||
|
after_batch_size = 3
|
||||||
|
dataloader = DataLoader(dataset, batch_size=after_batch_size)
|
||||||
|
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||||
|
re_batchsampler.load_state_dict(state)
|
||||||
|
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
||||||
|
|
||||||
|
real_res = []
|
||||||
|
supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27))))
|
||||||
|
forward_steps = 2
|
||||||
|
iter_dataloader = iter(dataloader)
|
||||||
|
for _ in range(forward_steps):
|
||||||
|
real_res.append(next(iter_dataloader))
|
||||||
|
|
||||||
|
for i in range(forward_steps):
|
||||||
|
assert all(real_res[i] == supposed_res[i])
|
||||||
|
|
||||||
|
# 断点重训的第二轮是否是一个完整的 dataloader;
|
||||||
|
# 先把断点重训所在的那一个 epoch 跑完;
|
||||||
|
begin_idx = 27
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = next(iter_dataloader)
|
||||||
|
_batch_size = len(data)
|
||||||
|
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
|
||||||
|
begin_idx += _batch_size
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 开始新的一轮;
|
||||||
|
begin_idx = 0
|
||||||
|
iter_dataloader = iter(dataloader)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = next(iter_dataloader)
|
||||||
|
_batch_size = len(data)
|
||||||
|
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
|
||||||
|
begin_idx += _batch_size
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
def test_torch_dataloader_2(self):
|
||||||
|
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
before_batch_size = 7
|
||||||
|
dataset = TorchNormalDataset(num_of_data=100)
|
||||||
|
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
|
||||||
|
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
|
||||||
|
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||||
|
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
||||||
|
|
||||||
|
# 将一轮的所有数据保存下来,看是否恢复的是正确的;
|
||||||
|
all_supposed_data = []
|
||||||
|
forward_steps = 3
|
||||||
|
iter_dataloader = iter(dataloader)
|
||||||
|
for _ in range(forward_steps):
|
||||||
|
all_supposed_data.extend(next(iter_dataloader).tolist())
|
||||||
|
|
||||||
|
# 1. 保存状态
|
||||||
|
_get_re_batchsampler = dataloader.batch_sampler
|
||||||
|
assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
|
||||||
|
state = _get_re_batchsampler.state_dict()
|
||||||
|
|
||||||
|
# 2. 断点重训,重新生成一个 dataloader;
|
||||||
|
# 不改变 batch_size;
|
||||||
|
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
|
||||||
|
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||||
|
re_batchsampler.load_state_dict(state)
|
||||||
|
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
||||||
|
|
||||||
|
iter_dataloader = iter(dataloader)
|
||||||
|
# 先把这一轮的数据过完;
|
||||||
|
pre_index_list = dataloader.batch_sampler.state_dict()["index_list"]
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
all_supposed_data.extend(next(iter_dataloader).tolist())
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
assert all_supposed_data == list(pre_index_list)
|
||||||
|
|
||||||
|
# 重新开启新的一轮;
|
||||||
|
for _ in range(3):
|
||||||
|
iter_dataloader = iter(dataloader)
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
res.extend(next(iter_dataloader).tolist())
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
assert res != all_supposed_data
|
||||||
|
|
@ -1,13 +1,25 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
class NormalIterator:
|
class NormalSampler:
|
||||||
def __init__(self, num_of_data=1000):
|
def __init__(self, num_of_data=1000, shuffle=False):
|
||||||
self._num_of_data = num_of_data
|
self._num_of_data = num_of_data
|
||||||
self._data = list(range(num_of_data))
|
self._data = list(range(num_of_data))
|
||||||
|
if shuffle:
|
||||||
|
random.shuffle(self._data)
|
||||||
|
self.shuffle = shuffle
|
||||||
self._index = 0
|
self._index = 0
|
||||||
|
self.need_reinitialize = False
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
if self.need_reinitialize:
|
||||||
|
self._index = 0
|
||||||
|
if self.shuffle:
|
||||||
|
random.shuffle(self._data)
|
||||||
|
else:
|
||||||
|
self.need_reinitialize = True
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
@ -15,12 +27,45 @@ class NormalIterator:
|
|||||||
raise StopIteration
|
raise StopIteration
|
||||||
_data = self._data[self._index]
|
_data = self._data[self._index]
|
||||||
self._index += 1
|
self._index += 1
|
||||||
return self._data
|
return _data
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self._num_of_data
|
return self._num_of_data
|
||||||
|
|
||||||
|
|
||||||
|
class NormalBatchSampler:
|
||||||
|
def __init__(self, sampler, batch_size: int, drop_last: bool) -> None:
|
||||||
|
# Since collections.abc.Iterable does not check for `__getitem__`, which
|
||||||
|
# is one way for an object to be an iterable, we don't do an `isinstance`
|
||||||
|
# check here.
|
||||||
|
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
|
||||||
|
batch_size <= 0:
|
||||||
|
raise ValueError("batch_size should be a positive integer value, "
|
||||||
|
"but got batch_size={}".format(batch_size))
|
||||||
|
if not isinstance(drop_last, bool):
|
||||||
|
raise ValueError("drop_last should be a boolean value, but got "
|
||||||
|
"drop_last={}".format(drop_last))
|
||||||
|
self.sampler = sampler
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.drop_last = drop_last
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
batch = []
|
||||||
|
for idx in self.sampler:
|
||||||
|
batch.append(idx)
|
||||||
|
if len(batch) == self.batch_size:
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
if len(batch) > 0 and not self.drop_last:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
if self.drop_last:
|
||||||
|
return len(self.sampler) // self.batch_size
|
||||||
|
else:
|
||||||
|
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
|
||||||
|
|
||||||
|
|
||||||
class RandomDataset:
|
class RandomDataset:
|
||||||
def __init__(self, num_data=10):
|
def __init__(self, num_data=10):
|
||||||
self.data = np.random.rand(num_data)
|
self.data = np.random.rand(num_data)
|
||||||
@ -30,3 +75,6 @@ class RandomDataset:
|
|||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return self.data[item]
|
return self.data[item]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user