修复但å单卡的设备逻辑

This commit is contained in:
x54-729 2022-04-09 16:48:18 +08:00
parent f5f1c280e0
commit 49d18f3683
3 changed files with 10 additions and 5 deletions

View File

@ -241,7 +241,6 @@ class PaddleFleetDriver(PaddleDriver):
launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc)
launcher.launch()
# 设置参数和初始化分布式环境
reset_seed()
fleet.init(self.role_maker, self.is_collective, self.strategy)
self.global_rank = int(os.getenv("PADDLE_TRAINER_ID"))
self.world_size = int(os.getenv("PADDLE_TRAINERS_NUM"))

View File

@ -3,6 +3,7 @@ from typing import Optional, Dict, Union
from .paddle_driver import PaddleDriver
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES
from fastNLP.core.utils import (
auto_param_call,
get_paddle_gpu_str,
@ -92,7 +93,12 @@ class PaddleSingleDriver(PaddleDriver):
self._test_signature_fn = model.forward
def setup(self):
os.environ["CUDA_VISIBLE_DEVICES"] = str(get_paddle_device_id(self.model_device))
user_visible_devices = os.environ[USER_CUDA_VISIBLE_DEVICES]
device_id = get_paddle_device_id(self.model_device)
if user_visible_devices is not None and user_visible_devices != "":
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES
device_id = user_visible_devices.split(",")[device_id]
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
paddle.device.set_device("gpu:0")
self.model.to("gpu:0")

View File

@ -271,10 +271,10 @@ def get_device_from_visible(device: Union[str, int]):
return idx
else:
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备
user_visiblde_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
if user_visiblde_devices is not None and user_visiblde_devices != "":
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
if user_visible_devices is not None and user_visible_devices != "":
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES
idx = user_visiblde_devices.split(",")[idx]
idx = user_visible_devices.split(",")[idx]
else:
idx = str(idx)