Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

This commit is contained in:
YWMditto 2022-04-09 15:28:19 +08:00
commit a8597602d0

View File

@ -16,19 +16,12 @@ class PaddleNormalDataset(Dataset):
class PaddleRandomDataset(Dataset):
def __init__(self, num_of_data=1000, features=64, labels=10):
self.num_of_data = num_of_data
self.x = [
paddle.rand((features,))
for i in range(num_of_data)
]
self.y = [
paddle.rand((labels,))
for i in range(num_of_data)
]
def __init__(self, num_samples, num_features):
self.x = paddle.randn((num_samples, num_features))
self.y = self.x.argmax(axis=-1)
def __len__(self):
return self.num_of_data
return len(self.x)
def __getitem__(self, item):
return {"x": self.x[item], "y": self.y[item]}