AIAS/1_image_sdks/gan/biggan_sdk
2023-07-08 22:09:44 +08:00
..
models no message 2023-07-08 22:09:44 +08:00
src/main no message 2023-07-08 22:09:44 +08:00
biggan_sdk.iml no message 2023-07-08 22:09:44 +08:00
pom.xml no message 2023-07-08 22:09:44 +08:00
README.md no message 2023-07-08 22:09:44 +08:00

官网:

官网链接

BIGGAN 图像自动生成SDK

能够自动生成1000种类别支持imagenet数据集分类的图片。

下载模型放置于models目录

支持分类如下:

  • tench, Tinca tinca
  • goldfish, Carassius auratus
  • great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
  • tiger shark, Galeocerdo cuvieri
  • hammerhead, hammerhead shark
  • electric ray, crampfish, numbfish, torpedo
  • stingray
  • cock
  • hen
  • ostrich, Struthio camelus
  • brambling, Fringilla montifringilla
  • goldfinch, Carduelis carduelis
  • house finch, linnet, Carpodacus mexicanus
  • junco, snowbird
  • indigo bunting, indigo finch, indigo bird, Passerina cyanea
  • robin, American robin, Turdus migratorius
  • ...

点击下载

SDK包含两个分类器

size 支持 128, 256, 512三种图片尺寸 如size = 512; imageClass 支持imagenet类别0~999 如imageClass = 156;

运行例子 - BigGAN

  • 测试图片类别11图片尺寸512X512 img1

  • 测试图片类别156图片尺寸512X512 img2

  • 测试图片类别821图片尺寸512X512 img3

运行成功后,命令行应该看到下面的信息:

...
[INFO ] - Number of inter-op threads is 4
[INFO ] - Number of intra-op threads is 4
[INFO ] - Generated image has been saved in: build/output/

开源算法

1. sdk使用的开源算法

2. 模型如何导出 ?

from src.biggan import BigGAN128
from src.biggan import BigGAN256 
from src.biggan import BigGAN512 

import torch 
import torchvision 

from scipy.stats import truncnorm 

import argparse 

if __name__ == '__main__': 
    parser = argparse.ArgumentParser() 
    parser.add_argument('-t', '--truncation', type=float, default=0.4) 
    parser.add_argument('-s', '--size', type=int, choices=[128, 256, 512], default=512) 
    parser.add_argument('-c', '--class_label', type=int, choices=range(0, 1000), default=156) 
    parser.add_argument('-w', '--pretrained_weight', type=str, required=True)
    args = parser.parse_args() 

    truncation = torch.clamp(torch.tensor(args.truncation), min=0.02+1e-4, max=1.0-1e-4).float()  
    c = torch.tensor((args.class_label,)).long()

    if args.size == 128: 
        z = truncation * torch.as_tensor(truncnorm.rvs(-2.0, 2.0, size=(1, 120))).float() 
        biggan = BigGAN128() 
    elif args.size == 256: 
        z = truncation * torch.as_tensor(truncnorm.rvs(-2.0, 2.0, size=(1, 140))).float() 
        biggan = BigGAN256()
    elif args.size == 512: 
        z = truncation * torch.as_tensor(truncnorm.rvs(-2.0, 2.0, size=(1, 128))).float() 
        biggan = BigGAN512() 

    biggan.load_state_dict(torch.load(args.pretrained_weight)) 
    biggan.eval() 

    #Generate model for DJL
    listSample = [z, c, torch.tensor(0.2)]
    # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
    traced_script_module = torch.jit.trace(biggan, listSample)
    # sm = torch.jit.script(tra)
    # Save the TorchScript model
    traced_script_module.save("traced_model.pt")


    with torch.no_grad(): 
        img = biggan(z, c, truncation)  


    img = 0.5 * (img.data + 1) 
    pil = torchvision.transforms.ToPILImage()(img.squeeze()) 
    pil.show()

其它帮助信息

https://aias.top/guides.html

Git地址

Github链接
Gitee链接

帮助文档: