remove biggan_sdk and style_transfer_sdk

This commit is contained in:
Calvin 2023-03-10 16:19:33 +08:00
parent 4a7b50db35
commit 0a47a96f75
24 changed files with 0 additions and 1156 deletions

View File

@ -1,137 +0,0 @@
### 官网:
[官网链接](https://www.aias.top/)
### BIGGAN 图像自动生成SDK
能够自动生成1000种类别支持imagenet数据集分类的图片。
### 支持分类如下:
- tench, Tinca tinca
- goldfish, Carassius auratus
- great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
- tiger shark, Galeocerdo cuvieri
- hammerhead, hammerhead shark
- electric ray, crampfish, numbfish, torpedo
- stingray
- cock
- hen
- ostrich, Struthio camelus
- brambling, Fringilla montifringilla
- goldfinch, Carduelis carduelis
- house finch, linnet, Carpodacus mexicanus
- junco, snowbird
- indigo bunting, indigo finch, indigo bird, Passerina cyanea
- robin, American robin, Turdus migratorius
- ...
[点击下载](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/classification_imagenet_sdk/synset.txt)
### SDK包含两个分类器
size 支持 128, 256, 512三种图片尺寸
size = 512;
imageClass 支持imagenet类别0~999
imageClass = 156;
### 运行例子 - BigGAN
- 测试图片类别11图片尺寸512X512
![img1](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/biggan_sdk/image11.png)
- 测试图片类别156图片尺寸512X512
![img2](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/biggan_sdk/image156.png)
- 测试图片类别821图片尺寸512X512
![img3](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/biggan_sdk/image821.png)
运行成功后,命令行应该看到下面的信息:
```text
...
[INFO ] - Number of inter-op threads is 4
[INFO ] - Number of intra-op threads is 4
[INFO ] - Generated image has been saved in: build/output/
```
### 开源算法
#### 1. sdk使用的开源算法
- [BigGAN-Generator-Pretrained-Pytorch](https://github.com/ivclab/BigGAN-Generator-Pretrained-Pytorch)
- [预训练模型 biggan-128](https://tfhub.dev/deepmind/biggan-128/2)
- [预训练模型 biggan-256](https://tfhub.dev/deepmind/biggan-256/2)
- [预训练模型 biggan-512](https://tfhub.dev/deepmind/biggan-512/2)
#### 2. 模型如何导出 ?
- [how_to_convert_your_model_to_torchscript](http://docs.djl.ai/docs/pytorch/how_to_convert_your_model_to_torchscript.html)
- 导出模型
```text
from src.biggan import BigGAN128
from src.biggan import BigGAN256
from src.biggan import BigGAN512
import torch
import torchvision
from scipy.stats import truncnorm
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--truncation', type=float, default=0.4)
parser.add_argument('-s', '--size', type=int, choices=[128, 256, 512], default=512)
parser.add_argument('-c', '--class_label', type=int, choices=range(0, 1000), default=156)
parser.add_argument('-w', '--pretrained_weight', type=str, required=True)
args = parser.parse_args()
truncation = torch.clamp(torch.tensor(args.truncation), min=0.02+1e-4, max=1.0-1e-4).float()
c = torch.tensor((args.class_label,)).long()
if args.size == 128:
z = truncation * torch.as_tensor(truncnorm.rvs(-2.0, 2.0, size=(1, 120))).float()
biggan = BigGAN128()
elif args.size == 256:
z = truncation * torch.as_tensor(truncnorm.rvs(-2.0, 2.0, size=(1, 140))).float()
biggan = BigGAN256()
elif args.size == 512:
z = truncation * torch.as_tensor(truncnorm.rvs(-2.0, 2.0, size=(1, 128))).float()
biggan = BigGAN512()
biggan.load_state_dict(torch.load(args.pretrained_weight))
biggan.eval()
#Generate model for DJL
listSample = [z, c, torch.tensor(0.2)]
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(biggan, listSample)
# sm = torch.jit.script(tra)
# Save the TorchScript model
traced_script_module.save("traced_model.pt")
with torch.no_grad():
img = biggan(z, c, truncation)
img = 0.5 * (img.data + 1)
pil = torchvision.transforms.ToPILImage()(img.squeeze())
pil.show()
```
### 其它帮助信息
https://aias.top/guides.html
### Git地址
[Github链接](https://github.com/mymagicpower/AIAS)
[Gitee链接](https://gitee.com/mymagicpower/AIAS)
#### 帮助文档:
- https://aias.top/guides.html
- 1.性能优化常见问题:
- https://aias.top/AIAS/guides/performance.html
- 2.引擎配置包括CPUGPU在线自动加载及本地配置:
- https://aias.top/AIAS/guides/engine_config.html
- 3.模型加载方式(在线自动加载,及本地配置):
- https://aias.top/AIAS/guides/load_model.html
- 4.Windows环境常见问题:
- https://aias.top/AIAS/guides/windows.html

View File

@ -1,42 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<module org.jetbrains.idea.maven.project.MavenProjectsManager.isMavenModule="true" type="JAVA_MODULE" version="4">
<component name="CheckStyle-IDEA-Module">
<option name="configuration">
<map />
</option>
</component>
<component name="NewModuleRootManager" LANGUAGE_LEVEL="JDK_1_8">
<output url="file://$MODULE_DIR$/target/classes" />
<output-test url="file://$MODULE_DIR$/target/test-classes" />
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src/main/java" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/src/main/resources" type="java-resource" />
<sourceFolder url="file://$MODULE_DIR$/src/test/java" isTestSource="true" />
<sourceFolder url="file://$MODULE_DIR$/src/test/resources" type="java-test-resource" />
<excludeFolder url="file://$MODULE_DIR$/target" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="library" name="aias-classification-imagenet-lib-0.1.0" level="project" />
<orderEntry type="library" name="Maven: com.google.code.gson:gson:2.8.6" level="project" />
<orderEntry type="library" name="Maven: commons-cli:commons-cli:1.4" level="project" />
<orderEntry type="library" name="Maven: org.apache.logging.log4j:log4j-slf4j-impl:2.17.2" level="project" />
<orderEntry type="library" name="Maven: org.slf4j:slf4j-api:1.7.25" level="project" />
<orderEntry type="library" name="Maven: org.apache.logging.log4j:log4j-api:2.17.2" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: org.apache.logging.log4j:log4j-core:2.17.2" level="project" />
<orderEntry type="library" name="Maven: ai.djl:api:0.17.0" level="project" />
<orderEntry type="library" name="Maven: net.java.dev.jna:jna:5.10.0" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.21" level="project" />
<orderEntry type="library" name="Maven: ai.djl:basicdataset:0.17.0" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-csv:1.9.0" level="project" />
<orderEntry type="library" name="Maven: ai.djl:model-zoo:0.17.0" level="project" />
<orderEntry type="library" name="Maven: ai.djl.pytorch:pytorch-engine:0.17.0" level="project" />
<orderEntry type="library" scope="PROVIDED" name="Maven: org.projectlombok:lombok:1.18.18" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.testng:testng:6.8.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: junit:junit:4.10" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.hamcrest:hamcrest-core:1.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.beanshell:bsh:2.0b4" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: com.beust:jcommander:1.27" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.yaml:snakeyaml:1.6" level="project" />
</component>
</module>

Binary file not shown.

Before

Width:  |  Height:  |  Size: 585 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 653 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 543 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 78 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 83 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 99 KiB

View File

@ -1,103 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Licensed to the Apache Software Foundation (ASF) under one
~ or more contributor license agreements. See the NOTICE file
~ distributed with this work for additional information
~ regarding copyright ownership. The ASF licenses this file
~ to you under the Apache License, Version 2.0 (the
~ "License"); you may not use this file except in compliance
~ with the License. You may obtain a copy of the License at
~
~ http://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing,
~ software distributed under the License is distributed on an
~ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
~ KIND, either express or implied. See the License for the
~ specific language governing permissions and limitations
~ under the License.
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>aias</groupId>
<artifactId>biggan-sdk</artifactId>
<version>0.17.0</version>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<djl.version>0.17.0</djl.version>
</properties>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>8</source>
<target>8</target>
</configuration>
<version>3.8.1</version>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.6</version>
</dependency>
<dependency>
<groupId>commons-cli</groupId>
<artifactId>commons-cli</artifactId>
<version>1.4</version>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-slf4j-impl</artifactId>
<version>2.17.2</version>
</dependency>
<!-- 服务器端推理引擎 -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>${djl.version}</version>
</dependency>
<!-- Pytorch -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.18</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<version>6.8.1</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -1,72 +0,0 @@
package me.aias;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import me.aias.utils.BigGANTranslator;
import me.aias.utils.ImageUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
/**
* An example of generation using BigGAN.
*/
public final class BigGAN {
private static final Logger logger = LoggerFactory.getLogger(BigGAN.class);
public BigGAN() {
}
public static void main(String[] args) throws IOException, ModelException, TranslateException {
// size 支持 128, 256, 512
int size = 512;
// imageClass 支持imagenet类别1~1000
long imageClass = 156;
Criteria<Long, Image> criteria = new BigGAN().generate(size, 0.4f);
Image image = null;
try (ZooModel<Long, Image> model = ModelZoo.loadModel(criteria);
Predictor<Long, Image> generator = model.newPredictor()) {
image = generator.predict(imageClass);
}
ImageUtils.saveImage(image, "image" + imageClass + ".png", "build/output/");
logger.info("Generated image has been saved in: {}", "build/output/");
}
public Criteria<Long, Image> generate(int size, float truncation) {
String url = null;
if (size == 128) {
size = 120;
url = "https://aias-home.oss-cn-beijing.aliyuncs.com/models/biggan128.zip";
} else if (size == 256) {
size = 140;
url = "https://aias-home.oss-cn-beijing.aliyuncs.com/models/biggan256.zip";
} else if (size == 512) {
size = 128;
url = "https://aias-home.oss-cn-beijing.aliyuncs.com/models/biggan512.zip";
}
BigGANTranslator translator = new BigGANTranslator(size, truncation);
Criteria<Long, Image> criteria =
Criteria.builder()
.optEngine("PyTorch") // Use PyTorch engine
.setTypes(Long.class, Image.class)
.optModelUrls(url)
// .optModelUrls("/Users/calvin/BigGAN-Generator-Pretrained-Pytorch/")
// .optModelName("traced_biggan512_model")
.optTranslator(translator)
.optProgress(new ProgressBar())
.build();
return criteria;
}
}

View File

@ -1,47 +0,0 @@
package me.aias.utils;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
public final class BigGANTranslator implements Translator<Long, Image> {
private float truncation;
private int size;
public BigGANTranslator(int size, float truncation) {
this.size = size;
this.truncation = truncation;
}
@Override
public Image processOutput(TranslatorContext ctx, NDList list) {
NDArray output = list.get(0).addi(1).muli(128).clip(0, 255).toType(DataType.UINT8, false);
Image image = ImageFactory.getInstance().fromNDArray(output.get(0));
return image;
}
@Override
public NDList processInput(TranslatorContext ctx, Long imageClass) throws Exception {
NDManager manager = ctx.getNDManager();
NDArray seed =
manager.truncatedNormal(new Shape(1, this.size)).clip(-2.0, 2.0).muli(truncation);
return new NDList(seed, manager.create(imageClass).expandDims(0), manager.create(truncation));
}
@Override
public Batchifier getBatchifier() {
return null;
}
}

View File

@ -1,168 +0,0 @@
package me.aias.utils;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.util.RandomUtils;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Iterator;
public class ImageUtils {
public static Image bufferedImage2DJLImage(BufferedImage img) {
return ImageFactory.getInstance().fromImage(img);
}
public static void saveImage(BufferedImage img, String name, String path) {
Image newImage = ImageFactory.getInstance().fromImage(img); // 支持多种图片格式自动适配
Path outputDir = Paths.get(path);
Path imagePath = outputDir.resolve(name);
// OpenJDK 不能保存 jpg 图片的 alpha channel
try {
newImage.save(Files.newOutputStream(imagePath), "png");
} catch (IOException e) {
e.printStackTrace();
}
}
public static void saveImage(Image img, String name, String path) {
Path outputDir = Paths.get(path);
Path imagePath = outputDir.resolve(name);
// OpenJDK 不能保存 jpg 图片的 alpha channel
try {
img.save(Files.newOutputStream(imagePath), "png");
} catch (IOException e) {
e.printStackTrace();
}
}
public static void saveBoundingBoxImage(
Image img, DetectedObjects detection, String name, String path) throws IOException {
// Make image copy with alpha channel because original image was jpg
img.drawBoundingBoxes(detection);
Path outputDir = Paths.get(path);
Files.createDirectories(outputDir);
Path imagePath = outputDir.resolve(name);
// OpenJDK can't save jpg with alpha channel
img.save(Files.newOutputStream(imagePath), "png");
}
public static void drawImageRect(BufferedImage image, int x, int y, int width, int height) {
// 将绘制图像转换为Graphics2D
Graphics2D g = (Graphics2D) image.getGraphics();
try {
g.setColor(new Color(246, 96, 0));
// 声明画笔属性 单位像素末端无修饰 折线处呈尖角
BasicStroke bStroke = new BasicStroke(4, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER);
g.setStroke(bStroke);
g.drawRect(x, y, width, height);
} finally {
g.dispose();
}
}
public static void drawImageRect(
BufferedImage image, int x, int y, int width, int height, Color c) {
// 将绘制图像转换为Graphics2D
Graphics2D g = (Graphics2D) image.getGraphics();
try {
g.setColor(c);
// 声明画笔属性 单位像素末端无修饰 折线处呈尖角
BasicStroke bStroke = new BasicStroke(1, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER);
g.setStroke(bStroke);
g.drawRect(x, y, width, height);
} finally {
g.dispose();
}
}
public static void drawImageText(BufferedImage image, String text) {
Graphics graphics = image.getGraphics();
int fontSize = 100;
Font font = new Font("楷体", Font.PLAIN, fontSize);
try {
graphics.setFont(font);
graphics.setColor(new Color(246, 96, 0));
int strWidth = graphics.getFontMetrics().stringWidth(text);
graphics.drawString(text, fontSize - (strWidth / 2), fontSize + 30);
} finally {
graphics.dispose();
}
}
/** 返回外扩人脸 factor = 1, 100%, factor = 0.2, 20% */
public static Image getSubImage(Image img, BoundingBox box, float factor) {
Rectangle rect = box.getBounds();
// 左上角坐标
int x1 = (int) (rect.getX() * img.getWidth());
int y1 = (int) (rect.getY() * img.getHeight());
// 宽度高度
int w = (int) (rect.getWidth() * img.getWidth());
int h = (int) (rect.getHeight() * img.getHeight());
// 左上角坐标
int x2 = x1 + w;
int y2 = y1 + h;
// 外扩大100%防止对齐后人脸出现黑边
int new_x1 = Math.max((int) (x1 + x1 * factor / 2 - x2 * factor / 2), 0);
int new_x2 = Math.min((int) (x2 + x2 * factor / 2 - x1 * factor / 2), img.getWidth() - 1);
int new_y1 = Math.max((int) (y1 + y1 * factor / 2 - y2 * factor / 2), 0);
int new_y2 = Math.min((int) (y2 + y2 * factor / 2 - y1 * factor / 2), img.getHeight() - 1);
int new_w = new_x2 - new_x1;
int new_h = new_y2 - new_y1;
return img.getSubImage(new_x1, new_y1, new_w, new_h);
}
public static void drawJoints(Image img, Image subImg, int x, int y, Joints joints) {
BufferedImage image = (BufferedImage) img.getWrappedImage();
Graphics2D g = (Graphics2D) image.getGraphics();
int stroke = 2;
g.setStroke(new BasicStroke((float) stroke));
int imageWidth = subImg.getWidth();
int imageHeight = subImg.getHeight();
Iterator iterator = joints.getJoints().iterator();
while (iterator.hasNext()) {
Joints.Joint joint = (Joints.Joint) iterator.next();
g.setPaint(randomColor().darker());
int newX = x + (int) (joint.getX() * (double) imageWidth);
int newY = y + (int) (joint.getY() * (double) imageHeight);
g.fillOval(newX, newY, 10, 10);
}
g.dispose();
}
private static Color randomColor() {
return new Color(RandomUtils.nextInt(255));
}
public static void drawBoundingBoxImage(Image img, DetectedObjects detection) {
img.drawBoundingBoxes(detection);
}
public static int getX(Image img, BoundingBox box) {
Rectangle rect = box.getBounds();
// 左上角x坐标
int x = (int) (rect.getX() * img.getWidth());
return x;
}
public static int getY(Image img, BoundingBox box) {
Rectangle rect = box.getBounds();
// 左上角y坐标
int y = (int) (rect.getY() * img.getHeight());
return y;
}
}

View File

@ -1,17 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<Configuration status="INFO">
<Appenders>
<Console name="console" target="SYSTEM_OUT">
<PatternLayout
pattern="[%-5level] - %msg%n"/>
</Console>
</Appenders>
<Loggers>
<Root level="info" additivity="false">
<AppenderRef ref="console"/>
</Root>
<Logger name="me.calvin" level="${sys:me.calvin.logging.level:-info}" additivity="false">
<AppenderRef ref="console"/>
</Logger>
</Loggers>
</Configuration>

View File

@ -1,58 +0,0 @@
### 官网:
[官网链接](https://www.aias.top/)
### 风格迁移(预定义4种)SDK
风格迁移可以把一张图片转换成另一种风格。本sdk预定义了4种画风
- 塞尚(Paul Cezanne, 18381906)
- 莫奈 (Claude monet, 18401926)
- 日本浮世绘
- 梵高 (Vincent Willem van Gogh, 1853~1890)
- 原图.
![img](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/gan_sdks/scenery.jpeg)
### 运行例子 - StyleTransferExample
运行成功后,命令行应该看到下面的信息:
```text
...
[INFO ] - Images generated and saved in folder - build/output/
```
#### 生成图片效果 - 塞尚风格
![img](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/gan_sdks/cezanne.png)
#### 生成图片效果 - 莫奈风格
![img](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/gan_sdks/monet.png)
#### 生成图片效果 - 日本浮世绘风格
![img](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/gan_sdks/ukiyoe.png)
#### 生成图片效果 - 梵高风格
![img](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/gan_sdks/vangogh.png)
### 开源算法
#### 找不到了,后面会有更好的替换
### 其它帮助信息
https://aias.top/guides.html
### Git地址
[Github链接](https://github.com/mymagicpower/AIAS)
[Gitee链接](https://gitee.com/mymagicpower/AIAS)
#### 帮助文档:
- https://aias.top/guides.html
- 1.性能优化常见问题:
- https://aias.top/AIAS/guides/performance.html
- 2.引擎配置包括CPUGPU在线自动加载及本地配置:
- https://aias.top/AIAS/guides/engine_config.html
- 3.模型加载方式(在线自动加载,及本地配置):
- https://aias.top/AIAS/guides/load_model.html
- 4.Windows环境常见问题:
- https://aias.top/AIAS/guides/windows.html

View File

@ -1,109 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Licensed to the Apache Software Foundation (ASF) under one
~ or more contributor license agreements. See the NOTICE file
~ distributed with this work for additional information
~ regarding copyright ownership. The ASF licenses this file
~ to you under the Apache License, Version 2.0 (the
~ "License"); you may not use this file except in compliance
~ with the License. You may obtain a copy of the License at
~
~ http://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing,
~ software distributed under the License is distributed on an
~ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
~ KIND, either express or implied. See the License for the
~ specific language governing permissions and limitations
~ under the License.
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>calvin</groupId>
<artifactId>style-transfer-sdk</artifactId>
<version>0.17.0</version>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<djl.version>0.17.0</djl.version>
</properties>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>8</source>
<target>8</target>
</configuration>
<version>3.8.1</version>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.6</version>
</dependency>
<dependency>
<groupId>commons-cli</groupId>
<artifactId>commons-cli</artifactId>
<version>1.4</version>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-slf4j-impl</artifactId>
<version>2.17.2</version>
</dependency>
<!-- 服务器端推理引擎 -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>${djl.version}</version>
</dependency>
<!-- Pytorch -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.18</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<version>6.8.1</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -1,72 +0,0 @@
package me.aias;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import me.aias.util.ImageUtils;
import me.aias.util.StyleTransfer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.file.Paths;
/**
* 风格迁移
*
* @author Calvin
* @email 179209347@qq.com
* @website www.aias.top
*/
public final class StyleTransferExample {
private static final Logger logger = LoggerFactory.getLogger(StyleTransferExample.class);
private StyleTransferExample() {}
public static void main(String[] args) throws IOException, ModelException, TranslateException {
String imagePath = "src/test/resources/scenery.jpeg";
Image input = ImageFactory.getInstance().fromFile(Paths.get(imagePath));
StyleTransfer styleTransfer = new StyleTransfer();
// 梵高风格 (Vincent Willem van Gogh, 1853~1890)
StyleTransfer.Artist artist = StyleTransfer.Artist.VANGOGH;
try (ZooModel<Image, Image> model = ModelZoo.loadModel(styleTransfer.criteria(artist));
Predictor<Image, Image> predictor = model.newPredictor()) {
Image output = predictor.predict(input);
ImageUtils.saveImage(output, artist.toString().toLowerCase()+ ".png", "build/output/");
}
// 塞尚风格 (Paul Cezanne, 18381906)
artist = StyleTransfer.Artist.CEZANNE;
try (ZooModel<Image, Image> model = ModelZoo.loadModel(styleTransfer.criteria(artist));
Predictor<Image, Image> predictor = model.newPredictor()) {
Image output = predictor.predict(input);
ImageUtils.saveImage(output, artist.toString().toLowerCase()+ ".png", "build/output/");
}
// 莫奈风格 (Claude monet, 18401926)
artist = StyleTransfer.Artist.MONET;
try (ZooModel<Image, Image> model = ModelZoo.loadModel(styleTransfer.criteria(artist));
Predictor<Image, Image> predictor = model.newPredictor()) {
Image output = predictor.predict(input);
ImageUtils.saveImage(output, artist.toString().toLowerCase()+ ".png", "build/output/");
}
// 日本浮世绘风格
artist = StyleTransfer.Artist.UKIYOE;
try (ZooModel<Image, Image> model = ModelZoo.loadModel(styleTransfer.criteria(artist));
Predictor<Image, Image> predictor = model.newPredictor()) {
Image output = predictor.predict(input);
ImageUtils.saveImage(output, artist.toString().toLowerCase()+ ".png", "build/output/");
}
logger.info("Images generated and saved in folder - build/output/");
}
}

View File

@ -1,168 +0,0 @@
package me.aias.util;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.util.RandomUtils;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Iterator;
public class ImageUtils {
public static Image bufferedImage2DJLImage(BufferedImage img) {
return ImageFactory.getInstance().fromImage(img);
}
public static void saveImage(BufferedImage img, String name, String path) {
Image newImage = ImageFactory.getInstance().fromImage(img); // 支持多种图片格式自动适配
Path outputDir = Paths.get(path);
Path imagePath = outputDir.resolve(name);
// OpenJDK 不能保存 jpg 图片的 alpha channel
try {
newImage.save(Files.newOutputStream(imagePath), "png");
} catch (IOException e) {
e.printStackTrace();
}
}
public static void saveImage(Image img, String name, String path) {
Path outputDir = Paths.get(path);
Path imagePath = outputDir.resolve(name);
// OpenJDK 不能保存 jpg 图片的 alpha channel
try {
img.save(Files.newOutputStream(imagePath), "png");
} catch (IOException e) {
e.printStackTrace();
}
}
public static void saveBoundingBoxImage(
Image img, DetectedObjects detection, String name, String path) throws IOException {
// Make image copy with alpha channel because original image was jpg
img.drawBoundingBoxes(detection);
Path outputDir = Paths.get(path);
Files.createDirectories(outputDir);
Path imagePath = outputDir.resolve(name);
// OpenJDK can't save jpg with alpha channel
img.save(Files.newOutputStream(imagePath), "png");
}
public static void drawImageRect(BufferedImage image, int x, int y, int width, int height) {
// 将绘制图像转换为Graphics2D
Graphics2D g = (Graphics2D) image.getGraphics();
try {
g.setColor(new Color(246, 96, 0));
// 声明画笔属性 单位像素末端无修饰 折线处呈尖角
BasicStroke bStroke = new BasicStroke(4, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER);
g.setStroke(bStroke);
g.drawRect(x, y, width, height);
} finally {
g.dispose();
}
}
public static void drawImageRect(
BufferedImage image, int x, int y, int width, int height, Color c) {
// 将绘制图像转换为Graphics2D
Graphics2D g = (Graphics2D) image.getGraphics();
try {
g.setColor(c);
// 声明画笔属性 单位像素末端无修饰 折线处呈尖角
BasicStroke bStroke = new BasicStroke(1, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER);
g.setStroke(bStroke);
g.drawRect(x, y, width, height);
} finally {
g.dispose();
}
}
public static void drawImageText(BufferedImage image, String text) {
Graphics graphics = image.getGraphics();
int fontSize = 100;
Font font = new Font("楷体", Font.PLAIN, fontSize);
try {
graphics.setFont(font);
graphics.setColor(new Color(246, 96, 0));
int strWidth = graphics.getFontMetrics().stringWidth(text);
graphics.drawString(text, fontSize - (strWidth / 2), fontSize + 30);
} finally {
graphics.dispose();
}
}
/** 返回外扩人脸 factor = 1, 100%, factor = 0.2, 20% */
public static Image getSubImage(Image img, BoundingBox box, float factor) {
Rectangle rect = box.getBounds();
// 左上角坐标
int x1 = (int) (rect.getX() * img.getWidth());
int y1 = (int) (rect.getY() * img.getHeight());
// 宽度高度
int w = (int) (rect.getWidth() * img.getWidth());
int h = (int) (rect.getHeight() * img.getHeight());
// 左上角坐标
int x2 = x1 + w;
int y2 = y1 + h;
// 外扩大100%防止对齐后人脸出现黑边
int new_x1 = Math.max((int) (x1 + x1 * factor / 2 - x2 * factor / 2), 0);
int new_x2 = Math.min((int) (x2 + x2 * factor / 2 - x1 * factor / 2), img.getWidth() - 1);
int new_y1 = Math.max((int) (y1 + y1 * factor / 2 - y2 * factor / 2), 0);
int new_y2 = Math.min((int) (y2 + y2 * factor / 2 - y1 * factor / 2), img.getHeight() - 1);
int new_w = new_x2 - new_x1;
int new_h = new_y2 - new_y1;
return img.getSubImage(new_x1, new_y1, new_w, new_h);
}
public static void drawJoints(Image img, Image subImg, int x, int y, Joints joints) {
BufferedImage image = (BufferedImage) img.getWrappedImage();
Graphics2D g = (Graphics2D) image.getGraphics();
int stroke = 2;
g.setStroke(new BasicStroke((float) stroke));
int imageWidth = subImg.getWidth();
int imageHeight = subImg.getHeight();
Iterator iterator = joints.getJoints().iterator();
while (iterator.hasNext()) {
Joints.Joint joint = (Joints.Joint) iterator.next();
g.setPaint(randomColor().darker());
int newX = x + (int) (joint.getX() * (double) imageWidth);
int newY = y + (int) (joint.getY() * (double) imageHeight);
g.fillOval(newX, newY, 10, 10);
}
g.dispose();
}
private static Color randomColor() {
return new Color(RandomUtils.nextInt(255));
}
public static void drawBoundingBoxImage(Image img, DetectedObjects detection) {
img.drawBoundingBoxes(detection);
}
public static int getX(Image img, BoundingBox box) {
Rectangle rect = box.getBounds();
// 左上角x坐标
int x = (int) (rect.getX() * img.getWidth());
return x;
}
public static int getY(Image img, BoundingBox box) {
Rectangle rect = box.getBounds();
// 左上角y坐标
int y = (int) (rect.getY() * img.getHeight());
return y;
}
}

View File

@ -1,37 +0,0 @@
package me.aias.util;
import ai.djl.Device;
import ai.djl.modality.cv.Image;
import ai.djl.repository.zoo.Criteria;
import ai.djl.training.util.ProgressBar;
public final class StyleTransfer {
public StyleTransfer() {}
public enum Artist {
CEZANNE, // 塞尚(Paul Cezanne, 18381906)
MONET, // 莫奈 (Claude monet, 18401926)
UKIYOE, // 日本浮世绘
VANGOGH // 梵高 (Vincent Willem van Gogh, 1853~1890)
}
public Criteria<Image, Image> criteria(Artist artist) {
String modelName = "style_" + artist.toString().toLowerCase() + ".zip";
String modelUrl =
"https://aias-home.oss-cn-beijing.aliyuncs.com/models/gan_models/" + modelName;
Criteria<Image, Image> criteria =
Criteria.builder()
.setTypes(Image.class, Image.class)
.optEngine("PyTorch") // Use PyTorch engine
.optModelUrls(modelUrl)
.optProgress(new ProgressBar())
.optDevice(Device.cpu())
.optTranslatorFactory(new StyleTransferTranslatorFactory())
.build();
return criteria;
}
}

View File

@ -1,35 +0,0 @@
package me.aias.util;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
public class StyleTransferTranslator implements Translator<Image, Image> {
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
NDArray image = switchFormat(input.toNDArray(ctx.getNDManager())).expandDims(0);
return new NDList(image.toType(DataType.FLOAT32, false));
}
@Override
public Image processOutput(TranslatorContext ctx, NDList list) {
NDArray output = list.get(0).addi(1).muli(128).toType(DataType.UINT8, false);
return ImageFactory.getInstance().fromNDArray(output.squeeze());
}
@Override
public Batchifier getBatchifier() {
return null;
}
private NDArray switchFormat(NDArray array) {
return NDArrays.stack(array.split(3, 2)).squeeze();
}
}

View File

@ -1,31 +0,0 @@
package me.aias.util;
import ai.djl.Model;
import ai.djl.modality.cv.Image;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import java.lang.reflect.Type;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
public class StyleTransferTranslatorFactory implements TranslatorFactory {
@Override
public Set<Pair<Type, Type>> getSupportedTypes() {
return Collections.singleton(new Pair<>(Image.class, Image.class));
}
@Override
public Translator<?, ?> newInstance(
Class<?> input, Class<?> output, Model model, Map<String, ?> arguments)
throws TranslateException {
if (!isSupported(input, output)) {
throw new IllegalArgumentException("Unsupported input/output types.");
}
return new StyleTransferTranslator();
}
}

View File

@ -1,17 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<Configuration status="INFO">
<Appenders>
<Console name="console" target="SYSTEM_OUT">
<PatternLayout
pattern="[%-5level] - %msg%n"/>
</Console>
</Appenders>
<Loggers>
<Root level="info" additivity="false">
<AppenderRef ref="console"/>
</Root>
<Logger name="me.calvin" level="${sys:me.calvin.logging.level:-info}" additivity="false">
<AppenderRef ref="console"/>
</Logger>
</Loggers>
</Configuration>

Binary file not shown.

Before

Width:  |  Height:  |  Size: 47 KiB

View File

@ -1,43 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<module org.jetbrains.idea.maven.project.MavenProjectsManager.isMavenModule="true" type="JAVA_MODULE" version="4">
<component name="CheckStyle-IDEA-Module">
<option name="configuration">
<map />
</option>
</component>
<component name="NewModuleRootManager" LANGUAGE_LEVEL="JDK_1_8">
<output url="file://$MODULE_DIR$/target/classes" />
<output-test url="file://$MODULE_DIR$/target/test-classes" />
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src/main/java" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/src/main/resources" type="java-resource" />
<sourceFolder url="file://$MODULE_DIR$/src/test/java" isTestSource="true" />
<sourceFolder url="file://$MODULE_DIR$/src/test/resources" type="java-test-resource" />
<excludeFolder url="file://$MODULE_DIR$/target" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="library" name="aias-style-transfer-lib-0.1.0" level="project" />
<orderEntry type="library" name="Maven: com.google.code.gson:gson:2.8.6" level="project" />
<orderEntry type="library" name="Maven: commons-cli:commons-cli:1.4" level="project" />
<orderEntry type="library" name="Maven: org.apache.logging.log4j:log4j-slf4j-impl:2.17.2" level="project" />
<orderEntry type="library" name="Maven: org.slf4j:slf4j-api:1.7.25" level="project" />
<orderEntry type="library" name="Maven: org.apache.logging.log4j:log4j-api:2.17.2" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: org.apache.logging.log4j:log4j-core:2.17.2" level="project" />
<orderEntry type="library" name="Maven: ai.djl:api:0.17.0" level="project" />
<orderEntry type="library" name="Maven: net.java.dev.jna:jna:5.10.0" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.21" level="project" />
<orderEntry type="library" name="Maven: ai.djl:basicdataset:0.17.0" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-csv:1.9.0" level="project" />
<orderEntry type="library" name="Maven: ai.djl:model-zoo:0.17.0" level="project" />
<orderEntry type="library" name="Maven: ai.djl.pytorch:pytorch-engine:0.17.0" level="project" />
<orderEntry type="library" name="Maven: ai.djl.pytorch:pytorch-model-zoo:0.17.0" level="project" />
<orderEntry type="library" scope="PROVIDED" name="Maven: org.projectlombok:lombok:1.18.18" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.testng:testng:6.8.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: junit:junit:4.10" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.hamcrest:hamcrest-core:1.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.beanshell:bsh:2.0b4" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: com.beust:jcommander:1.27" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.yaml:snakeyaml:1.6" level="project" />
</component>
</module>

Binary file not shown.

After

Width:  |  Height:  |  Size: 849 KiB