test loss

This commit is contained in:
yh 2018-12-04 16:22:41 +08:00
parent abe5ec7261
commit 62c63f159a

View File

@ -300,3 +300,22 @@ class TestLoss_v2(unittest.TestCase):
b = torch.tensor([1, 0, 4])
ans = l1({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.nll_loss(a, b))
class TestLosserError(unittest.TestCase):
def test_losser1(self):
# (1) only input, targets passed
pred_dict = {"pred": torch.zeros(4, 3)}
target_dict = {'target': torch.zeros(4).long()}
los = loss.CrossEntropyLoss()
print(los(pred_dict=pred_dict, target_dict=target_dict))
#
def test_AccuracyMetric2(self):
# (2) with corrupted size
pred_dict = {"pred": torch.zeros(16, 3, 4)}
target_dict = {'target': torch.zeros(16, 3).long()}
los = loss.CrossEntropyLoss()
print(los(pred_dict=pred_dict, target_dict=target_dict))