remove biggan_sdk and style_transfer_sdk
@ -1,137 +0,0 @@
|
||||
### 官网:
|
||||
[官网链接](https://www.aias.top/)
|
||||
|
||||
|
||||
### BIGGAN 图像自动生成SDK
|
||||
能够自动生成1000种类别(支持imagenet数据集分类)的图片。
|
||||
|
||||
### 支持分类如下:
|
||||
- tench, Tinca tinca
|
||||
- goldfish, Carassius auratus
|
||||
- great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
|
||||
- tiger shark, Galeocerdo cuvieri
|
||||
- hammerhead, hammerhead shark
|
||||
- electric ray, crampfish, numbfish, torpedo
|
||||
- stingray
|
||||
- cock
|
||||
- hen
|
||||
- ostrich, Struthio camelus
|
||||
- brambling, Fringilla montifringilla
|
||||
- goldfinch, Carduelis carduelis
|
||||
- house finch, linnet, Carpodacus mexicanus
|
||||
- junco, snowbird
|
||||
- indigo bunting, indigo finch, indigo bird, Passerina cyanea
|
||||
- robin, American robin, Turdus migratorius
|
||||
- ...
|
||||
|
||||
[点击下载](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/classification_imagenet_sdk/synset.txt)
|
||||
|
||||
### SDK包含两个分类器:
|
||||
size 支持 128, 256, 512三种图片尺寸
|
||||
如:size = 512;
|
||||
imageClass 支持imagenet类别0~999
|
||||
如:imageClass = 156;
|
||||
|
||||
### 运行例子 - BigGAN
|
||||
- 测试图片类别11,图片尺寸:512X512
|
||||
![img1](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/biggan_sdk/image11.png)
|
||||
|
||||
- 测试图片类别156,图片尺寸:512X512
|
||||
![img2](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/biggan_sdk/image156.png)
|
||||
|
||||
- 测试图片类别821,图片尺寸:512X512
|
||||
![img3](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/biggan_sdk/image821.png)
|
||||
|
||||
运行成功后,命令行应该看到下面的信息:
|
||||
```text
|
||||
...
|
||||
[INFO ] - Number of inter-op threads is 4
|
||||
[INFO ] - Number of intra-op threads is 4
|
||||
[INFO ] - Generated image has been saved in: build/output/
|
||||
```
|
||||
|
||||
### 开源算法
|
||||
#### 1. sdk使用的开源算法
|
||||
- [BigGAN-Generator-Pretrained-Pytorch](https://github.com/ivclab/BigGAN-Generator-Pretrained-Pytorch)
|
||||
- [预训练模型 biggan-128](https://tfhub.dev/deepmind/biggan-128/2)
|
||||
- [预训练模型 biggan-256](https://tfhub.dev/deepmind/biggan-256/2)
|
||||
- [预训练模型 biggan-512](https://tfhub.dev/deepmind/biggan-512/2)
|
||||
|
||||
|
||||
#### 2. 模型如何导出 ?
|
||||
- [how_to_convert_your_model_to_torchscript](http://docs.djl.ai/docs/pytorch/how_to_convert_your_model_to_torchscript.html)
|
||||
|
||||
- 导出模型
|
||||
```text
|
||||
from src.biggan import BigGAN128
|
||||
from src.biggan import BigGAN256
|
||||
from src.biggan import BigGAN512
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from scipy.stats import truncnorm
|
||||
|
||||
import argparse
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-t', '--truncation', type=float, default=0.4)
|
||||
parser.add_argument('-s', '--size', type=int, choices=[128, 256, 512], default=512)
|
||||
parser.add_argument('-c', '--class_label', type=int, choices=range(0, 1000), default=156)
|
||||
parser.add_argument('-w', '--pretrained_weight', type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
truncation = torch.clamp(torch.tensor(args.truncation), min=0.02+1e-4, max=1.0-1e-4).float()
|
||||
c = torch.tensor((args.class_label,)).long()
|
||||
|
||||
if args.size == 128:
|
||||
z = truncation * torch.as_tensor(truncnorm.rvs(-2.0, 2.0, size=(1, 120))).float()
|
||||
biggan = BigGAN128()
|
||||
elif args.size == 256:
|
||||
z = truncation * torch.as_tensor(truncnorm.rvs(-2.0, 2.0, size=(1, 140))).float()
|
||||
biggan = BigGAN256()
|
||||
elif args.size == 512:
|
||||
z = truncation * torch.as_tensor(truncnorm.rvs(-2.0, 2.0, size=(1, 128))).float()
|
||||
biggan = BigGAN512()
|
||||
|
||||
biggan.load_state_dict(torch.load(args.pretrained_weight))
|
||||
biggan.eval()
|
||||
|
||||
#Generate model for DJL
|
||||
listSample = [z, c, torch.tensor(0.2)]
|
||||
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
|
||||
traced_script_module = torch.jit.trace(biggan, listSample)
|
||||
# sm = torch.jit.script(tra)
|
||||
# Save the TorchScript model
|
||||
traced_script_module.save("traced_model.pt")
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
img = biggan(z, c, truncation)
|
||||
|
||||
|
||||
img = 0.5 * (img.data + 1)
|
||||
pil = torchvision.transforms.ToPILImage()(img.squeeze())
|
||||
pil.show()
|
||||
```
|
||||
|
||||
|
||||
### 其它帮助信息
|
||||
https://aias.top/guides.html
|
||||
|
||||
### Git地址:
|
||||
[Github链接](https://github.com/mymagicpower/AIAS)
|
||||
[Gitee链接](https://gitee.com/mymagicpower/AIAS)
|
||||
|
||||
|
||||
#### 帮助文档:
|
||||
- https://aias.top/guides.html
|
||||
- 1.性能优化常见问题:
|
||||
- https://aias.top/AIAS/guides/performance.html
|
||||
- 2.引擎配置(包括CPU,GPU在线自动加载,及本地配置):
|
||||
- https://aias.top/AIAS/guides/engine_config.html
|
||||
- 3.模型加载方式(在线自动加载,及本地配置):
|
||||
- https://aias.top/AIAS/guides/load_model.html
|
||||
- 4.Windows环境常见问题:
|
||||
- https://aias.top/AIAS/guides/windows.html
|
@ -1,42 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module org.jetbrains.idea.maven.project.MavenProjectsManager.isMavenModule="true" type="JAVA_MODULE" version="4">
|
||||
<component name="CheckStyle-IDEA-Module">
|
||||
<option name="configuration">
|
||||
<map />
|
||||
</option>
|
||||
</component>
|
||||
<component name="NewModuleRootManager" LANGUAGE_LEVEL="JDK_1_8">
|
||||
<output url="file://$MODULE_DIR$/target/classes" />
|
||||
<output-test url="file://$MODULE_DIR$/target/test-classes" />
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<sourceFolder url="file://$MODULE_DIR$/src/main/java" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/src/main/resources" type="java-resource" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/src/test/java" isTestSource="true" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/src/test/resources" type="java-test-resource" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/target" />
|
||||
</content>
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
<orderEntry type="library" name="aias-classification-imagenet-lib-0.1.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.code.gson:gson:2.8.6" level="project" />
|
||||
<orderEntry type="library" name="Maven: commons-cli:commons-cli:1.4" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.logging.log4j:log4j-slf4j-impl:2.17.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.slf4j:slf4j-api:1.7.25" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.logging.log4j:log4j-api:2.17.2" level="project" />
|
||||
<orderEntry type="library" scope="RUNTIME" name="Maven: org.apache.logging.log4j:log4j-core:2.17.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: ai.djl:api:0.17.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: net.java.dev.jna:jna:5.10.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.21" level="project" />
|
||||
<orderEntry type="library" name="Maven: ai.djl:basicdataset:0.17.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.commons:commons-csv:1.9.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: ai.djl:model-zoo:0.17.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: ai.djl.pytorch:pytorch-engine:0.17.0" level="project" />
|
||||
<orderEntry type="library" scope="PROVIDED" name="Maven: org.projectlombok:lombok:1.18.18" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: org.testng:testng:6.8.1" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: junit:junit:4.10" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: org.hamcrest:hamcrest-core:1.1" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: org.beanshell:bsh:2.0b4" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: com.beust:jcommander:1.27" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: org.yaml:snakeyaml:1.6" level="project" />
|
||||
</component>
|
||||
</module>
|
Before Width: | Height: | Size: 585 KiB |
Before Width: | Height: | Size: 653 KiB |
Before Width: | Height: | Size: 543 KiB |
Before Width: | Height: | Size: 78 KiB |
Before Width: | Height: | Size: 83 KiB |
Before Width: | Height: | Size: 99 KiB |
@ -1,103 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
~ Licensed to the Apache Software Foundation (ASF) under one
|
||||
~ or more contributor license agreements. See the NOTICE file
|
||||
~ distributed with this work for additional information
|
||||
~ regarding copyright ownership. The ASF licenses this file
|
||||
~ to you under the Apache License, Version 2.0 (the
|
||||
~ "License"); you may not use this file except in compliance
|
||||
~ with the License. You may obtain a copy of the License at
|
||||
~
|
||||
~ http://www.apache.org/licenses/LICENSE-2.0
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing,
|
||||
~ software distributed under the License is distributed on an
|
||||
~ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
~ KIND, either express or implied. See the License for the
|
||||
~ specific language governing permissions and limitations
|
||||
~ under the License.
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>aias</groupId>
|
||||
<artifactId>biggan-sdk</artifactId>
|
||||
<version>0.17.0</version>
|
||||
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<maven.compiler.source>1.8</maven.compiler.source>
|
||||
<maven.compiler.target>1.8</maven.compiler.target>
|
||||
<djl.version>0.17.0</djl.version>
|
||||
</properties>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-compiler-plugin</artifactId>
|
||||
<configuration>
|
||||
<source>8</source>
|
||||
<target>8</target>
|
||||
</configuration>
|
||||
<version>3.8.1</version>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.google.code.gson</groupId>
|
||||
<artifactId>gson</artifactId>
|
||||
<version>2.8.6</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-cli</groupId>
|
||||
<artifactId>commons-cli</artifactId>
|
||||
<version>1.4</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.logging.log4j</groupId>
|
||||
<artifactId>log4j-slf4j-impl</artifactId>
|
||||
<version>2.17.2</version>
|
||||
</dependency>
|
||||
|
||||
<!-- 服务器端推理引擎 -->
|
||||
<dependency>
|
||||
<groupId>ai.djl</groupId>
|
||||
<artifactId>api</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.djl</groupId>
|
||||
<artifactId>basicdataset</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.djl</groupId>
|
||||
<artifactId>model-zoo</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
<!-- Pytorch -->
|
||||
<dependency>
|
||||
<groupId>ai.djl.pytorch</groupId>
|
||||
<artifactId>pytorch-engine</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<version>1.18.18</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.testng</groupId>
|
||||
<artifactId>testng</artifactId>
|
||||
<version>6.8.1</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
@ -1,72 +0,0 @@
|
||||
package me.aias;
|
||||
|
||||
import ai.djl.ModelException;
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.repository.zoo.Criteria;
|
||||
import ai.djl.repository.zoo.ModelZoo;
|
||||
import ai.djl.repository.zoo.ZooModel;
|
||||
import ai.djl.training.util.ProgressBar;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import me.aias.utils.BigGANTranslator;
|
||||
import me.aias.utils.ImageUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* An example of generation using BigGAN.
|
||||
*/
|
||||
public final class BigGAN {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(BigGAN.class);
|
||||
|
||||
public BigGAN() {
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws IOException, ModelException, TranslateException {
|
||||
// size 支持 128, 256, 512
|
||||
int size = 512;
|
||||
// imageClass 支持imagenet类别1~1000
|
||||
long imageClass = 156;
|
||||
|
||||
Criteria<Long, Image> criteria = new BigGAN().generate(size, 0.4f);
|
||||
Image image = null;
|
||||
try (ZooModel<Long, Image> model = ModelZoo.loadModel(criteria);
|
||||
Predictor<Long, Image> generator = model.newPredictor()) {
|
||||
image = generator.predict(imageClass);
|
||||
}
|
||||
|
||||
ImageUtils.saveImage(image, "image" + imageClass + ".png", "build/output/");
|
||||
logger.info("Generated image has been saved in: {}", "build/output/");
|
||||
}
|
||||
|
||||
public Criteria<Long, Image> generate(int size, float truncation) {
|
||||
|
||||
String url = null;
|
||||
if (size == 128) {
|
||||
size = 120;
|
||||
url = "https://aias-home.oss-cn-beijing.aliyuncs.com/models/biggan128.zip";
|
||||
} else if (size == 256) {
|
||||
size = 140;
|
||||
url = "https://aias-home.oss-cn-beijing.aliyuncs.com/models/biggan256.zip";
|
||||
} else if (size == 512) {
|
||||
size = 128;
|
||||
url = "https://aias-home.oss-cn-beijing.aliyuncs.com/models/biggan512.zip";
|
||||
}
|
||||
|
||||
BigGANTranslator translator = new BigGANTranslator(size, truncation);
|
||||
Criteria<Long, Image> criteria =
|
||||
Criteria.builder()
|
||||
.optEngine("PyTorch") // Use PyTorch engine
|
||||
.setTypes(Long.class, Image.class)
|
||||
.optModelUrls(url)
|
||||
// .optModelUrls("/Users/calvin/BigGAN-Generator-Pretrained-Pytorch/")
|
||||
// .optModelName("traced_biggan512_model")
|
||||
.optTranslator(translator)
|
||||
.optProgress(new ProgressBar())
|
||||
.build();
|
||||
return criteria;
|
||||
}
|
||||
}
|
@ -1,47 +0,0 @@
|
||||
package me.aias.utils;
|
||||
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDList;
|
||||
import ai.djl.ndarray.NDManager;
|
||||
import ai.djl.ndarray.types.DataType;
|
||||
import ai.djl.ndarray.types.Shape;
|
||||
import ai.djl.translate.Batchifier;
|
||||
import ai.djl.translate.Translator;
|
||||
import ai.djl.translate.TranslatorContext;
|
||||
|
||||
public final class BigGANTranslator implements Translator<Long, Image> {
|
||||
|
||||
private float truncation;
|
||||
private int size;
|
||||
|
||||
public BigGANTranslator(int size, float truncation) {
|
||||
this.size = size;
|
||||
this.truncation = truncation;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image processOutput(TranslatorContext ctx, NDList list) {
|
||||
NDArray output = list.get(0).addi(1).muli(128).clip(0, 255).toType(DataType.UINT8, false);
|
||||
|
||||
Image image = ImageFactory.getInstance().fromNDArray(output.get(0));
|
||||
|
||||
return image;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NDList processInput(TranslatorContext ctx, Long imageClass) throws Exception {
|
||||
NDManager manager = ctx.getNDManager();
|
||||
|
||||
NDArray seed =
|
||||
manager.truncatedNormal(new Shape(1, this.size)).clip(-2.0, 2.0).muli(truncation);
|
||||
|
||||
return new NDList(seed, manager.create(imageClass).expandDims(0), manager.create(truncation));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Batchifier getBatchifier() {
|
||||
return null;
|
||||
}
|
||||
}
|
@ -1,168 +0,0 @@
|
||||
package me.aias.utils;
|
||||
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.modality.cv.output.BoundingBox;
|
||||
import ai.djl.modality.cv.output.DetectedObjects;
|
||||
import ai.djl.modality.cv.output.Joints;
|
||||
import ai.djl.modality.cv.output.Rectangle;
|
||||
import ai.djl.util.RandomUtils;
|
||||
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Iterator;
|
||||
|
||||
public class ImageUtils {
|
||||
|
||||
public static Image bufferedImage2DJLImage(BufferedImage img) {
|
||||
return ImageFactory.getInstance().fromImage(img);
|
||||
}
|
||||
|
||||
public static void saveImage(BufferedImage img, String name, String path) {
|
||||
Image newImage = ImageFactory.getInstance().fromImage(img); // 支持多种图片格式,自动适配
|
||||
Path outputDir = Paths.get(path);
|
||||
Path imagePath = outputDir.resolve(name);
|
||||
// OpenJDK 不能保存 jpg 图片的 alpha channel
|
||||
try {
|
||||
newImage.save(Files.newOutputStream(imagePath), "png");
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
public static void saveImage(Image img, String name, String path) {
|
||||
Path outputDir = Paths.get(path);
|
||||
Path imagePath = outputDir.resolve(name);
|
||||
// OpenJDK 不能保存 jpg 图片的 alpha channel
|
||||
try {
|
||||
img.save(Files.newOutputStream(imagePath), "png");
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
public static void saveBoundingBoxImage(
|
||||
Image img, DetectedObjects detection, String name, String path) throws IOException {
|
||||
// Make image copy with alpha channel because original image was jpg
|
||||
img.drawBoundingBoxes(detection);
|
||||
Path outputDir = Paths.get(path);
|
||||
Files.createDirectories(outputDir);
|
||||
Path imagePath = outputDir.resolve(name);
|
||||
// OpenJDK can't save jpg with alpha channel
|
||||
img.save(Files.newOutputStream(imagePath), "png");
|
||||
}
|
||||
|
||||
public static void drawImageRect(BufferedImage image, int x, int y, int width, int height) {
|
||||
// 将绘制图像转换为Graphics2D
|
||||
Graphics2D g = (Graphics2D) image.getGraphics();
|
||||
try {
|
||||
g.setColor(new Color(246, 96, 0));
|
||||
// 声明画笔属性 :粗 细(单位像素)末端无修饰 折线处呈尖角
|
||||
BasicStroke bStroke = new BasicStroke(4, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER);
|
||||
g.setStroke(bStroke);
|
||||
g.drawRect(x, y, width, height);
|
||||
|
||||
} finally {
|
||||
g.dispose();
|
||||
}
|
||||
}
|
||||
|
||||
public static void drawImageRect(
|
||||
BufferedImage image, int x, int y, int width, int height, Color c) {
|
||||
// 将绘制图像转换为Graphics2D
|
||||
Graphics2D g = (Graphics2D) image.getGraphics();
|
||||
try {
|
||||
g.setColor(c);
|
||||
// 声明画笔属性 :粗 细(单位像素)末端无修饰 折线处呈尖角
|
||||
BasicStroke bStroke = new BasicStroke(1, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER);
|
||||
g.setStroke(bStroke);
|
||||
g.drawRect(x, y, width, height);
|
||||
|
||||
} finally {
|
||||
g.dispose();
|
||||
}
|
||||
}
|
||||
|
||||
public static void drawImageText(BufferedImage image, String text) {
|
||||
Graphics graphics = image.getGraphics();
|
||||
int fontSize = 100;
|
||||
Font font = new Font("楷体", Font.PLAIN, fontSize);
|
||||
try {
|
||||
graphics.setFont(font);
|
||||
graphics.setColor(new Color(246, 96, 0));
|
||||
int strWidth = graphics.getFontMetrics().stringWidth(text);
|
||||
graphics.drawString(text, fontSize - (strWidth / 2), fontSize + 30);
|
||||
} finally {
|
||||
graphics.dispose();
|
||||
}
|
||||
}
|
||||
|
||||
/** 返回外扩人脸 factor = 1, 100%, factor = 0.2, 20% */
|
||||
public static Image getSubImage(Image img, BoundingBox box, float factor) {
|
||||
Rectangle rect = box.getBounds();
|
||||
// 左上角坐标
|
||||
int x1 = (int) (rect.getX() * img.getWidth());
|
||||
int y1 = (int) (rect.getY() * img.getHeight());
|
||||
// 宽度,高度
|
||||
int w = (int) (rect.getWidth() * img.getWidth());
|
||||
int h = (int) (rect.getHeight() * img.getHeight());
|
||||
// 左上角坐标
|
||||
int x2 = x1 + w;
|
||||
int y2 = y1 + h;
|
||||
|
||||
// 外扩大100%,防止对齐后人脸出现黑边
|
||||
int new_x1 = Math.max((int) (x1 + x1 * factor / 2 - x2 * factor / 2), 0);
|
||||
int new_x2 = Math.min((int) (x2 + x2 * factor / 2 - x1 * factor / 2), img.getWidth() - 1);
|
||||
int new_y1 = Math.max((int) (y1 + y1 * factor / 2 - y2 * factor / 2), 0);
|
||||
int new_y2 = Math.min((int) (y2 + y2 * factor / 2 - y1 * factor / 2), img.getHeight() - 1);
|
||||
int new_w = new_x2 - new_x1;
|
||||
int new_h = new_y2 - new_y1;
|
||||
return img.getSubImage(new_x1, new_y1, new_w, new_h);
|
||||
}
|
||||
|
||||
public static void drawJoints(Image img, Image subImg, int x, int y, Joints joints) {
|
||||
BufferedImage image = (BufferedImage) img.getWrappedImage();
|
||||
Graphics2D g = (Graphics2D) image.getGraphics();
|
||||
int stroke = 2;
|
||||
g.setStroke(new BasicStroke((float) stroke));
|
||||
int imageWidth = subImg.getWidth();
|
||||
int imageHeight = subImg.getHeight();
|
||||
Iterator iterator = joints.getJoints().iterator();
|
||||
|
||||
while (iterator.hasNext()) {
|
||||
Joints.Joint joint = (Joints.Joint) iterator.next();
|
||||
g.setPaint(randomColor().darker());
|
||||
int newX = x + (int) (joint.getX() * (double) imageWidth);
|
||||
int newY = y + (int) (joint.getY() * (double) imageHeight);
|
||||
g.fillOval(newX, newY, 10, 10);
|
||||
}
|
||||
|
||||
g.dispose();
|
||||
}
|
||||
|
||||
private static Color randomColor() {
|
||||
return new Color(RandomUtils.nextInt(255));
|
||||
}
|
||||
|
||||
public static void drawBoundingBoxImage(Image img, DetectedObjects detection) {
|
||||
img.drawBoundingBoxes(detection);
|
||||
}
|
||||
|
||||
public static int getX(Image img, BoundingBox box) {
|
||||
Rectangle rect = box.getBounds();
|
||||
// 左上角x坐标
|
||||
int x = (int) (rect.getX() * img.getWidth());
|
||||
return x;
|
||||
}
|
||||
|
||||
public static int getY(Image img, BoundingBox box) {
|
||||
Rectangle rect = box.getBounds();
|
||||
// 左上角y坐标
|
||||
int y = (int) (rect.getY() * img.getHeight());
|
||||
return y;
|
||||
}
|
||||
}
|
@ -1,17 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Configuration status="INFO">
|
||||
<Appenders>
|
||||
<Console name="console" target="SYSTEM_OUT">
|
||||
<PatternLayout
|
||||
pattern="[%-5level] - %msg%n"/>
|
||||
</Console>
|
||||
</Appenders>
|
||||
<Loggers>
|
||||
<Root level="info" additivity="false">
|
||||
<AppenderRef ref="console"/>
|
||||
</Root>
|
||||
<Logger name="me.calvin" level="${sys:me.calvin.logging.level:-info}" additivity="false">
|
||||
<AppenderRef ref="console"/>
|
||||
</Logger>
|
||||
</Loggers>
|
||||
</Configuration>
|
@ -1,58 +0,0 @@
|
||||
### 官网:
|
||||
[官网链接](https://www.aias.top/)
|
||||
|
||||
### 风格迁移(预定义4种)SDK
|
||||
风格迁移可以把一张图片转换成另一种风格。本sdk预定义了4种画风:
|
||||
- 塞尚(Paul Cezanne, 1838~1906)
|
||||
- 莫奈 (Claude monet, 1840~1926)
|
||||
- 日本浮世绘
|
||||
- 梵高 (Vincent Willem van Gogh, 1853~1890)
|
||||
|
||||
|
||||
- 原图.
|
||||
![img](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/gan_sdks/scenery.jpeg)
|
||||
|
||||
|
||||
### 运行例子 - StyleTransferExample
|
||||
运行成功后,命令行应该看到下面的信息:
|
||||
```text
|
||||
...
|
||||
[INFO ] - Images generated and saved in folder - build/output/
|
||||
|
||||
```
|
||||
|
||||
#### 生成图片效果 - 塞尚风格
|
||||
![img](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/gan_sdks/cezanne.png)
|
||||
|
||||
#### 生成图片效果 - 莫奈风格
|
||||
![img](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/gan_sdks/monet.png)
|
||||
|
||||
#### 生成图片效果 - 日本浮世绘风格
|
||||
![img](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/gan_sdks/ukiyoe.png)
|
||||
|
||||
#### 生成图片效果 - 梵高风格
|
||||
![img](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/gan_sdks/vangogh.png)
|
||||
|
||||
|
||||
### 开源算法
|
||||
#### 找不到了,后面会有更好的替换
|
||||
|
||||
### 其它帮助信息
|
||||
https://aias.top/guides.html
|
||||
|
||||
|
||||
### Git地址:
|
||||
[Github链接](https://github.com/mymagicpower/AIAS)
|
||||
[Gitee链接](https://gitee.com/mymagicpower/AIAS)
|
||||
|
||||
|
||||
#### 帮助文档:
|
||||
- https://aias.top/guides.html
|
||||
- 1.性能优化常见问题:
|
||||
- https://aias.top/AIAS/guides/performance.html
|
||||
- 2.引擎配置(包括CPU,GPU在线自动加载,及本地配置):
|
||||
- https://aias.top/AIAS/guides/engine_config.html
|
||||
- 3.模型加载方式(在线自动加载,及本地配置):
|
||||
- https://aias.top/AIAS/guides/load_model.html
|
||||
- 4.Windows环境常见问题:
|
||||
- https://aias.top/AIAS/guides/windows.html
|
@ -1,109 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
~ Licensed to the Apache Software Foundation (ASF) under one
|
||||
~ or more contributor license agreements. See the NOTICE file
|
||||
~ distributed with this work for additional information
|
||||
~ regarding copyright ownership. The ASF licenses this file
|
||||
~ to you under the Apache License, Version 2.0 (the
|
||||
~ "License"); you may not use this file except in compliance
|
||||
~ with the License. You may obtain a copy of the License at
|
||||
~
|
||||
~ http://www.apache.org/licenses/LICENSE-2.0
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing,
|
||||
~ software distributed under the License is distributed on an
|
||||
~ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
~ KIND, either express or implied. See the License for the
|
||||
~ specific language governing permissions and limitations
|
||||
~ under the License.
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>calvin</groupId>
|
||||
<artifactId>style-transfer-sdk</artifactId>
|
||||
<version>0.17.0</version>
|
||||
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<maven.compiler.source>1.8</maven.compiler.source>
|
||||
<maven.compiler.target>1.8</maven.compiler.target>
|
||||
<djl.version>0.17.0</djl.version>
|
||||
</properties>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-compiler-plugin</artifactId>
|
||||
<configuration>
|
||||
<source>8</source>
|
||||
<target>8</target>
|
||||
</configuration>
|
||||
<version>3.8.1</version>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.google.code.gson</groupId>
|
||||
<artifactId>gson</artifactId>
|
||||
<version>2.8.6</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-cli</groupId>
|
||||
<artifactId>commons-cli</artifactId>
|
||||
<version>1.4</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.logging.log4j</groupId>
|
||||
<artifactId>log4j-slf4j-impl</artifactId>
|
||||
<version>2.17.2</version>
|
||||
</dependency>
|
||||
|
||||
<!-- 服务器端推理引擎 -->
|
||||
<dependency>
|
||||
<groupId>ai.djl</groupId>
|
||||
<artifactId>api</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.djl</groupId>
|
||||
<artifactId>basicdataset</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.djl</groupId>
|
||||
<artifactId>model-zoo</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
<!-- Pytorch -->
|
||||
<dependency>
|
||||
<groupId>ai.djl.pytorch</groupId>
|
||||
<artifactId>pytorch-engine</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.djl.pytorch</groupId>
|
||||
<artifactId>pytorch-model-zoo</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<version>1.18.18</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.testng</groupId>
|
||||
<artifactId>testng</artifactId>
|
||||
<version>6.8.1</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
@ -1,72 +0,0 @@
|
||||
package me.aias;
|
||||
|
||||
import ai.djl.ModelException;
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.repository.zoo.ModelZoo;
|
||||
import ai.djl.repository.zoo.ZooModel;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import me.aias.util.ImageUtils;
|
||||
import me.aias.util.StyleTransfer;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Paths;
|
||||
|
||||
/**
|
||||
* 风格迁移
|
||||
*
|
||||
* @author Calvin
|
||||
* @email 179209347@qq.com
|
||||
* @website www.aias.top
|
||||
*/
|
||||
|
||||
public final class StyleTransferExample {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(StyleTransferExample.class);
|
||||
|
||||
private StyleTransferExample() {}
|
||||
|
||||
public static void main(String[] args) throws IOException, ModelException, TranslateException {
|
||||
String imagePath = "src/test/resources/scenery.jpeg";
|
||||
Image input = ImageFactory.getInstance().fromFile(Paths.get(imagePath));
|
||||
StyleTransfer styleTransfer = new StyleTransfer();
|
||||
|
||||
// 梵高风格 (Vincent Willem van Gogh, 1853~1890)
|
||||
StyleTransfer.Artist artist = StyleTransfer.Artist.VANGOGH;
|
||||
try (ZooModel<Image, Image> model = ModelZoo.loadModel(styleTransfer.criteria(artist));
|
||||
Predictor<Image, Image> predictor = model.newPredictor()) {
|
||||
Image output = predictor.predict(input);
|
||||
ImageUtils.saveImage(output, artist.toString().toLowerCase()+ ".png", "build/output/");
|
||||
}
|
||||
|
||||
// 塞尚风格 (Paul Cezanne, 1838~1906)
|
||||
artist = StyleTransfer.Artist.CEZANNE;
|
||||
|
||||
try (ZooModel<Image, Image> model = ModelZoo.loadModel(styleTransfer.criteria(artist));
|
||||
Predictor<Image, Image> predictor = model.newPredictor()) {
|
||||
Image output = predictor.predict(input);
|
||||
ImageUtils.saveImage(output, artist.toString().toLowerCase()+ ".png", "build/output/");
|
||||
}
|
||||
|
||||
// 莫奈风格 (Claude monet, 1840~1926)
|
||||
artist = StyleTransfer.Artist.MONET;
|
||||
try (ZooModel<Image, Image> model = ModelZoo.loadModel(styleTransfer.criteria(artist));
|
||||
Predictor<Image, Image> predictor = model.newPredictor()) {
|
||||
Image output = predictor.predict(input);
|
||||
ImageUtils.saveImage(output, artist.toString().toLowerCase()+ ".png", "build/output/");
|
||||
}
|
||||
|
||||
// 日本浮世绘风格
|
||||
artist = StyleTransfer.Artist.UKIYOE;
|
||||
try (ZooModel<Image, Image> model = ModelZoo.loadModel(styleTransfer.criteria(artist));
|
||||
Predictor<Image, Image> predictor = model.newPredictor()) {
|
||||
Image output = predictor.predict(input);
|
||||
ImageUtils.saveImage(output, artist.toString().toLowerCase()+ ".png", "build/output/");
|
||||
}
|
||||
|
||||
logger.info("Images generated and saved in folder - build/output/");
|
||||
}
|
||||
}
|
@ -1,168 +0,0 @@
|
||||
package me.aias.util;
|
||||
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.modality.cv.output.BoundingBox;
|
||||
import ai.djl.modality.cv.output.DetectedObjects;
|
||||
import ai.djl.modality.cv.output.Joints;
|
||||
import ai.djl.modality.cv.output.Rectangle;
|
||||
import ai.djl.util.RandomUtils;
|
||||
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Iterator;
|
||||
|
||||
public class ImageUtils {
|
||||
|
||||
public static Image bufferedImage2DJLImage(BufferedImage img) {
|
||||
return ImageFactory.getInstance().fromImage(img);
|
||||
}
|
||||
|
||||
public static void saveImage(BufferedImage img, String name, String path) {
|
||||
Image newImage = ImageFactory.getInstance().fromImage(img); // 支持多种图片格式,自动适配
|
||||
Path outputDir = Paths.get(path);
|
||||
Path imagePath = outputDir.resolve(name);
|
||||
// OpenJDK 不能保存 jpg 图片的 alpha channel
|
||||
try {
|
||||
newImage.save(Files.newOutputStream(imagePath), "png");
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
public static void saveImage(Image img, String name, String path) {
|
||||
Path outputDir = Paths.get(path);
|
||||
Path imagePath = outputDir.resolve(name);
|
||||
// OpenJDK 不能保存 jpg 图片的 alpha channel
|
||||
try {
|
||||
img.save(Files.newOutputStream(imagePath), "png");
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
public static void saveBoundingBoxImage(
|
||||
Image img, DetectedObjects detection, String name, String path) throws IOException {
|
||||
// Make image copy with alpha channel because original image was jpg
|
||||
img.drawBoundingBoxes(detection);
|
||||
Path outputDir = Paths.get(path);
|
||||
Files.createDirectories(outputDir);
|
||||
Path imagePath = outputDir.resolve(name);
|
||||
// OpenJDK can't save jpg with alpha channel
|
||||
img.save(Files.newOutputStream(imagePath), "png");
|
||||
}
|
||||
|
||||
public static void drawImageRect(BufferedImage image, int x, int y, int width, int height) {
|
||||
// 将绘制图像转换为Graphics2D
|
||||
Graphics2D g = (Graphics2D) image.getGraphics();
|
||||
try {
|
||||
g.setColor(new Color(246, 96, 0));
|
||||
// 声明画笔属性 :粗 细(单位像素)末端无修饰 折线处呈尖角
|
||||
BasicStroke bStroke = new BasicStroke(4, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER);
|
||||
g.setStroke(bStroke);
|
||||
g.drawRect(x, y, width, height);
|
||||
|
||||
} finally {
|
||||
g.dispose();
|
||||
}
|
||||
}
|
||||
|
||||
public static void drawImageRect(
|
||||
BufferedImage image, int x, int y, int width, int height, Color c) {
|
||||
// 将绘制图像转换为Graphics2D
|
||||
Graphics2D g = (Graphics2D) image.getGraphics();
|
||||
try {
|
||||
g.setColor(c);
|
||||
// 声明画笔属性 :粗 细(单位像素)末端无修饰 折线处呈尖角
|
||||
BasicStroke bStroke = new BasicStroke(1, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER);
|
||||
g.setStroke(bStroke);
|
||||
g.drawRect(x, y, width, height);
|
||||
|
||||
} finally {
|
||||
g.dispose();
|
||||
}
|
||||
}
|
||||
|
||||
public static void drawImageText(BufferedImage image, String text) {
|
||||
Graphics graphics = image.getGraphics();
|
||||
int fontSize = 100;
|
||||
Font font = new Font("楷体", Font.PLAIN, fontSize);
|
||||
try {
|
||||
graphics.setFont(font);
|
||||
graphics.setColor(new Color(246, 96, 0));
|
||||
int strWidth = graphics.getFontMetrics().stringWidth(text);
|
||||
graphics.drawString(text, fontSize - (strWidth / 2), fontSize + 30);
|
||||
} finally {
|
||||
graphics.dispose();
|
||||
}
|
||||
}
|
||||
|
||||
/** 返回外扩人脸 factor = 1, 100%, factor = 0.2, 20% */
|
||||
public static Image getSubImage(Image img, BoundingBox box, float factor) {
|
||||
Rectangle rect = box.getBounds();
|
||||
// 左上角坐标
|
||||
int x1 = (int) (rect.getX() * img.getWidth());
|
||||
int y1 = (int) (rect.getY() * img.getHeight());
|
||||
// 宽度,高度
|
||||
int w = (int) (rect.getWidth() * img.getWidth());
|
||||
int h = (int) (rect.getHeight() * img.getHeight());
|
||||
// 左上角坐标
|
||||
int x2 = x1 + w;
|
||||
int y2 = y1 + h;
|
||||
|
||||
// 外扩大100%,防止对齐后人脸出现黑边
|
||||
int new_x1 = Math.max((int) (x1 + x1 * factor / 2 - x2 * factor / 2), 0);
|
||||
int new_x2 = Math.min((int) (x2 + x2 * factor / 2 - x1 * factor / 2), img.getWidth() - 1);
|
||||
int new_y1 = Math.max((int) (y1 + y1 * factor / 2 - y2 * factor / 2), 0);
|
||||
int new_y2 = Math.min((int) (y2 + y2 * factor / 2 - y1 * factor / 2), img.getHeight() - 1);
|
||||
int new_w = new_x2 - new_x1;
|
||||
int new_h = new_y2 - new_y1;
|
||||
return img.getSubImage(new_x1, new_y1, new_w, new_h);
|
||||
}
|
||||
|
||||
public static void drawJoints(Image img, Image subImg, int x, int y, Joints joints) {
|
||||
BufferedImage image = (BufferedImage) img.getWrappedImage();
|
||||
Graphics2D g = (Graphics2D) image.getGraphics();
|
||||
int stroke = 2;
|
||||
g.setStroke(new BasicStroke((float) stroke));
|
||||
int imageWidth = subImg.getWidth();
|
||||
int imageHeight = subImg.getHeight();
|
||||
Iterator iterator = joints.getJoints().iterator();
|
||||
|
||||
while (iterator.hasNext()) {
|
||||
Joints.Joint joint = (Joints.Joint) iterator.next();
|
||||
g.setPaint(randomColor().darker());
|
||||
int newX = x + (int) (joint.getX() * (double) imageWidth);
|
||||
int newY = y + (int) (joint.getY() * (double) imageHeight);
|
||||
g.fillOval(newX, newY, 10, 10);
|
||||
}
|
||||
|
||||
g.dispose();
|
||||
}
|
||||
|
||||
private static Color randomColor() {
|
||||
return new Color(RandomUtils.nextInt(255));
|
||||
}
|
||||
|
||||
public static void drawBoundingBoxImage(Image img, DetectedObjects detection) {
|
||||
img.drawBoundingBoxes(detection);
|
||||
}
|
||||
|
||||
public static int getX(Image img, BoundingBox box) {
|
||||
Rectangle rect = box.getBounds();
|
||||
// 左上角x坐标
|
||||
int x = (int) (rect.getX() * img.getWidth());
|
||||
return x;
|
||||
}
|
||||
|
||||
public static int getY(Image img, BoundingBox box) {
|
||||
Rectangle rect = box.getBounds();
|
||||
// 左上角y坐标
|
||||
int y = (int) (rect.getY() * img.getHeight());
|
||||
return y;
|
||||
}
|
||||
}
|
@ -1,37 +0,0 @@
|
||||
package me.aias.util;
|
||||
|
||||
import ai.djl.Device;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.repository.zoo.Criteria;
|
||||
import ai.djl.training.util.ProgressBar;
|
||||
|
||||
public final class StyleTransfer {
|
||||
|
||||
public StyleTransfer() {}
|
||||
|
||||
public enum Artist {
|
||||
CEZANNE, // 塞尚(Paul Cezanne, 1838~1906)
|
||||
MONET, // 莫奈 (Claude monet, 1840~1926)
|
||||
UKIYOE, // 日本浮世绘
|
||||
VANGOGH // 梵高 (Vincent Willem van Gogh, 1853~1890)
|
||||
}
|
||||
|
||||
public Criteria<Image, Image> criteria(Artist artist) {
|
||||
|
||||
String modelName = "style_" + artist.toString().toLowerCase() + ".zip";
|
||||
String modelUrl =
|
||||
"https://aias-home.oss-cn-beijing.aliyuncs.com/models/gan_models/" + modelName;
|
||||
|
||||
Criteria<Image, Image> criteria =
|
||||
Criteria.builder()
|
||||
.setTypes(Image.class, Image.class)
|
||||
.optEngine("PyTorch") // Use PyTorch engine
|
||||
.optModelUrls(modelUrl)
|
||||
.optProgress(new ProgressBar())
|
||||
.optDevice(Device.cpu())
|
||||
.optTranslatorFactory(new StyleTransferTranslatorFactory())
|
||||
.build();
|
||||
|
||||
return criteria;
|
||||
}
|
||||
}
|
@ -1,35 +0,0 @@
|
||||
package me.aias.util;
|
||||
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDArrays;
|
||||
import ai.djl.ndarray.NDList;
|
||||
import ai.djl.ndarray.types.DataType;
|
||||
import ai.djl.translate.Batchifier;
|
||||
import ai.djl.translate.Translator;
|
||||
import ai.djl.translate.TranslatorContext;
|
||||
|
||||
public class StyleTransferTranslator implements Translator<Image, Image> {
|
||||
|
||||
@Override
|
||||
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||
NDArray image = switchFormat(input.toNDArray(ctx.getNDManager())).expandDims(0);
|
||||
return new NDList(image.toType(DataType.FLOAT32, false));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image processOutput(TranslatorContext ctx, NDList list) {
|
||||
NDArray output = list.get(0).addi(1).muli(128).toType(DataType.UINT8, false);
|
||||
return ImageFactory.getInstance().fromNDArray(output.squeeze());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Batchifier getBatchifier() {
|
||||
return null;
|
||||
}
|
||||
|
||||
private NDArray switchFormat(NDArray array) {
|
||||
return NDArrays.stack(array.split(3, 2)).squeeze();
|
||||
}
|
||||
}
|
@ -1,31 +0,0 @@
|
||||
package me.aias.util;
|
||||
|
||||
import ai.djl.Model;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import ai.djl.translate.Translator;
|
||||
import ai.djl.translate.TranslatorFactory;
|
||||
import ai.djl.util.Pair;
|
||||
|
||||
import java.lang.reflect.Type;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
public class StyleTransferTranslatorFactory implements TranslatorFactory {
|
||||
|
||||
@Override
|
||||
public Set<Pair<Type, Type>> getSupportedTypes() {
|
||||
return Collections.singleton(new Pair<>(Image.class, Image.class));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Translator<?, ?> newInstance(
|
||||
Class<?> input, Class<?> output, Model model, Map<String, ?> arguments)
|
||||
throws TranslateException {
|
||||
if (!isSupported(input, output)) {
|
||||
throw new IllegalArgumentException("Unsupported input/output types.");
|
||||
}
|
||||
return new StyleTransferTranslator();
|
||||
}
|
||||
}
|
@ -1,17 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Configuration status="INFO">
|
||||
<Appenders>
|
||||
<Console name="console" target="SYSTEM_OUT">
|
||||
<PatternLayout
|
||||
pattern="[%-5level] - %msg%n"/>
|
||||
</Console>
|
||||
</Appenders>
|
||||
<Loggers>
|
||||
<Root level="info" additivity="false">
|
||||
<AppenderRef ref="console"/>
|
||||
</Root>
|
||||
<Logger name="me.calvin" level="${sys:me.calvin.logging.level:-info}" additivity="false">
|
||||
<AppenderRef ref="console"/>
|
||||
</Logger>
|
||||
</Loggers>
|
||||
</Configuration>
|
Before Width: | Height: | Size: 47 KiB |
@ -1,43 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module org.jetbrains.idea.maven.project.MavenProjectsManager.isMavenModule="true" type="JAVA_MODULE" version="4">
|
||||
<component name="CheckStyle-IDEA-Module">
|
||||
<option name="configuration">
|
||||
<map />
|
||||
</option>
|
||||
</component>
|
||||
<component name="NewModuleRootManager" LANGUAGE_LEVEL="JDK_1_8">
|
||||
<output url="file://$MODULE_DIR$/target/classes" />
|
||||
<output-test url="file://$MODULE_DIR$/target/test-classes" />
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<sourceFolder url="file://$MODULE_DIR$/src/main/java" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/src/main/resources" type="java-resource" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/src/test/java" isTestSource="true" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/src/test/resources" type="java-test-resource" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/target" />
|
||||
</content>
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
<orderEntry type="library" name="aias-style-transfer-lib-0.1.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.code.gson:gson:2.8.6" level="project" />
|
||||
<orderEntry type="library" name="Maven: commons-cli:commons-cli:1.4" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.logging.log4j:log4j-slf4j-impl:2.17.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.slf4j:slf4j-api:1.7.25" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.logging.log4j:log4j-api:2.17.2" level="project" />
|
||||
<orderEntry type="library" scope="RUNTIME" name="Maven: org.apache.logging.log4j:log4j-core:2.17.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: ai.djl:api:0.17.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: net.java.dev.jna:jna:5.10.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.21" level="project" />
|
||||
<orderEntry type="library" name="Maven: ai.djl:basicdataset:0.17.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.commons:commons-csv:1.9.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: ai.djl:model-zoo:0.17.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: ai.djl.pytorch:pytorch-engine:0.17.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: ai.djl.pytorch:pytorch-model-zoo:0.17.0" level="project" />
|
||||
<orderEntry type="library" scope="PROVIDED" name="Maven: org.projectlombok:lombok:1.18.18" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: org.testng:testng:6.8.1" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: junit:junit:4.10" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: org.hamcrest:hamcrest-core:1.1" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: org.beanshell:bsh:2.0b4" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: com.beust:jcommander:1.27" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: org.yaml:snakeyaml:1.6" level="project" />
|
||||
</component>
|
||||
</module>
|
After Width: | Height: | Size: 849 KiB |