ModelLink2/torch_trans.py
2023-06-12 14:42:29 +08:00

8 lines
192 B
Python

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
torch.cuda.init = torch.npu.init
torch.npu.init()
torch.cuda.default_generators = torch_npu.npu.default_generators