mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
test loss
This commit is contained in:
parent
abe5ec7261
commit
62c63f159a
@ -300,3 +300,22 @@ class TestLoss_v2(unittest.TestCase):
|
||||
b = torch.tensor([1, 0, 4])
|
||||
ans = l1({"my_predict": a}, {"my_truth": b})
|
||||
self.assertEqual(ans, torch.nn.functional.nll_loss(a, b))
|
||||
|
||||
class TestLosserError(unittest.TestCase):
|
||||
def test_losser1(self):
|
||||
# (1) only input, targets passed
|
||||
pred_dict = {"pred": torch.zeros(4, 3)}
|
||||
target_dict = {'target': torch.zeros(4).long()}
|
||||
los = loss.CrossEntropyLoss()
|
||||
|
||||
print(los(pred_dict=pred_dict, target_dict=target_dict))
|
||||
|
||||
#
|
||||
def test_AccuracyMetric2(self):
|
||||
# (2) with corrupted size
|
||||
pred_dict = {"pred": torch.zeros(16, 3, 4)}
|
||||
target_dict = {'target': torch.zeros(16, 3).long()}
|
||||
los = loss.CrossEntropyLoss()
|
||||
|
||||
print(los(pred_dict=pred_dict, target_dict=target_dict))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user