fastNLP/tests/core/metrics/utils.py
2022-05-03 07:57:32 +00:00

45 lines
1.4 KiB
Python

import os, sys
import socket
from typing import Union
import numpy as np
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch import distributed
def setup_ddp(rank: int, world_size: int, master_port: int) -> None:
"""Setup ddp environment."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
def find_free_network_port() -> int:
"""Finds a free port on localhost.
It is useful in single-node training when we don't want to connect to a real master node but have to set the
`MASTER_PORT` environment variable.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
s.close()
return port
def _assert_allclose(my_result: Union[float, np.ndarray], sklearn_result: Union[float, np.ndarray],
atol: float = 1e-8) -> None:
"""
测试对比结果,这里不用非得是必须数组且维度对应,一些其他情况例如 np.allclose(np.array([[1e10, ], ]), 1e10+1) 也是 True
:param my_result: 可以不限设备等
:param sklearn_result:
:param atol:
:return:
"""
assert np.allclose(a=my_result, b=sklearn_result, atol=atol)