ModelLink2/torch_trans.py

8 lines
192 B
Python
Raw Normal View History

2023-06-12 14:42:29 +08:00
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
torch.cuda.init = torch.npu.init
torch.npu.init()
torch.cuda.default_generators = torch_npu.npu.default_generators