mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 04:37:37 +08:00
增加paddle单卡的accuracy测试用例
This commit is contained in:
parent
3ea74b52d2
commit
665d79a3ed
@ -14,11 +14,13 @@ if _NEED_IMPORT_PADDLE:
|
||||
import paddle.distributed as dist
|
||||
from paddle.fluid.dygraph import parallel_helper
|
||||
|
||||
|
||||
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List:
|
||||
gathered_result = [paddle.zeros_like(result) for _ in range(world_size)]
|
||||
dist.all_gather(gathered_result, result, group)
|
||||
return gathered_result
|
||||
|
||||
|
||||
class PaddleBackend(Backend):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -124,4 +126,3 @@ class PaddleBackend(Backend):
|
||||
# TODO 如果在这里处理的话,会不会在别的地方引起bug?
|
||||
device = get_device_from_visible(device)
|
||||
return paddle_to(tensor, device)
|
||||
|
||||
|
@ -11,7 +11,6 @@ from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gathe
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List:
|
||||
@ -33,7 +32,7 @@ class TorchBackend(Backend):
|
||||
if dist.is_initialized():
|
||||
if method is None:
|
||||
raise AggregateMethodError(should_have_aggregate_method=True)
|
||||
tensor = fastnlp_torch_all_gather(tensor)
|
||||
tensor = self.all_gather_object(tensor)
|
||||
if isinstance(tensor[0], torch.Tensor):
|
||||
tensor = torch.stack(tensor)
|
||||
# 第一步, aggregate结果
|
||||
|
59
tests/core/metrics/test_accutacy_paddle.py
Normal file
59
tests/core/metrics/test_accutacy_paddle.py
Normal file
@ -0,0 +1,59 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import paddle
|
||||
import paddle.distributed
|
||||
import paddle.distributed.fleet.base.role_maker as role_maker
|
||||
import paddle.distributed.fleet as fleet
|
||||
from fastNLP.core.metrics import Accuracy
|
||||
from fastNLP.core.drivers.paddle_driver.fleet_launcher import FleetLauncher
|
||||
|
||||
############################################################################
|
||||
#
|
||||
# 测试 单机单卡情况下的Accuracy
|
||||
#
|
||||
############################################################################
|
||||
def test_accuracy_single():
|
||||
pred = paddle.to_tensor([[1.19812393, -0.82041764, -0.53517765, -0.73061031, -1.45006669,
|
||||
0.46514302],
|
||||
[-0.85775983, -2.18273783, -1.07505429, -1.45561373, 0.40011844,
|
||||
1.02202022],
|
||||
[-0.39487389, 0.65682763, -0.62424040, 0.53692561, -0.28390560,
|
||||
-0.02559055],
|
||||
[-0.22586937, -0.07676325, -0.95977223, 0.36395910, -0.91758579,
|
||||
-0.83857095],
|
||||
[0.25136873, 2.49652624, 1.06251311, 1.60194016, 1.01451588,
|
||||
0.08403367],
|
||||
[0.10844281, 1.19017303, -0.11378096, 1.12686944, -0.08654942,
|
||||
0.48605862],
|
||||
[1.27320433, -1.13902378, 1.47072780, -0.98665696, -0.42589864,
|
||||
0.64618838],
|
||||
[0.83809763, -0.05356205, 0.03042423, -0.28371972, 0.81611472,
|
||||
-0.45802942],
|
||||
[0.38535264, 0.09721313, 2.27187467, 0.32045507, -0.20711982,
|
||||
-0.13550705],
|
||||
[-0.75228405, -1.34161997, 1.08697927, 0.33218071, -1.19470012,
|
||||
2.58735061]])
|
||||
tg = paddle.to_tensor([1, 2, 1, 3, 5, 4, 4, 2, 1, 5])
|
||||
acc_metric = Accuracy()
|
||||
acc_metric.update(pred, tg)
|
||||
result = acc_metric.get_metric()
|
||||
true_result = {'acc': 0.3}
|
||||
assert true_result == result
|
||||
|
||||
|
||||
############################################################################
|
||||
#
|
||||
# 测试 单机多卡情况下的Accuracy
|
||||
#
|
||||
############################################################################
|
||||
def test_accuracy_ddp():
|
||||
launcher = FleetLauncher(devices=[0, 1])
|
||||
launcher.launch()
|
||||
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||
fleet.init(role)
|
||||
if fleet.is_server():
|
||||
pass
|
||||
elif fleet.is_worker():
|
||||
print(os.getenv("PADDLE_TRAINER_ID"))
|
||||
|
Loading…
Reference in New Issue
Block a user