diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index be43bc74..171a784b 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -19,7 +19,7 @@ from abc import abstractmethod class ReproducibleBatchSampler: def __init__(self, **kwargs): - pass + self.num_replicas = 1 @abstractmethod def set_distributed(self, num_replicas, rank, pad=True): @@ -53,14 +53,6 @@ class ReproducibleBatchSampler: def batch_idx_in_epoch(self): raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") - @property - def num_replicas(self): - return self._num_replicas - - @num_replicas.setter - def num_replicas(self, value): - self._num_replicas = value - class RandomBatchSampler(ReproducibleBatchSampler): # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; @@ -322,7 +314,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): if len(batches[-1])==0: batches.pop(-1) - assert len(list(chain(*batches))) == self.num_left_samples + assert sum(map(len, batches)) == self.num_left_samples if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: batches = batches[:-1] diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index c3facbb9..c8425dc7 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -20,6 +20,8 @@ class ReproducibleSampler: 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 """ + def __init__(self, **kwargs): + self.num_replicas = 1 def set_distributed(self, num_replicas, rank, pad=True): raise NotImplementedError("Each specific sampler should implement its own `set_distributed` method.") @@ -47,14 +49,6 @@ class ReproducibleSampler: def set_epoch(self, epoch): pass - @property - def num_repliacs(self): - return self._num_replicas - - @num_repliacs.setter - def num_repliacs(self, value): - self._num_replicas = value - class RandomSampler(ReproducibleSampler): def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):