mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 12:17:35 +08:00
删除 paddle_driver/test_utils.py 中 get_device_from_visible 的测试例
This commit is contained in:
parent
6fb694db91
commit
007678b6d9
@ -370,29 +370,11 @@ class TestDataSetMethods:
|
||||
assert os.path.exists("1.csv") == True
|
||||
os.remove("1.csv")
|
||||
|
||||
def test_add_collate_fn(self):
|
||||
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
|
||||
|
||||
def collate_fn(item):
|
||||
return item
|
||||
|
||||
ds.add_collate_fn(collate_fn)
|
||||
|
||||
def test_get_collator(self):
|
||||
from typing import Callable
|
||||
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
|
||||
collate_fn = ds.get_collator()
|
||||
assert isinstance(collate_fn, Callable) == True
|
||||
|
||||
def test_add_seq_len(self):
|
||||
ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]})
|
||||
ds.add_seq_len('x')
|
||||
print(ds)
|
||||
|
||||
def test_set_target(self):
|
||||
ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]})
|
||||
ds.set_target('x')
|
||||
|
||||
|
||||
class TestFieldArrayInit:
|
||||
"""
|
||||
|
@ -1,8 +1,6 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from fastNLP.core.drivers.paddle_driver.utils import (
|
||||
get_device_from_visible,
|
||||
replace_batch_sampler,
|
||||
replace_sampler,
|
||||
)
|
||||
@ -14,24 +12,6 @@ if _NEED_IMPORT_PADDLE:
|
||||
|
||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("user_visible_devices, cuda_visible_devices, device, output_type, correct"),
|
||||
(
|
||||
("0,1,2,3,4,5,6,7", "0", "cpu", str, "cpu"),
|
||||
("0,1,2,3,4,5,6,7", "0", "cpu", int, "cpu"),
|
||||
("0,1,2,3,4,5,6,7", "3,4,5", "gpu:4", int, 1),
|
||||
("0,1,2,3,4,5,6,7", "3,4,5", "gpu:5", str, "gpu:2"),
|
||||
("3,4,5,6", "3,5", 0, int, 0),
|
||||
("3,6,7,8", "6,7,8", "gpu:2", str, "gpu:1"),
|
||||
)
|
||||
)
|
||||
@pytest.mark.paddle
|
||||
def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices, device, output_type, correct):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
|
||||
os.environ["USER_CUDA_VISIBLE_DEVICES"] = user_visible_devices
|
||||
res = get_device_from_visible(device, output_type)
|
||||
assert res == correct
|
||||
|
||||
@pytest.mark.paddle
|
||||
def test_replace_batch_sampler():
|
||||
dataset = PaddleNormalDataset(10)
|
||||
|
Loading…
Reference in New Issue
Block a user