mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
修复但å单卡的设备逻辑
This commit is contained in:
parent
f5f1c280e0
commit
49d18f3683
@ -241,7 +241,6 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc)
|
||||
launcher.launch()
|
||||
# 设置参数和初始化分布式环境
|
||||
reset_seed()
|
||||
fleet.init(self.role_maker, self.is_collective, self.strategy)
|
||||
self.global_rank = int(os.getenv("PADDLE_TRAINER_ID"))
|
||||
self.world_size = int(os.getenv("PADDLE_TRAINERS_NUM"))
|
||||
|
@ -3,6 +3,7 @@ from typing import Optional, Dict, Union
|
||||
|
||||
from .paddle_driver import PaddleDriver
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
|
||||
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES
|
||||
from fastNLP.core.utils import (
|
||||
auto_param_call,
|
||||
get_paddle_gpu_str,
|
||||
@ -92,7 +93,12 @@ class PaddleSingleDriver(PaddleDriver):
|
||||
self._test_signature_fn = model.forward
|
||||
|
||||
def setup(self):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(get_paddle_device_id(self.model_device))
|
||||
user_visible_devices = os.environ[USER_CUDA_VISIBLE_DEVICES]
|
||||
device_id = get_paddle_device_id(self.model_device)
|
||||
if user_visible_devices is not None and user_visible_devices != "":
|
||||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES
|
||||
device_id = user_visible_devices.split(",")[device_id]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
|
||||
paddle.device.set_device("gpu:0")
|
||||
self.model.to("gpu:0")
|
||||
|
||||
|
@ -271,10 +271,10 @@ def get_device_from_visible(device: Union[str, int]):
|
||||
return idx
|
||||
else:
|
||||
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备
|
||||
user_visiblde_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
|
||||
if user_visiblde_devices is not None and user_visiblde_devices != "":
|
||||
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
|
||||
if user_visible_devices is not None and user_visible_devices != "":
|
||||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES
|
||||
idx = user_visiblde_devices.split(",")[idx]
|
||||
idx = user_visible_devices.split(",")[idx]
|
||||
else:
|
||||
idx = str(idx)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user