mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 04:37:37 +08:00
BucketedBatchSampler的batch_id_in_epoch实现
This commit is contained in:
parent
77f6b63ba6
commit
cb01a661f1
@ -411,4 +411,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
|
||||
self.old_num_replicas = states['num_replicas']
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
self.epoch = epoch
|
||||
|
||||
@property
|
||||
def batch_idx_in_epoch(self):
|
||||
if self.drop_last:
|
||||
return len(self.dataset) // self.batch_size - (len(self.dataset) - self.num_consumed_samples) // self.batch_size
|
||||
else:
|
||||
return (len(self.dataset) + self.batch_size - 1) // self.batch_size - \
|
||||
(len(self.dataset) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size
|
Loading…
Reference in New Issue
Block a user