mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 04:37:37 +08:00
微调 reproducible sampler 的初始化
This commit is contained in:
parent
aa2f678507
commit
3dbb3677f0
@ -19,7 +19,7 @@ from abc import abstractmethod
|
||||
|
||||
class ReproducibleBatchSampler:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
self.num_replicas = 1
|
||||
|
||||
@abstractmethod
|
||||
def set_distributed(self, num_replicas, rank, pad=True):
|
||||
@ -53,14 +53,6 @@ class ReproducibleBatchSampler:
|
||||
def batch_idx_in_epoch(self):
|
||||
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")
|
||||
|
||||
@property
|
||||
def num_replicas(self):
|
||||
return self._num_replicas
|
||||
|
||||
@num_replicas.setter
|
||||
def num_replicas(self, value):
|
||||
self._num_replicas = value
|
||||
|
||||
|
||||
class RandomBatchSampler(ReproducibleBatchSampler):
|
||||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
|
||||
@ -322,7 +314,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
|
||||
if len(batches[-1])==0:
|
||||
batches.pop(-1)
|
||||
|
||||
assert len(list(chain(*batches))) == self.num_left_samples
|
||||
assert sum(map(len, batches)) == self.num_left_samples
|
||||
|
||||
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
|
||||
batches = batches[:-1]
|
||||
|
@ -20,6 +20,8 @@ class ReproducibleSampler:
|
||||
或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。
|
||||
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
self.num_replicas = 1
|
||||
|
||||
def set_distributed(self, num_replicas, rank, pad=True):
|
||||
raise NotImplementedError("Each specific sampler should implement its own `set_distributed` method.")
|
||||
@ -47,14 +49,6 @@ class ReproducibleSampler:
|
||||
def set_epoch(self, epoch):
|
||||
pass
|
||||
|
||||
@property
|
||||
def num_repliacs(self):
|
||||
return self._num_replicas
|
||||
|
||||
@num_repliacs.setter
|
||||
def num_repliacs(self, value):
|
||||
self._num_replicas = value
|
||||
|
||||
|
||||
class RandomSampler(ReproducibleSampler):
|
||||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
|
||||
|
Loading…
Reference in New Issue
Block a user