mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 04:07:35 +08:00
完善 mix_modules/utils.py 的文档
This commit is contained in:
parent
a49c9d7bdf
commit
7704ccc89d
7
docs/source/fastNLP.core.callbacks.fitlog_callback.rst
Normal file
7
docs/source/fastNLP.core.callbacks.fitlog_callback.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.callbacks.fitlog\_callback module
|
||||
==============================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.fitlog_callback
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -25,6 +25,7 @@ Submodules
|
||||
fastNLP.core.callbacks.callback_manager
|
||||
fastNLP.core.callbacks.checkpoint_callback
|
||||
fastNLP.core.callbacks.early_stop_callback
|
||||
fastNLP.core.callbacks.fitlog_callback
|
||||
fastNLP.core.callbacks.has_monitor_callback
|
||||
fastNLP.core.callbacks.load_best_model_callback
|
||||
fastNLP.core.callbacks.lr_scheduler_callback
|
||||
|
15
docs/source/fastNLP.modules.mix_modules.rst
Normal file
15
docs/source/fastNLP.modules.mix_modules.rst
Normal file
@ -0,0 +1,15 @@
|
||||
fastNLP.modules.mix\_modules package
|
||||
====================================
|
||||
|
||||
.. automodule:: fastNLP.modules.mix_modules
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.modules.mix_modules.utils
|
7
docs/source/fastNLP.modules.mix_modules.utils.rst
Normal file
7
docs/source/fastNLP.modules.mix_modules.utils.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.modules.mix\_modules.utils module
|
||||
=========================================
|
||||
|
||||
.. automodule:: fastNLP.modules.mix_modules.utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
15
docs/source/fastNLP.modules.rst
Normal file
15
docs/source/fastNLP.modules.rst
Normal file
@ -0,0 +1,15 @@
|
||||
fastNLP.modules package
|
||||
=======================
|
||||
|
||||
.. automodule:: fastNLP.modules
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Subpackages
|
||||
-----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.modules.mix_modules
|
@ -15,3 +15,4 @@ Subpackages
|
||||
fastNLP.core
|
||||
fastNLP.envs
|
||||
fastNLP.io
|
||||
fastNLP.modules
|
||||
|
@ -22,9 +22,9 @@ from .utils import apply_to_collection
|
||||
|
||||
def _convert_data_device(device: Union[str, int]) -> str:
|
||||
"""
|
||||
用于转换 ``driver`` 的 ``data_device`` 的函数。如果用户设置了 ``FASTNLP_BACKEND=paddle``,那么 ``fastNLP`` 会将
|
||||
用于转换 ``driver`` 的 ``data_device`` 的函数。如果用户设置了 ``FASTNLP_BACKEND=paddle``,那么 **fastNLP** 会将
|
||||
可见的设备保存在 ``USER_CUDA_VISIBLE_DEVICES`` 中,并且将 ``CUDA_VISIBLE_DEVICES`` 设置为可见的第一张显卡;这是为
|
||||
了顺利执行 ``paddle`` 的分布式训练而设置的。
|
||||
了顺利执行 **paddle** 的分布式训练而设置的。
|
||||
|
||||
在这种情况下,单纯使用 ``driver.data_device`` 是无效的。比如在分布式训练中将设备设置为 ``[0,2,3]`` ,且用户设置了
|
||||
``CUDA_VISIBLE_DEVICES=3,4,5,6`` ,那么在 ``rank1``的进程中有::
|
||||
@ -127,7 +127,7 @@ def get_paddle_device_id(device: Union[str, int]) -> int:
|
||||
|
||||
def paddle_move_data_to_device(batch: Any, device: Optional[Union[str, int]]) -> Any:
|
||||
r"""
|
||||
将 ``paddle`` 的数据集合传输到给定设备。只有 :class:`paddle.Tensor` 对象会被传输到设备中,其余保持不变。
|
||||
将 **paddle** 的数据集合传输到给定设备。只有 :class:`paddle.Tensor` 对象会被传输到设备中,其余保持不变。
|
||||
|
||||
:param batch: 需要进行迁移的数据集合;
|
||||
:param device: 目标设备。可以是显卡设备的编号,或是``cpu``, ``gpu`` 或 ``gpu:x`` 格式的字符串;当这个参数
|
||||
@ -145,20 +145,20 @@ def paddle_move_data_to_device(batch: Any, device: Optional[Union[str, int]]) ->
|
||||
|
||||
def is_in_paddle_dist() -> bool:
|
||||
"""
|
||||
判断是否处于 ``paddle`` 分布式的进程下,使用 ``PADDLE_RANK_IN_NODE`` 和 ``FLAGS_selected_gpus`` 判断。
|
||||
判断是否处于 **paddle** 分布式的进程下,使用 ``PADDLE_RANK_IN_NODE`` 和 ``FLAGS_selected_gpus`` 判断。
|
||||
"""
|
||||
return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ)
|
||||
|
||||
|
||||
def is_in_fnlp_paddle_dist() -> bool:
|
||||
"""
|
||||
判断是否处于 ``fastNLP`` 拉起的 ``paddle`` 分布式进程中
|
||||
判断是否处于 **fastNLP** 拉起的 **paddle** 分布式进程中
|
||||
"""
|
||||
return FASTNLP_DISTRIBUTED_CHECK in os.environ
|
||||
|
||||
|
||||
def is_in_paddle_launch_dist() -> bool:
|
||||
"""
|
||||
判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 ``paddle`` 分布式进程中
|
||||
判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 **paddle** 分布式进程中
|
||||
"""
|
||||
return FASTNLP_BACKEND_LAUNCH in os.environ
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
该文件用于为 ``fastNLP`` 提供一个统一的 ``progress bar`` 管理,通过共用一个``Task`` 对象, :class:`~fastNLP.core.Trainer` 中
|
||||
该文件用于为 **fastNLP** 提供一个统一的 ``progress bar`` 管理,通过共用一个``Task`` 对象, :class:`~fastNLP.core.Trainer` 中
|
||||
的 ``progress bar`` 和 :class:`~fastNLP.core.Evaluator` 中的 ``progress bar`` 才能不冲突
|
||||
"""
|
||||
import sys
|
||||
|
@ -44,11 +44,11 @@ class TorchTransferableDataType(ABC):
|
||||
def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.device"]] = None,
|
||||
non_blocking: Optional[bool] = True) -> Any:
|
||||
r"""
|
||||
在 ``pytorch`` 中将数据集合 ``batch`` 传输到给定设备。任何定义方法 ``to(device)`` 的对象都将被移动并且集合中的所有其他对象将保持不变;
|
||||
在 **pytorch** 中将数据集合 ``batch`` 传输到给定设备。任何定义方法 ``to(device)`` 的对象都将被移动并且集合中的所有其他对象将保持不变;
|
||||
|
||||
:param batch: 需要迁移的数据;
|
||||
:param device: 数据应当迁移到的设备;当该参数的值为 ``None`` 时则不执行任何操作;
|
||||
:param non_blocking: ``pytorch`` 的数据迁移方法 ``to`` 的参数;
|
||||
:param non_blocking: **pytorch** 的数据迁移方法 ``to`` 的参数;
|
||||
:return: 迁移到新设备上的数据集合;
|
||||
"""
|
||||
if device is None:
|
||||
|
@ -55,7 +55,7 @@ def get_fn_arg_names(fn: Callable) -> List[str]:
|
||||
def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None,
|
||||
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any:
|
||||
r"""
|
||||
该函数会根据输入函数的形参名从 ``*args`` (均为 ``dict`` 类型)中找到匹配的值进行调用,如果传入的数据与 ``fn`` 的形参不匹配,可以通过
|
||||
该函数会根据输入函数的形参名从 ``*args`` (均为 **dict** 类型)中找到匹配的值进行调用,如果传入的数据与 ``fn`` 的形参不匹配,可以通过
|
||||
``mapping`` 参数进行转换。``mapping`` 参数中的一对 ``(key, value)`` 表示在 ``*args`` 中找到 ``key`` 对应的值,并将这个值传递给形参中名为
|
||||
``value`` 的参数。
|
||||
|
||||
@ -259,21 +259,21 @@ def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict:
|
||||
|
||||
def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, data: Optional[Any] = None) -> Any:
|
||||
r"""
|
||||
用来实现将输入的 ``batch`` 或者输出的 ``outputs`` 通过 ``mapping`` 将键值进行更换的功能;
|
||||
用来实现将输入的 **batch** 或者输出的 **outputs** 通过 ``mapping`` 将键值进行更换的功能;
|
||||
该函数应用于 ``input_mapping`` 和 ``output_mapping``;
|
||||
|
||||
* 对于 ``input_mapping``,该函数会在 :class:`~fastNLP.core.controllers.TrainBatchLoop` 中取完数据后立刻被调用;
|
||||
* 对于 ``output_mapping``,该函数会在 :class:`~fastNLP.core.Trainer` 的 :meth:`~fastNLP.core.Trainer.train_step`
|
||||
以及 :class:`~fastNLP.core.Evaluator` 的 :meth:`~fastNLP.core.Evaluator.train_step` 中得到结果后立刻被调用;
|
||||
以及 :class:`~fastNLP.core.Evaluator` 的 :meth:`~fastNLP.core.Evaluator.train_step` 中得到结果后立刻被调用;
|
||||
|
||||
转换的逻辑按优先级依次为:
|
||||
|
||||
1. 如果 ``mapping`` 是一个函数,那么会直接返回 ``mapping(data)``;
|
||||
2. 如果 ``mapping`` 是一个 ``Dict``,那么 ``data`` 的类型只能为以下三种: ``[Dict, dataclass, Sequence]``;
|
||||
1. 如果 ``mapping`` 是一个函数,那么会直接返回 **mapping(data)**;
|
||||
2. 如果 ``mapping`` 是一个 **Dict**,那么 ``data`` 的类型只能为以下三种: ``[Dict, dataclass, Sequence]``;
|
||||
|
||||
* 如果 ``data`` 是 ``Dict``,那么该函数会将 ``data`` 的 ``key`` 替换为 ``mapping[key]``;
|
||||
* 如果 ``data`` 是 ``dataclass``,那么该函数会先使用 :func:`dataclasses.asdict` 函数将其转换为 ``Dict``,然后进行转换;
|
||||
* 如果 ``data`` 是 ``Sequence``,那么该函数会先将其转换成一个对应的字典::
|
||||
* 如果 ``data`` 是 **Dict**,那么该函数会将 ``data`` 的 ``key`` 替换为 **mapping[key]**;
|
||||
* 如果 ``data`` 是 **dataclass**,那么该函数会先使用 :func:`dataclasses.asdict` 函数将其转换为 **Dict**,然后进行转换;
|
||||
* 如果 ``data`` 是 **Sequence**,那么该函数会先将其转换成一个对应的字典::
|
||||
|
||||
{
|
||||
"_0": list[0],
|
||||
@ -281,7 +281,7 @@ def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None,
|
||||
...
|
||||
}
|
||||
|
||||
然后使用 ``mapping`` 对这个 ``Dict`` 进行转换,如果没有匹配上 ``mapping`` 中的 ``key`` 则保持 ``\'\_number\'`` 这个形式。
|
||||
然后使用 ``mapping`` 对这个字典进行转换,如果没有匹配上 ``mapping`` 中的 ``key`` 则保持 ``'_number'`` 这个形式。
|
||||
|
||||
:param mapping: 用于转换的字典或者函数;当 ``mapping`` 是函数时,返回值必须为字典类型;
|
||||
:param data: 需要被转换的对象;
|
||||
@ -459,7 +459,7 @@ def _is_iterable(value):
|
||||
|
||||
def pretty_table_printer(dataset_or_ins) -> PrettyTable:
|
||||
r"""
|
||||
用于在 ``fastNLP`` 中展示数据的函数::
|
||||
用于在 **fastNLP** 中展示数据的函数::
|
||||
|
||||
>>> ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"])
|
||||
+-----------+-----------+-----------------+
|
||||
|
@ -0,0 +1,242 @@
|
||||
import warnings
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastNLP.core.utils import paddle_to, apply_to_collection
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE
|
||||
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
import paddle
|
||||
|
||||
if _NEED_IMPORT_JITTOR:
|
||||
import jittor
|
||||
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"paddle2torch",
|
||||
"torch2paddle",
|
||||
"jittor2torch",
|
||||
"torch2jittor",
|
||||
]
|
||||
|
||||
def _paddle2torch(paddle_tensor: 'paddle.Tensor', device: Optional[Union[str, int]] = None, no_gradient: bool = None) -> 'torch.Tensor':
|
||||
"""
|
||||
将 :class:`paddle.Tensor` 转换为 :class:`torch.Tensor` ,并且能够保留梯度进行反向传播
|
||||
|
||||
:param paddle_tensor: 要转换的 **paddle** 张量;
|
||||
:param device: 是否将转换后的张量迁移到特定设备上,为 ``None``时,和输入的张量相同;
|
||||
:param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致;
|
||||
为 ``True`` 时,全部不保留梯度;为 ``False`` 时,全部保留梯度;
|
||||
:return: 转换后的 **torch** 张量;
|
||||
"""
|
||||
no_gradient = paddle_tensor.stop_gradient if no_gradient is None else no_gradient
|
||||
paddle_numpy = paddle_tensor.numpy()
|
||||
if not np.issubdtype(paddle_numpy.dtype, np.inexact):
|
||||
no_gradient = True
|
||||
|
||||
if device is None:
|
||||
if paddle_tensor.place.is_gpu_place():
|
||||
# paddlepaddle有两种Place,对应不同的device id获取方式
|
||||
if hasattr(paddle_tensor.place, "gpu_device_id"):
|
||||
# paddle.fluid.core_avx.Place
|
||||
# 在gpu环境下创建张量的话,张量的place是这一类型
|
||||
device = f"cuda:{paddle_tensor.place.gpu_device_id()}"
|
||||
else:
|
||||
# paddle.CUDAPlace
|
||||
device = f"cuda:{paddle_tensor.place.get_device_id()}"
|
||||
else:
|
||||
# TODO: 可能需要支持xpu等设备
|
||||
device = "cpu"
|
||||
|
||||
if not no_gradient:
|
||||
# 保持梯度,并保持反向传播
|
||||
# torch.tensor会保留numpy数组的类型
|
||||
torch_tensor = torch.tensor(paddle_numpy, requires_grad=True, device=device)
|
||||
hook = torch_tensor.register_hook(
|
||||
lambda grad: paddle.autograd.backward(paddle_tensor, paddle.to_tensor(grad.cpu().numpy()))
|
||||
)
|
||||
else:
|
||||
# 不保留梯度
|
||||
torch_tensor = torch.tensor(paddle_numpy, requires_grad=False, device=device)
|
||||
|
||||
return torch_tensor
|
||||
|
||||
|
||||
def _torch2paddle(torch_tensor: 'torch.Tensor', device: str = None, no_gradient: bool = None) -> 'paddle.Tensor':
|
||||
"""
|
||||
将 :class:`torch.Tensor` 转换为 :class:`paddle.Tensor`,并且能够保留梯度进行反向传播。
|
||||
|
||||
:param torch_tensor: 要转换的 **torch** 张量;
|
||||
:param device: 是否将转换后的张量迁移到特定设备上,输入为 ``None`` 时,和输入的张量相同;
|
||||
:param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致;
|
||||
为 ``True`` 时,全部不保留梯度;为 ``False`` 时,全部保留梯度;
|
||||
:return: 转换后的 **paddle** 张量;
|
||||
"""
|
||||
no_gradient = not torch_tensor.requires_grad if no_gradient is None else no_gradient
|
||||
if device is None:
|
||||
if torch_tensor.is_cuda:
|
||||
device = f"gpu:{torch_tensor.device.index}"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
if not no_gradient:
|
||||
# 保持梯度并保持反向传播
|
||||
# paddle的stop_gradient和torch的requires_grad表现是相反的
|
||||
paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=False)
|
||||
hook = paddle_tensor.register_hook(
|
||||
lambda grad: torch.autograd.backward(torch_tensor, torch.tensor(grad.numpy()))
|
||||
)
|
||||
else:
|
||||
paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=True)
|
||||
|
||||
paddle_tensor = paddle_to(paddle_tensor, device)
|
||||
|
||||
return paddle_tensor
|
||||
|
||||
|
||||
def _jittor2torch(jittor_var: 'jittor.Var', device: Optional[Union[str, int]] = None, no_gradient: bool = None) -> 'torch.Tensor':
|
||||
"""
|
||||
将 :class:`jittor.Var` 转换为 :class:`torch.Tensor` 。
|
||||
|
||||
:param jittor_var: 要转换的 **jittor** 变量;
|
||||
:param device: 是否将转换后的张量迁移到特定设备上,输入为 ``None`` 时,根据 ``jittor.flags.use_cuda`` 决定;
|
||||
:param no_gradient: 是否保留原张量的梯度。为``None``时,新的张量与输入张量保持一致;
|
||||
为 ``True`` 时,全部不保留梯度;为 ``False`` 时,全部保留梯度;
|
||||
:return: 转换后的 **torch** 张量;
|
||||
"""
|
||||
# TODO: warning:无法保留梯度
|
||||
# jittor的grad可以通过callback进行传递
|
||||
# 如果outputs有_grad键,可以实现求导
|
||||
no_gradient = not jittor_var.requires_grad if no_gradient is None else no_gradient
|
||||
if no_gradient == False:
|
||||
warnings.warn("The result tensor will not keep gradients due to differences between jittor and pytorch.")
|
||||
jittor_numpy = jittor_var.numpy()
|
||||
if not np.issubdtype(jittor_numpy.dtype, np.inexact):
|
||||
no_gradient = True
|
||||
|
||||
if device is None:
|
||||
# jittor的设备分配是自动的
|
||||
# 根据use_cuda判断
|
||||
if jittor.flags.use_cuda:
|
||||
device = "cuda:0"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
torch_tensor = torch.tensor(jittor_numpy, requires_grad=not no_gradient, device=device)
|
||||
|
||||
return torch_tensor
|
||||
|
||||
|
||||
def _torch2jittor(torch_tensor: 'torch.Tensor', no_gradient: bool = None) -> 'jittor.Var':
|
||||
"""
|
||||
将 :class:`torch.Tensor` 转换为 :class:`jittor.Var` 。
|
||||
|
||||
:param torch_tensor: 要转换的 **torch** 张量;
|
||||
:param no_gradient: 是否保留原张量的梯度。为``None``时,新的张量与输入张量保持一致;
|
||||
为 ``True`` 时,全部不保留梯度;为 ``False`` 时,全部保留梯度;
|
||||
:return: 转换后的 **jittor** 变量;
|
||||
"""
|
||||
no_gradient = not torch_tensor.requires_grad if no_gradient is None else no_gradient
|
||||
|
||||
if not no_gradient:
|
||||
# 保持梯度并保持反向传播
|
||||
jittor_var = jittor.Var(torch_tensor.detach().numpy())
|
||||
jittor_var.requires_grad = True
|
||||
hook = jittor_var.register_hook(
|
||||
lambda grad: torch.autograd.backward(torch_tensor, torch.tensor(grad.numpy()))
|
||||
)
|
||||
else:
|
||||
jittor_var = jittor.Var(torch_tensor.detach().numpy())
|
||||
jittor_var.requires_grad = False
|
||||
|
||||
return jittor_var
|
||||
|
||||
|
||||
def torch2paddle(batch: Any, device: str = None, no_gradient: bool = None) -> Any:
|
||||
"""
|
||||
递归地将输入中包含的 :class:`torch.Tensor` 转换为 :class:`paddle.Tensor` 。
|
||||
|
||||
:param batch: 包含 :class:`torch.Tensor` 类型的数据集合
|
||||
:param device: 是否将转换后的张量迁移到特定设备上。为 ``None`` 时,和输入保持一致;
|
||||
:param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致;
|
||||
为 ``True`` 时,不保留梯度;为 ``False`` 时,保留梯度;
|
||||
:return: 转换后的数据;
|
||||
"""
|
||||
|
||||
return apply_to_collection(
|
||||
batch,
|
||||
dtype=torch.Tensor,
|
||||
function=_torch2paddle,
|
||||
device=device,
|
||||
no_gradient=no_gradient,
|
||||
)
|
||||
|
||||
|
||||
def paddle2torch(batch: Any, device: str = None, no_gradient: bool = None) -> Any:
|
||||
"""
|
||||
递归地将输入中包含的 :class:`paddle.Tensor` 转换为 :class:`torch.Tensor` 。
|
||||
|
||||
:param batch: 包含 :class:`paddle.Tensor` 类型的数据集合;
|
||||
:param device: 是否将转换后的张量迁移到特定设备上。为 ``None``时,和输入保持一致;
|
||||
:param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致;
|
||||
为 ``True`` 时,不保留梯度;为 ``False`` 时,保留梯度;
|
||||
:return: 转换后的数据;
|
||||
"""
|
||||
|
||||
return apply_to_collection(
|
||||
batch,
|
||||
dtype=paddle.Tensor,
|
||||
function=_paddle2torch,
|
||||
device=device,
|
||||
no_gradient=no_gradient,
|
||||
)
|
||||
|
||||
|
||||
def jittor2torch(batch: Any, device: str = None, no_gradient: bool = None) -> Any:
|
||||
"""
|
||||
递归地将输入中包含的 :class:`jittor.Var` 转换为 :class:`torch.Tensor` 。
|
||||
|
||||
.. note::
|
||||
|
||||
注意,由于 **pytorch** 和 **jittor** 之间的差异,从 :class:`jittor.Var` 转换
|
||||
至 :class:`torch.Tensor` 的过程中无法保留原张量的梯度。
|
||||
|
||||
:param batch: 包含 :class:`jittor.Var` 类型的数据集合;
|
||||
:param device: 是否将转换后的张量迁移到特定设备上。为 ``None``时,和输入保持一致;
|
||||
:param no_gradient: 是否保留原张量的梯度,在这个函数中该参数无效。
|
||||
:return: 转换后的数据;
|
||||
"""
|
||||
|
||||
return apply_to_collection(
|
||||
batch,
|
||||
dtype=jittor.Var,
|
||||
function=_jittor2torch,
|
||||
device=device,
|
||||
no_gradient=no_gradient,
|
||||
)
|
||||
|
||||
|
||||
def torch2jittor(batch: Any, no_gradient: bool = None) -> Any:
|
||||
"""
|
||||
递归地将输入中包含的 :class:`torch.Tensor` 转换为 :class:`jittor.Var` 。
|
||||
|
||||
.. note::
|
||||
|
||||
**jittor** 会自动为创建的变量分配设备。
|
||||
|
||||
:param batch: 包含 :class:`torch.Tensor` 类型的数据集合;
|
||||
:param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致;
|
||||
为 ``True`` 时,不保留梯度;为 ``False`` 时,保留梯度;
|
||||
:return: 转换后的数据;
|
||||
"""
|
||||
|
||||
return apply_to_collection(
|
||||
batch,
|
||||
dtype=torch.Tensor,
|
||||
function=_torch2jittor,
|
||||
no_gradient=no_gradient,
|
||||
)
|
Loading…
Reference in New Issue
Block a user