BucketedBatchSampler的batch_id_in_epoch实现

This commit is contained in:
x54-729 2022-04-16 15:46:57 +00:00
parent 77f6b63ba6
commit cb01a661f1

View File

@ -411,4 +411,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
self.old_num_replicas = states['num_replicas']
def set_epoch(self, epoch):
self.epoch = epoch
self.epoch = epoch
@property
def batch_idx_in_epoch(self):
if self.drop_last:
return len(self.dataset) // self.batch_size - (len(self.dataset) - self.num_consumed_samples) // self.batch_size
else:
return (len(self.dataset) + self.batch_size - 1) // self.batch_size - \
(len(self.dataset) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size