From cb01a661f1662a44995bf9988f58df17c32d1db6 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 16 Apr 2022 15:46:57 +0000 Subject: [PATCH] =?UTF-8?q?BucketedBatchSampler=E7=9A=84batch=5Fid=5Fin=5F?= =?UTF-8?q?epoch=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/samplers/reproducible_batch_sampler.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 171a784b..e8acc645 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -411,4 +411,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): self.old_num_replicas = states['num_replicas'] def set_epoch(self, epoch): - self.epoch = epoch \ No newline at end of file + self.epoch = epoch + + @property + def batch_idx_in_epoch(self): + if self.drop_last: + return len(self.dataset) // self.batch_size - (len(self.dataset) - self.num_consumed_samples) // self.batch_size + else: + return (len(self.dataset) + self.batch_size - 1) // self.batch_size - \ + (len(self.dataset) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size \ No newline at end of file