mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
initialize_paddle_driver的测试例
This commit is contained in:
parent
e3d565b639
commit
193c04c9e2
@ -38,23 +38,19 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
|
||||
if driver not in {"paddle", "fleet"}:
|
||||
raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].")
|
||||
|
||||
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
|
||||
user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES")
|
||||
# 优先级 user > cuda
|
||||
# 判断单机情况 device 的合法性
|
||||
# 分布式情况下通过 world_device 判断
|
||||
if user_visible_devices != "":
|
||||
_could_use_device_num = len(user_visible_devices.split(","))
|
||||
elif cuda_visible_devices is not None:
|
||||
_could_use_device_num = len(cuda_visible_devices.split(","))
|
||||
else:
|
||||
_could_use_device_num = paddle.device.cuda.device_count()
|
||||
if user_visible_devices is None:
|
||||
raise RuntimeError("This situation cannot happen, please report a bug to us.")
|
||||
_could_use_device_num = len(user_visible_devices.split(","))
|
||||
if isinstance(device, int):
|
||||
if device < 0 and device != -1:
|
||||
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.")
|
||||
# if device >= _could_use_device_num:
|
||||
# raise ValueError("The gpu device that parameter `device` specifies is not existed.")
|
||||
device = f"gpu:{device}"
|
||||
if device >= _could_use_device_num:
|
||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.")
|
||||
if device != -1:
|
||||
device = f"gpu:{device}"
|
||||
else:
|
||||
device = list(range(_could_use_device_num))
|
||||
elif isinstance(device, Sequence) and not isinstance(device, str):
|
||||
device = list(set(device))
|
||||
for each in device:
|
||||
@ -62,6 +58,9 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
|
||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be 'int' type.")
|
||||
elif each < 0:
|
||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.")
|
||||
elif each >= _could_use_device_num:
|
||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than"
|
||||
" the available gpu number.")
|
||||
if len(device) == 1:
|
||||
# 传入了 [1] 这样的,视为单卡。
|
||||
device = device[0]
|
||||
|
@ -1,83 +1,103 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from fastNLP.envs.set_backend import set_env
|
||||
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle
|
||||
|
||||
set_env_on_import_paddle()
|
||||
set_env("paddle")
|
||||
import paddle
|
||||
os.environ["FASTNLP_BACKEND"] = "paddle"
|
||||
|
||||
from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver
|
||||
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver
|
||||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
|
||||
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver
|
||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification
|
||||
from fastNLP.envs import get_gpu_count
|
||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
|
||||
import paddle
|
||||
|
||||
def test_incorrect_driver():
|
||||
|
||||
model = PaddleNormalModel_Classification_1(2, 100)
|
||||
with pytest.raises(ValueError):
|
||||
driver = initialize_paddle_driver("torch")
|
||||
driver = initialize_paddle_driver("torch", 0, model)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device",
|
||||
["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"]
|
||||
["cpu", "gpu:0", 0, [1]]
|
||||
)
|
||||
def test_get_single_device(device):
|
||||
@pytest.mark.parametrize(
|
||||
"driver",
|
||||
["paddle"]
|
||||
)
|
||||
def test_get_single_device(driver, device):
|
||||
"""
|
||||
测试正常情况下初始化PaddleSingleDriver的情况
|
||||
"""
|
||||
|
||||
model = PaddleNormalModel_Classification(2, 100)
|
||||
driver = initialize_paddle_driver("paddle", device, model)
|
||||
|
||||
model = PaddleNormalModel_Classification_1(2, 100)
|
||||
driver = initialize_paddle_driver(driver, device, model)
|
||||
assert isinstance(driver, PaddleSingleDriver)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device",
|
||||
["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"]
|
||||
[0, 1]
|
||||
)
|
||||
def test_get_single_device_with_visiblde_devices(device):
|
||||
"""
|
||||
测试 CUDA_VISIBLE_DEVICES 启动时初始化PaddleSingleDriver的情况
|
||||
"""
|
||||
# TODO
|
||||
|
||||
model = PaddleNormalModel_Classification(2, 100)
|
||||
driver = initialize_paddle_driver("paddle", device, model)
|
||||
|
||||
assert isinstance(driver, PaddleSingleDriver)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device",
|
||||
[[1, 2, 3]]
|
||||
"driver",
|
||||
["fleet"]
|
||||
)
|
||||
def test_get_fleet(device):
|
||||
@magic_argv_env_context
|
||||
def test_get_fleet_2(driver, device):
|
||||
"""
|
||||
测试 fleet 多卡的初始化情况
|
||||
"""
|
||||
|
||||
model = PaddleNormalModel_Classification(2, 100)
|
||||
driver = initialize_paddle_driver("paddle", device, model)
|
||||
model = PaddleNormalModel_Classification_1(64, 10)
|
||||
driver = initialize_paddle_driver(driver, device, model)
|
||||
|
||||
assert isinstance(driver, PaddleFleetDriver)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device",
|
||||
[[1,2,3]]
|
||||
[[0, 2, 3], -1]
|
||||
)
|
||||
def test_get_fleet(device):
|
||||
@pytest.mark.parametrize(
|
||||
"driver",
|
||||
["paddle", "fleet"]
|
||||
)
|
||||
@magic_argv_env_context
|
||||
def test_get_fleet(driver, device):
|
||||
"""
|
||||
测试 launch 启动 fleet 多卡的初始化情况
|
||||
测试 fleet 多卡的初始化情况
|
||||
"""
|
||||
# TODO
|
||||
|
||||
model = PaddleNormalModel_Classification(2, 100)
|
||||
driver = initialize_paddle_driver("paddle", device, model)
|
||||
model = PaddleNormalModel_Classification_1(64, 10)
|
||||
driver = initialize_paddle_driver(driver, device, model)
|
||||
|
||||
assert isinstance(driver, PaddleFleetDriver)
|
||||
|
||||
def test_device_out_of_range(device):
|
||||
@pytest.mark.parametrize(
|
||||
("driver", "device"),
|
||||
[("fleet", "cpu")]
|
||||
)
|
||||
@magic_argv_env_context
|
||||
def test_get_fleet_cpu(driver, device):
|
||||
"""
|
||||
测试试图在 cpu 上初始化分布式训练的情况
|
||||
"""
|
||||
model = PaddleNormalModel_Classification_1(64, 10)
|
||||
with pytest.raises(ValueError):
|
||||
driver = initialize_paddle_driver(driver, device, model)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device",
|
||||
[-2, [0, get_gpu_count() + 1, 3], [-2], get_gpu_count() + 1]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"driver",
|
||||
["paddle", "fleet"]
|
||||
)
|
||||
@magic_argv_env_context
|
||||
def test_device_out_of_range(driver, device):
|
||||
"""
|
||||
测试传入的device超过范围的情况
|
||||
"""
|
||||
pass
|
||||
model = PaddleNormalModel_Classification_1(2, 100)
|
||||
with pytest.raises(ValueError):
|
||||
driver = initialize_paddle_driver(driver, device, model)
|
Loading…
Reference in New Issue
Block a user