mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0
This commit is contained in:
commit
a8597602d0
@ -16,19 +16,12 @@ class PaddleNormalDataset(Dataset):
|
||||
|
||||
|
||||
class PaddleRandomDataset(Dataset):
|
||||
def __init__(self, num_of_data=1000, features=64, labels=10):
|
||||
self.num_of_data = num_of_data
|
||||
self.x = [
|
||||
paddle.rand((features,))
|
||||
for i in range(num_of_data)
|
||||
]
|
||||
self.y = [
|
||||
paddle.rand((labels,))
|
||||
for i in range(num_of_data)
|
||||
]
|
||||
def __init__(self, num_samples, num_features):
|
||||
self.x = paddle.randn((num_samples, num_features))
|
||||
self.y = self.x.argmax(axis=-1)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_of_data
|
||||
return len(self.x)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return {"x": self.x[item], "y": self.y[item]}
|
||||
|
Loading…
Reference in New Issue
Block a user