mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 04:37:37 +08:00
43 lines
1.3 KiB
Python
43 lines
1.3 KiB
Python
import os, sys
|
|
import socket
|
|
from typing import Union
|
|
|
|
import torch
|
|
from torch import distributed
|
|
import numpy as np
|
|
|
|
|
|
def setup_ddp(rank: int, world_size: int, master_port: int) -> None:
|
|
"""Setup ddp environment."""
|
|
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
os.environ["MASTER_PORT"] = str(master_port)
|
|
if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
|
|
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
|
|
|
|
|
|
def find_free_network_port() -> int:
|
|
"""Finds a free port on localhost.
|
|
|
|
It is useful in single-node training when we don't want to connect to a real master node but have to set the
|
|
`MASTER_PORT` environment variable.
|
|
"""
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
s.bind(("", 0))
|
|
s.listen(1)
|
|
port = s.getsockname()[1]
|
|
s.close()
|
|
return port
|
|
|
|
|
|
def _assert_allclose(my_result: Union[float, np.ndarray], sklearn_result: Union[float, np.ndarray],
|
|
atol: float = 1e-8) -> None:
|
|
"""
|
|
测试对比结果,这里不用非得是必须数组且维度对应,一些其他情况例如 np.allclose(np.array([[1e10, ], ]), 1e10+1) 也是 True
|
|
:param my_result: 可以不限设备等
|
|
:param sklearn_result:
|
|
:param atol:
|
|
:return:
|
|
"""
|
|
assert np.allclose(a=my_result, b=sklearn_result, atol=atol)
|