修改GradientClipCallback中parameter的存储方式,防止在torch 1.5版本中报错

This commit is contained in:
yh_cc 2020-04-29 00:13:13 +08:00
parent fbd2fd4ead
commit 62fe53b147

View File

@ -464,7 +464,7 @@ class GradientClipCallback(Callback):
self.clip_fun = nn.utils.clip_grad_value_
else:
raise ValueError("Only supports `norm` or `value` right now.")
self.parameters = parameters
self.parameters = list(parameters)
self.clip_value = clip_value
def on_backward_end(self):