微调 reproducible sampler 的初始化

This commit is contained in:
x54-729 2022-04-16 08:39:07 +00:00
parent aa2f678507
commit 3dbb3677f0
2 changed files with 4 additions and 18 deletions

View File

@ -19,7 +19,7 @@ from abc import abstractmethod
class ReproducibleBatchSampler:
def __init__(self, **kwargs):
pass
self.num_replicas = 1
@abstractmethod
def set_distributed(self, num_replicas, rank, pad=True):
@ -53,14 +53,6 @@ class ReproducibleBatchSampler:
def batch_idx_in_epoch(self):
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")
@property
def num_replicas(self):
return self._num_replicas
@num_replicas.setter
def num_replicas(self, value):
self._num_replicas = value
class RandomBatchSampler(ReproducibleBatchSampler):
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
@ -322,7 +314,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
if len(batches[-1])==0:
batches.pop(-1)
assert len(list(chain(*batches))) == self.num_left_samples
assert sum(map(len, batches)) == self.num_left_samples
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
batches = batches[:-1]

View File

@ -20,6 +20,8 @@ class ReproducibleSampler:
或者 batch_sampler注意所有在 init 中初始化的变量都不能含有 _ 下横线作为开头所有不在 init 中设置的变量都必须以下横线开头
"""
def __init__(self, **kwargs):
self.num_replicas = 1
def set_distributed(self, num_replicas, rank, pad=True):
raise NotImplementedError("Each specific sampler should implement its own `set_distributed` method.")
@ -47,14 +49,6 @@ class ReproducibleSampler:
def set_epoch(self, epoch):
pass
@property
def num_repliacs(self):
return self._num_replicas
@num_repliacs.setter
def num_repliacs(self, value):
self._num_replicas = value
class RandomSampler(ReproducibleSampler):
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):