diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 352e6127..de6303ad 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -177,34 +177,36 @@ class DummyCallback(Callback): def before_train(self, *arg): print(arg) - def after_epoch(self): - print("after epoch!!!") - return 12 + def after_epoch(self, cur_epoch, n_epoch, optimizer): + print(cur_epoch, n_epoch, optimizer) class EchoCallback(Callback): def before_train(self): print("before_train") - def before_epoch(self): + def before_epoch(self, cur_epoch, total_epoch): print("before_epoch") - def before_batch(self): + def before_batch(self, batch_x, batch_y, indices): print("before_batch") + print("batch_x:", batch_x) + print("batch_y:", batch_y) + print("indices: ", indices) - def before_loss(self): + def before_loss(self, batch_y, predict_y): print("before_loss") - def before_backward(self): + def before_backward(self, loss, model): print("before_backward") def after_batch(self): print("after_batch") - def after_epoch(self): + def after_epoch(self, cur_epoch, n_epoch, optimizer): print("after_epoch") - def after_train(self): + def after_train(self, model): print("after_train") class GradientClipCallback(Callback):