mirror of
https://gitee.com/mymagicpower/AIAS.git
synced 2024-12-02 04:08:21 +08:00
no message
This commit is contained in:
parent
b1c865da4b
commit
626ed23697
2
.gitignore
vendored
2
.gitignore
vendored
@ -2,3 +2,5 @@
|
||||
3_audio_sdks/ndarray_audio_sdk/target/
|
||||
3_audio_sdks/asr_long_audio_sdk/.idea/
|
||||
3_audio_sdks/asr_long_audio_sdk/models/
|
||||
1_image_sdks/text_recognition/ocr_sdk/.idea/
|
||||
1_image_sdks/text_recognition/.idea/
|
||||
|
@ -1,75 +0,0 @@
|
||||
### 官网:
|
||||
[官网链接](http://www.aias.top/)
|
||||
|
||||
|
||||
### 下载模型,放置于models目录
|
||||
- 链接: https://pan.baidu.com/s/19NaYJ55FiqFDL8_NUOfCnA?pwd=xff8
|
||||
|
||||
### 人脸特征提取与比对SDK
|
||||
#### 人脸识别
|
||||
广义的人脸识别实际包括构建人脸识别系统的一系列相关技术,包括人脸图像采集、人脸定位、人脸识别预处理、身份确认以及身份查找等;
|
||||
而狭义的人脸识别特指通过人脸进行身份确认或者身份查找的技术或系统。
|
||||
人脸识别是一项热门的计算机技术研究领域,它属于生物特征识别技术,是对生物体(一般特指人)本身的生物特征来区分生物体个体。
|
||||
生物特征识别技术所研究的生物特征包括脸、指纹、手掌纹、虹膜、视网膜、声音(语音)、体形、个人习惯(例如敲击键盘的力度和频率、签字)等,
|
||||
相应的识别技术就有人脸识别、指纹识别、掌纹识别、虹膜识别、视网膜识别、语音识别(用语音识别可以进行身份识别,也可以进行语音内容的识别,
|
||||
只有前者属于生物特征识别技术)、体形识别、键盘敲击识别、签字识别等。
|
||||
|
||||
#### 行业现状
|
||||
人脸识别技术目前已经广泛应用于包括人脸门禁系统、刷脸支付等各行各业。随着人脸识别技术的提升,应用越来越广泛。目前中国的人脸识
|
||||
别技术已经在世界水平上处于领先地位,在安防行业,国内主流安防厂家也都推出了各自的人脸识别产品和解决方案,泛安防行业是人脸识别技术主要应用领域。
|
||||
|
||||
#### 技术发展趋势
|
||||
目前人脸识别技术广泛采用的是基于神经网络的深度学习模型。利用深度学习提取出的人脸特征,相比于传统技术,能够提取更多的特征,
|
||||
更能表达人脸之间的相关性,能够显著提高算法的精度。近些年大数据技术以及算力都得到了大幅提升,而深度学习非常依赖于大数据与算力,
|
||||
这也是为什么这项技术在近几年取得突破的原因。更多更丰富的数据加入到训练模型中,意味着算法模型更加通用,更贴近现实世界。另一方面,算力的提升,
|
||||
使得模型可以有更深的层级结构,同时深度学习的理论模型本身也在不断的完善中,模型本身的优化将会极大地提高人脸识别的技术水平。
|
||||
|
||||
#### 人脸识别关键技术
|
||||
人脸识别涉及的关键技术包含:人脸检测,人脸关键点,人脸特征提取,人脸比对,人脸对齐。
|
||||
![face_sdk](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/face_sdk/images/face_sdk.png)
|
||||
|
||||
本文的例子给出了人脸特征提取,人脸比对的参考实现。
|
||||
####人脸特征提取:
|
||||
模型推理例子: FeatureExtractionExample
|
||||
|
||||
####人脸特征比对:
|
||||
人脸比对例子: FeatureComparisonExample
|
||||
|
||||
|
||||
#### 运行人脸特征提取的例子 - FeatureExtractionExample
|
||||
运行成功后,命令行应该看到下面的信息:
|
||||
```text
|
||||
[INFO ] - Face feature: [-0.04026184, -0.019486362, -0.09802659, 0.01700999, 0.037829027, ...]
|
||||
```
|
||||
|
||||
#### 运行人脸特征比对的例子 - FeatureComparisonExample
|
||||
`src/test/resources/kana1.jpg`
|
||||
![kana1](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/face_sdk/images/kana1.jpg)
|
||||
`src/test/resources/kana2.jpg`
|
||||
![kana2](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/face_sdk/images/kana2.jpg)
|
||||
|
||||
运行成功后,命令行应该看到下面的信息:
|
||||
比对使用的是欧式距离的计算方式。
|
||||
|
||||
```text
|
||||
[INFO ] - face1 feature: [-0.040261842, -0.019486364, ..., 0.031147916, -0.032064643]
|
||||
[INFO ] - face2 feature: [-0.049654193, -0.04029847, ..., 0.04562381, -0.044428844]
|
||||
[INFO ] - 相似度: 0.9022608
|
||||
```
|
||||
|
||||
### 开源算法
|
||||
#### 1. sdk使用的开源算法
|
||||
- [facenet-pytorch](https://github.com/timesler/facenet-pytorch)
|
||||
|
||||
#### 2. 模型如何导出 ?
|
||||
- [how_to_convert_your_model_to_torchscript](http://docs.djl.ai/docs/pytorch/how_to_convert_your_model_to_torchscript.html)
|
||||
|
||||
|
||||
### 其它帮助信息
|
||||
http://aias.top/guides.html
|
||||
|
||||
|
||||
### Git地址:
|
||||
[Github链接](https://github.com/mymagicpower/AIAS)
|
||||
[Gitee链接](https://gitee.com/mymagicpower/AIAS)
|
||||
|
@ -1,42 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module org.jetbrains.idea.maven.project.MavenProjectsManager.isMavenModule="true" type="JAVA_MODULE" version="4">
|
||||
<component name="CheckStyle-IDEA-Module">
|
||||
<option name="configuration">
|
||||
<map />
|
||||
</option>
|
||||
</component>
|
||||
<component name="NewModuleRootManager" LANGUAGE_LEVEL="JDK_1_8">
|
||||
<output url="file://$MODULE_DIR$/target/classes" />
|
||||
<output-test url="file://$MODULE_DIR$/target/test-classes" />
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<sourceFolder url="file://$MODULE_DIR$/src/main/java" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/src/main/resources" type="java-resource" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/src/test/java" isTestSource="true" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/src/test/resources" type="java-test-resource" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/target" />
|
||||
</content>
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
<orderEntry type="library" name="aais-face-feature-lib-0.1.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: commons-cli:commons-cli:1.4" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.logging.log4j:log4j-slf4j-impl:2.17.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.slf4j:slf4j-api:1.7.25" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.logging.log4j:log4j-api:2.17.2" level="project" />
|
||||
<orderEntry type="library" scope="RUNTIME" name="Maven: org.apache.logging.log4j:log4j-core:2.17.2" level="project" />
|
||||
<orderEntry type="library" name="Maven: com.google.code.gson:gson:2.8.5" level="project" />
|
||||
<orderEntry type="library" name="Maven: ai.djl:api:0.17.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: net.java.dev.jna:jna:5.10.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.commons:commons-compress:1.21" level="project" />
|
||||
<orderEntry type="library" name="Maven: ai.djl:basicdataset:0.17.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: org.apache.commons:commons-csv:1.9.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: ai.djl:model-zoo:0.17.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: ai.djl.pytorch:pytorch-engine:0.17.0" level="project" />
|
||||
<orderEntry type="library" name="Maven: ai.djl.pytorch:pytorch-model-zoo:0.17.0" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: org.testng:testng:6.8.1" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: junit:junit:4.10" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: org.hamcrest:hamcrest-core:1.1" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: org.beanshell:bsh:2.0b4" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: com.beust:jcommander:1.27" level="project" />
|
||||
<orderEntry type="library" scope="TEST" name="Maven: org.yaml:snakeyaml:1.6" level="project" />
|
||||
</component>
|
||||
</module>
|
@ -1,60 +0,0 @@
|
||||
|
||||
#### Common Model Loading Methods
|
||||
|
||||
|
||||
1. How to load a model online via URL?
|
||||
```text
|
||||
# Use optModelUrls to load a model via URL
|
||||
|
||||
Criteria<Image, DetectedObjects> criteria =
|
||||
Criteria.builder()
|
||||
.optEngine("PaddlePaddle")
|
||||
.setTypes(Image.class, DetectedObjects.class)
|
||||
.optModelUrls("https://aias-home.oss-cn-beijing.aliyuncs.com/models/ocr_models/ch_ppocr_mobile_v2.0_det_infer.zip")
|
||||
.optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap<String, String>()))
|
||||
.optProgress(new ProgressBar())
|
||||
.build();
|
||||
```
|
||||
|
||||
2. How to load a model locally?
|
||||
```text
|
||||
# Use optModelPath to load a model from a zipped file
|
||||
Path modelPath = Paths.get("src/test/resources/ch_ppocr_mobile_v2.0_det_infer.zip");
|
||||
Criteria<Image, DetectedObjects> criteria =
|
||||
Criteria.builder()
|
||||
.optEngine("PaddlePaddle")
|
||||
.setTypes(Image.class, DetectedObjects.class)
|
||||
.optModelPath(modelPath)
|
||||
.optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap<String, String>()))
|
||||
.optProgress(new ProgressBar())
|
||||
.build();
|
||||
|
||||
# Use optModelPath to load a model from a local directory
|
||||
Path modelPath = Paths.get("src/test/resources/ch_ppocr_mobile_v2.0_det_infer/");
|
||||
Criteria<Image, DetectedObjects> criteria =
|
||||
Criteria.builder()
|
||||
.optEngine("PaddlePaddle")
|
||||
.setTypes(Image.class, DetectedObjects.class)
|
||||
.optModelPath(modelPath)
|
||||
.optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap<String, String>()))
|
||||
.optProgress(new ProgressBar())
|
||||
.build();
|
||||
```
|
||||
|
||||
3. How to load a model packed into a JAR file?
|
||||
```text
|
||||
# Use optModelUrls to load a model
|
||||
# Assuming the model is located in the JAR file at:
|
||||
# BOOT-INF/classes/ch_ppocr_mobile_v2.0_det_infer.zip
|
||||
|
||||
Criteria<Image, DetectedObjects> criteria =
|
||||
Criteria.builder()
|
||||
.optEngine("PaddlePaddle")
|
||||
.setTypes(Image.class, DetectedObjects.class)
|
||||
.optModelUrls("jar:///ch_ppocr_mobile_v2.0_det_infer.zip")
|
||||
.optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap<String, String>()))
|
||||
.optProgress(new ProgressBar())
|
||||
.build();
|
||||
```
|
||||
|
||||
|
@ -1,101 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
~ Licensed to the Apache Software Foundation (ASF) under one
|
||||
~ or more contributor license agreements. See the NOTICE file
|
||||
~ distributed with this work for additional information
|
||||
~ regarding copyright ownership. The ASF licenses this file
|
||||
~ to you under the Apache License, Version 2.0 (the
|
||||
~ "License"); you may not use this file except in compliance
|
||||
~ with the License. You may obtain a copy of the License at
|
||||
~
|
||||
~ http://www.apache.org/licenses/LICENSE-2.0
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing,
|
||||
~ software distributed under the License is distributed on an
|
||||
~ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
~ KIND, either express or implied. See the License for the
|
||||
~ specific language governing permissions and limitations
|
||||
~ under the License.
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>aias</groupId>
|
||||
<artifactId>face-feature-sdk</artifactId>
|
||||
<version>0.17.0</version>
|
||||
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<maven.compiler.source>1.8</maven.compiler.source>
|
||||
<maven.compiler.target>1.8</maven.compiler.target>
|
||||
<djl.version>0.17.0</djl.version>
|
||||
</properties>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-compiler-plugin</artifactId>
|
||||
<configuration>
|
||||
<source>8</source>
|
||||
<target>8</target>
|
||||
</configuration>
|
||||
<version>3.8.1</version>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>commons-cli</groupId>
|
||||
<artifactId>commons-cli</artifactId>
|
||||
<version>1.4</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.logging.log4j</groupId>
|
||||
<artifactId>log4j-slf4j-impl</artifactId>
|
||||
<version>2.17.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.code.gson</groupId>
|
||||
<artifactId>gson</artifactId>
|
||||
<version>2.8.5</version>
|
||||
</dependency>
|
||||
<!-- 服务器端推理引擎 -->
|
||||
<dependency>
|
||||
<groupId>ai.djl</groupId>
|
||||
<artifactId>api</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.djl</groupId>
|
||||
<artifactId>basicdataset</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.djl</groupId>
|
||||
<artifactId>model-zoo</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
<!-- Pytorch -->
|
||||
<dependency>
|
||||
<groupId>ai.djl.pytorch</groupId>
|
||||
<artifactId>pytorch-engine</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.djl.pytorch</groupId>
|
||||
<artifactId>pytorch-model-zoo</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.testng</groupId>
|
||||
<artifactId>testng</artifactId>
|
||||
<version>6.8.1</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
@ -1,52 +0,0 @@
|
||||
package me.aias;
|
||||
|
||||
import ai.djl.ModelException;
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.repository.zoo.ModelZoo;
|
||||
import ai.djl.repository.zoo.ZooModel;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import me.aias.util.FaceFeature;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* 人脸比对 - 1:1.
|
||||
*
|
||||
* @author Calvin
|
||||
* @email 179209347@qq.com
|
||||
* @website www.aias.top
|
||||
*/
|
||||
public final class FeatureComparisonExample {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(FeatureComparisonExample.class);
|
||||
|
||||
private FeatureComparisonExample() {
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws IOException, ModelException, TranslateException {
|
||||
|
||||
Path imageFile1 = Paths.get("src/test/resources/kana1.jpg");
|
||||
Image img1 = ImageFactory.getInstance().fromFile(imageFile1);
|
||||
Path imageFile2 = Paths.get("src/test/resources/kana2.jpg");
|
||||
Image img2 = ImageFactory.getInstance().fromFile(imageFile2);
|
||||
|
||||
FaceFeature faceFeature = new FaceFeature();
|
||||
try (ZooModel<Image, float[]> model = ModelZoo.loadModel(faceFeature.criteria());
|
||||
Predictor<Image, float[]> predictor = model.newPredictor()) {
|
||||
|
||||
float[] feature1 = predictor.predict(img1);
|
||||
logger.info("face1 feature: " + Arrays.toString(feature1));
|
||||
float[] feature2 = predictor.predict(img2);
|
||||
logger.info("face2 feature: " + Arrays.toString(feature2));
|
||||
|
||||
logger.info("相似度: "+ Float.toString(faceFeature.calculSimilar(feature1, feature2)));
|
||||
}
|
||||
}
|
||||
}
|
@ -1,46 +0,0 @@
|
||||
package me.aias;
|
||||
|
||||
import ai.djl.ModelException;
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.repository.zoo.ModelZoo;
|
||||
import ai.djl.repository.zoo.ZooModel;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import me.aias.util.FaceFeature;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* 人脸特征提取
|
||||
*
|
||||
* @author Calvin
|
||||
* @email 179209347@qq.com
|
||||
* @website www.aias.top
|
||||
*/
|
||||
public final class FeatureExtractionExample {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(FeatureExtractionExample.class);
|
||||
|
||||
private FeatureExtractionExample() {
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws IOException, ModelException, TranslateException {
|
||||
Path imageFile = Paths.get("src/test/resources/kana1.jpg");
|
||||
Image img = ImageFactory.getInstance().fromFile(imageFile);
|
||||
|
||||
FaceFeature faceFeature = new FaceFeature();
|
||||
try (ZooModel<Image, float[]> model = ModelZoo.loadModel(faceFeature.criteria());
|
||||
Predictor<Image, float[]> predictor = model.newPredictor()) {
|
||||
float[] feature = predictor.predict(img);
|
||||
if (feature != null) {
|
||||
logger.info("Face feature: " + Arrays.toString(feature));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,39 +0,0 @@
|
||||
package me.aias.util;
|
||||
|
||||
import ai.djl.Device;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.repository.zoo.Criteria;
|
||||
import ai.djl.training.util.ProgressBar;
|
||||
|
||||
public final class FaceFeature {
|
||||
|
||||
public FaceFeature() {}
|
||||
|
||||
public Criteria<Image, float[]> criteria() {
|
||||
Criteria<Image, float[]> criteria =
|
||||
Criteria.builder()
|
||||
.setTypes(Image.class, float[].class)
|
||||
.optModelPath(Paths.get("models/face_feature.zip"))
|
||||
.optModelName("face_feature")
|
||||
.optTranslator(new FaceFeatureTranslator())
|
||||
.optProgress(new ProgressBar())
|
||||
.optEngine("PyTorch") // Use PyTorch engine
|
||||
.optDevice(Device.cpu())
|
||||
.build();
|
||||
|
||||
return criteria;
|
||||
}
|
||||
|
||||
public float calculSimilar(float[] feature1, float[] feature2) {
|
||||
float ret = 0.0f;
|
||||
float mod1 = 0.0f;
|
||||
float mod2 = 0.0f;
|
||||
int length = feature1.length;
|
||||
for (int i = 0; i < length; ++i) {
|
||||
ret += feature1[i] * feature2[i];
|
||||
mod1 += feature1[i] * feature1[i];
|
||||
mod2 += feature2[i] * feature2[i];
|
||||
}
|
||||
return (float) ((ret / Math.sqrt(mod1) / Math.sqrt(mod2) + 1) / 2.0f);
|
||||
}
|
||||
}
|
@ -1,54 +0,0 @@
|
||||
package me.aias.util;
|
||||
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.transform.Normalize;
|
||||
import ai.djl.modality.cv.transform.ToTensor;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDList;
|
||||
import ai.djl.translate.Batchifier;
|
||||
import ai.djl.translate.Pipeline;
|
||||
import ai.djl.translate.Translator;
|
||||
import ai.djl.translate.TranslatorContext;
|
||||
|
||||
public final class FaceFeatureTranslator implements Translator<Image, float[]> {
|
||||
|
||||
public FaceFeatureTranslator() {}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||
NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);
|
||||
Pipeline pipeline = new Pipeline();
|
||||
pipeline
|
||||
// .add(new Resize(160))
|
||||
.add(new ToTensor())
|
||||
.add(
|
||||
new Normalize(
|
||||
new float[] {127.5f / 255.0f, 127.5f / 255.0f, 127.5f / 255.0f},
|
||||
new float[] {128.0f / 255.0f, 128.0f / 255.0f, 128.0f / 255.0f}));
|
||||
|
||||
return pipeline.transform(new NDList(array));
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public float[] processOutput(TranslatorContext ctx, NDList list) {
|
||||
NDList result = new NDList();
|
||||
long numOutputs = list.singletonOrThrow().getShape().get(0);
|
||||
for (int i = 0; i < numOutputs; i++) {
|
||||
result.add(list.singletonOrThrow().get(i));
|
||||
}
|
||||
float[][] embeddings = result.stream().map(NDArray::toFloatArray).toArray(float[][]::new);
|
||||
float[] feature = new float[embeddings.length];
|
||||
for (int i = 0; i < embeddings.length; i++) {
|
||||
feature[i] = embeddings[i][0];
|
||||
}
|
||||
return feature;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public Batchifier getBatchifier() {
|
||||
return Batchifier.STACK;
|
||||
}
|
||||
}
|
@ -1,17 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Configuration status="INFO">
|
||||
<Appenders>
|
||||
<Console name="console" target="SYSTEM_OUT">
|
||||
<PatternLayout
|
||||
pattern="[%-5level] - %msg%n"/>
|
||||
</Console>
|
||||
</Appenders>
|
||||
<Loggers>
|
||||
<Root level="info" additivity="false">
|
||||
<AppenderRef ref="console"/>
|
||||
</Root>
|
||||
<Logger name="me.calvin" level="${sys:me.calvin.logging.level:-info}" additivity="false">
|
||||
<AppenderRef ref="console"/>
|
||||
</Logger>
|
||||
</Loggers>
|
||||
</Configuration>
|
Binary file not shown.
Before Width: | Height: | Size: 246 KiB |
Binary file not shown.
Before Width: | Height: | Size: 50 KiB |
Binary file not shown.
Before Width: | Height: | Size: 41 KiB |
Binary file not shown.
Before Width: | Height: | Size: 463 KiB |
@ -1,38 +1,62 @@
|
||||
### 官网:
|
||||
[官网链接](https://www.aias.top/)
|
||||
|
||||
### Download the model, place it in the models directory, and unzip
|
||||
- Link: https://github.com/mymagicpower/AIAS/releases/download/apps/text_recognition_models.zip
|
||||
### 下载模型,放置于models目录
|
||||
- 链接: https://pan.baidu.com/s/1AGKdyvVeRONOhAHu-Ot7RA?pwd=3m2f
|
||||
|
||||
## Text recognition (OCR) toolbox
|
||||
## 文字识别(OCR)工具箱
|
||||
文字识别(OCR)目前在多个行业中得到了广泛应用,比如金融行业的单据识别输入,餐饮行业中的发票识别,
|
||||
交通领域的车票识别,企业中各种表单识别,以及日常工作生活中常用的身份证,驾驶证,护照识别等等。
|
||||
OCR(文字识别)是目前常用的一种AI能力。
|
||||
|
||||
Text recognition (OCR) is currently widely used in multiple industries, such as document recognition input in the financial industry, invoice recognition in the catering industry, ticket recognition in the transportation field, various form recognition in enterprises, and identification card, driver's license, passport recognition commonly used in daily work and life. OCR (text recognition) is a commonly used AI capability.
|
||||
|
||||
### OCR Toolbox Functions:
|
||||
|
||||
### 1. Direction detection
|
||||
|
||||
- OcrDirectionExample
|
||||
- 0 degrees
|
||||
- 90 degrees
|
||||
- 180 degrees
|
||||
- 270 degrees
|
||||
|
||||
![detect_direction](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/OCR/images/detect_direction.png)
|
||||
|
||||
### 2. Image rotation
|
||||
|
||||
- RotationExample
|
||||
|
||||
### 3. Text recognition (1 & 2 can be used as auxiliary when needed)
|
||||
### OCR工具箱功能:
|
||||
|
||||
#### 文字识别
|
||||
- OcrV3RecognitionExample
|
||||
|
||||
|
||||
### 2. Image rotation:
|
||||
Each call to the rotateImg method rotates the image counterclockwise by 90 degrees.
|
||||
### 运行OCR识别例子
|
||||
#### 1.1 文字识别:
|
||||
- 例子代码: OcrV3RecognitionExample.java
|
||||
- 运行成功后,命令行应该看到下面的信息:
|
||||
```text
|
||||
检票:B1
|
||||
Z31C014941
|
||||
九江站
|
||||
南昌站
|
||||
D6262
|
||||
Nanchang
|
||||
Jiujiang
|
||||
03车02A号
|
||||
2019年06月07日06:56开
|
||||
二等座
|
||||
网折
|
||||
¥39.5元
|
||||
折
|
||||
限乘当日当次车
|
||||
3604211990****2417
|
||||
买票请到12306发货请到95306
|
||||
中国铁路祝您旅途愉快
|
||||
32270300310607C014941上海南售
|
||||
time: 790
|
||||
```
|
||||
|
||||
- Example code: RotationExample.java
|
||||
- Image before rotation:
|
||||
-
|
||||
![ticket_0](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/OCR/images/ticket_0.png)
|
||||
- The resulting image after rotation is as follows:
|
||||
![rotate_result](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/OCR/images/rotate_result.png)
|
||||
- 输出图片效果如下:
|
||||
![texts](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/OCR/images/texts_result.png)
|
||||
|
||||
|
||||
### Git地址:
|
||||
[Github链接](https://github.com/mymagicpower/AIAS)
|
||||
[Gitee链接](https://gitee.com/mymagicpower/AIAS)
|
||||
|
||||
|
||||
#### 帮助文档:
|
||||
- https://aias.top/guides.html
|
||||
- 1.性能优化常见问题:
|
||||
- https://aias.top/AIAS/guides/performance.html
|
||||
- 2.引擎配置(包括CPU,GPU在线自动加载,及本地配置):
|
||||
- https://aias.top/AIAS/guides/engine_config.html
|
||||
- 3.模型加载方式(在线自动加载,及本地配置):
|
||||
- https://aias.top/AIAS/guides/load_model.html
|
||||
- 4.Windows环境常见问题:
|
||||
- https://aias.top/AIAS/guides/windows.html
|
@ -1,82 +0,0 @@
|
||||
### 官网:
|
||||
[官网链接](https://www.aias.top/)
|
||||
|
||||
### 下载模型,放置于models目录
|
||||
- 链接: https://pan.baidu.com/s/1AGKdyvVeRONOhAHu-Ot7RA?pwd=3m2f
|
||||
|
||||
## 文字识别(OCR)工具箱
|
||||
文字识别(OCR)目前在多个行业中得到了广泛应用,比如金融行业的单据识别输入,餐饮行业中的发票识别,
|
||||
交通领域的车票识别,企业中各种表单识别,以及日常工作生活中常用的身份证,驾驶证,护照识别等等。
|
||||
OCR(文字识别)是目前常用的一种AI能力。
|
||||
|
||||
### OCR工具箱功能:
|
||||
|
||||
#### 1. 方向检测
|
||||
- OcrDirectionExample
|
||||
- 0度
|
||||
- 90度
|
||||
- 180度
|
||||
- 270度
|
||||
![detect_direction](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/OCR/images/detect_direction.png)
|
||||
|
||||
#### 2. 图片旋转
|
||||
- RotationExample
|
||||
|
||||
#### 3. 文字识别 (1 & 2 需要时可以作为辅助)
|
||||
- OcrV3RecognitionExample
|
||||
|
||||
|
||||
### 运行OCR识别例子
|
||||
#### 1.1 文字识别:
|
||||
- 例子代码: OcrV3RecognitionExample.java
|
||||
- 运行成功后,命令行应该看到下面的信息:
|
||||
```text
|
||||
检票:B1
|
||||
Z31C014941
|
||||
九江站
|
||||
南昌站
|
||||
D6262
|
||||
Nanchang
|
||||
Jiujiang
|
||||
03车02A号
|
||||
2019年06月07日06:56开
|
||||
二等座
|
||||
网折
|
||||
¥39.5元
|
||||
折
|
||||
限乘当日当次车
|
||||
3604211990****2417
|
||||
买票请到12306发货请到95306
|
||||
中国铁路祝您旅途愉快
|
||||
32270300310607C014941上海南售
|
||||
time: 790
|
||||
```
|
||||
|
||||
- 输出图片效果如下:
|
||||
![texts](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/OCR/images/texts_result.png)
|
||||
|
||||
|
||||
#### 2. 图片旋转:
|
||||
每调用一次rotateImg方法,会使图片逆时针旋转90度。
|
||||
- 例子代码: RotationExample.java
|
||||
- 旋转前图片:
|
||||
![ticket_0](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/OCR/images/ticket_0.png)
|
||||
- 旋转后图片效果如下:
|
||||
![rotate_result](https://aias-home.oss-cn-beijing.aliyuncs.com/AIAS/OCR/images/rotate_result.png)
|
||||
|
||||
|
||||
### Git地址:
|
||||
[Github链接](https://github.com/mymagicpower/AIAS)
|
||||
[Gitee链接](https://gitee.com/mymagicpower/AIAS)
|
||||
|
||||
|
||||
#### 帮助文档:
|
||||
- https://aias.top/guides.html
|
||||
- 1.性能优化常见问题:
|
||||
- https://aias.top/AIAS/guides/performance.html
|
||||
- 2.引擎配置(包括CPU,GPU在线自动加载,及本地配置):
|
||||
- https://aias.top/AIAS/guides/engine_config.html
|
||||
- 3.模型加载方式(在线自动加载,及本地配置):
|
||||
- https://aias.top/AIAS/guides/load_model.html
|
||||
- 4.Windows环境常见问题:
|
||||
- https://aias.top/AIAS/guides/windows.html
|
@ -101,7 +101,6 @@
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- onnx 没有自己的 ndarry 所以需要引用一个 pytorch 或者 mxnet engine, pytorch 实践发现算子支持的更好-->
|
||||
<dependency>
|
||||
<groupId>ai.djl.pytorch</groupId>
|
||||
<artifactId>pytorch-engine</artifactId>
|
||||
@ -115,11 +114,6 @@
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>javacv-platform</artifactId>
|
||||
<version>1.5.7</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.djl.opencv</groupId>
|
||||
<artifactId>opencv</artifactId>
|
||||
|
@ -1,58 +0,0 @@
|
||||
package me.aias.example;
|
||||
|
||||
import ai.djl.ModelException;
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.modality.cv.output.DetectedObjects;
|
||||
import ai.djl.repository.zoo.ModelZoo;
|
||||
import ai.djl.repository.zoo.ZooModel;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import me.aias.example.cls.OcrDirectionDetection;
|
||||
import me.aias.example.common.DirectionInfo;
|
||||
import me.aias.example.common.ImageUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* OCR文字方向检测(轻量级模型).
|
||||
*
|
||||
* OCR text direction detection (light model)
|
||||
*
|
||||
* @author Calvin
|
||||
* @date 2021-10-04
|
||||
* @email 179209347@qq.com
|
||||
*/
|
||||
public final class OcrDirectionExample {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(OcrDirectionExample.class);
|
||||
|
||||
private OcrDirectionExample() {}
|
||||
|
||||
public static void main(String[] args) throws IOException, ModelException, TranslateException {
|
||||
Path imageFile = Paths.get("src/test/resources/ticket_90.png");
|
||||
Image image = ImageFactory.getInstance().fromFile(imageFile);
|
||||
|
||||
OcrDirectionDetection detection = new OcrDirectionDetection();
|
||||
try (ZooModel detectionModel = ModelZoo.loadModel(detection.detectCriteria());
|
||||
Predictor<Image, DetectedObjects> detector = detectionModel.newPredictor();
|
||||
ZooModel rotateModel = ModelZoo.loadModel(detection.clsCriteria());
|
||||
Predictor<Image, DirectionInfo> rotateClassifier = rotateModel.newPredictor()) {
|
||||
|
||||
DetectedObjects detections = detection.predict(image,detector,rotateClassifier);
|
||||
|
||||
List<DetectedObjects.DetectedObject> boxes = detections.items();
|
||||
for (DetectedObjects.DetectedObject result : boxes) {
|
||||
System.out.println(result.getClassName() + " : " + result.getProbability());
|
||||
}
|
||||
|
||||
ImageUtils.saveBoundingBoxImage(image, detections, "cls_detect_result.png", "build/output");
|
||||
logger.info("{}", detections);
|
||||
}
|
||||
}
|
||||
}
|
@ -3,7 +3,6 @@ package me.aias.example;
|
||||
import ai.djl.ModelException;
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.output.BoundingBox;
|
||||
import ai.djl.modality.cv.output.DetectedObjects;
|
||||
import ai.djl.modality.cv.output.Rectangle;
|
||||
import ai.djl.opencv.OpenCVImageFactory;
|
||||
@ -12,18 +11,15 @@ import ai.djl.repository.zoo.ZooModel;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import me.aias.example.common.ImageUtils;
|
||||
import me.aias.example.detection.OcrV3Detection;
|
||||
import me.aias.example.opencv.OpenCVUtils;
|
||||
import me.aias.example.recognition.OcrV3Recognition;
|
||||
import org.opencv.core.Mat;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
@ -62,7 +58,7 @@ public final class OcrV3RecognitionExample {
|
||||
|
||||
// 转 BufferedImage 解决中文乱码问题
|
||||
Mat wrappedImage = (Mat) image.getWrappedImage();
|
||||
BufferedImage bufferedImage = OpenCVUtils.mat2Image(wrappedImage);
|
||||
BufferedImage bufferedImage = ImageUtils.mat2Image(wrappedImage);
|
||||
for (DetectedObjects.DetectedObject item : items) {
|
||||
Rectangle rectangle = item.getBoundingBox().getBounds();
|
||||
int x = (int) (rectangle.getX() * image.getWidth());
|
||||
@ -74,7 +70,7 @@ public final class OcrV3RecognitionExample {
|
||||
ImageUtils.drawImageText(bufferedImage, item.getClassName(), x, y);
|
||||
}
|
||||
|
||||
Mat image2Mat = OpenCVUtils.image2Mat(bufferedImage);
|
||||
Mat image2Mat = ImageUtils.image2Mat(bufferedImage);
|
||||
image = OpenCVImageFactory.getInstance().fromImage(image2Mat);
|
||||
ImageUtils.saveImage(image, "ocr_result.png", "build/output");
|
||||
|
||||
|
@ -1,59 +0,0 @@
|
||||
package me.aias.example;
|
||||
|
||||
import ai.djl.ModelException;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.modality.cv.util.NDImageUtils;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDManager;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
|
||||
/**
|
||||
* 图片旋转
|
||||
* Rotation Example
|
||||
*
|
||||
* @author Calvin
|
||||
* @date 2021-06-28
|
||||
* @email 179209347@qq.com
|
||||
*/
|
||||
public final class RotationExample {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(RotationExample.class);
|
||||
|
||||
private RotationExample() {
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws IOException{
|
||||
Path imageFile = Paths.get("src/test/resources/ticket_0.png");
|
||||
Image image = ImageFactory.getInstance().fromFile(imageFile);
|
||||
// 逆时针旋转
|
||||
// Counterclockwise rotation
|
||||
image = rotateImg(image);
|
||||
|
||||
saveImage(image, "rotate_result.png", "build/output");
|
||||
}
|
||||
|
||||
private static Image rotateImg(Image image) {
|
||||
try (NDManager manager = NDManager.newBaseManager()) {
|
||||
NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1);
|
||||
return ImageFactory.getInstance().fromNDArray(rotated);
|
||||
}
|
||||
}
|
||||
|
||||
public static void saveImage(Image img, String name, String path) {
|
||||
Path outputDir = Paths.get(path);
|
||||
Path imagePath = outputDir.resolve(name);
|
||||
try {
|
||||
img.save(Files.newOutputStream(imagePath), "png");
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
package me.aias.example.cls;
|
||||
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.modality.cv.output.BoundingBox;
|
||||
import ai.djl.modality.cv.output.DetectedObjects;
|
||||
import ai.djl.modality.cv.output.Rectangle;
|
||||
import ai.djl.modality.cv.util.NDImageUtils;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDManager;
|
||||
import ai.djl.repository.zoo.Criteria;
|
||||
import ai.djl.training.util.ProgressBar;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import me.aias.example.common.DirectionInfo;
|
||||
import me.aias.example.detection.PpWordDetectionTranslator;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.nio.file.Paths;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public final class OcrDirectionDetection {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(OcrDirectionDetection.class);
|
||||
|
||||
public OcrDirectionDetection() {
|
||||
}
|
||||
|
||||
public DetectedObjects predict(
|
||||
Image image,
|
||||
Predictor<Image, DetectedObjects> detector,
|
||||
Predictor<Image, DirectionInfo> rotateClassifier)
|
||||
throws TranslateException {
|
||||
DetectedObjects detections = detector.predict(image);
|
||||
|
||||
List<DetectedObjects.DetectedObject> boxes = detections.items();
|
||||
|
||||
List<String> names = new ArrayList<>();
|
||||
List<Double> prob = new ArrayList<>();
|
||||
List<BoundingBox> rect = new ArrayList<>();
|
||||
|
||||
for (int i = 0; i < boxes.size(); i++) {
|
||||
Image subImg = getSubImage(image, boxes.get(i).getBoundingBox());
|
||||
DirectionInfo result = null;
|
||||
if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) {
|
||||
subImg = rotateImg(subImg);
|
||||
result = rotateClassifier.predict(subImg);
|
||||
prob.add(result.getProb());
|
||||
if (result.getName().equalsIgnoreCase("Rotate")) {
|
||||
names.add("90");
|
||||
} else {
|
||||
names.add("270");
|
||||
}
|
||||
} else {
|
||||
result = rotateClassifier.predict(subImg);
|
||||
prob.add(result.getProb());
|
||||
if (result.getName().equalsIgnoreCase("No Rotate")) {
|
||||
names.add("0");
|
||||
} else {
|
||||
names.add("180");
|
||||
}
|
||||
}
|
||||
rect.add(boxes.get(i).getBoundingBox());
|
||||
}
|
||||
DetectedObjects detectedObjects = new DetectedObjects(names, prob, rect);
|
||||
|
||||
return detectedObjects;
|
||||
}
|
||||
|
||||
public Criteria<Image, DetectedObjects> detectCriteria() {
|
||||
Criteria<Image, DetectedObjects> criteria =
|
||||
Criteria.builder()
|
||||
.optEngine("OnnxRuntime")
|
||||
.optModelName("inference")
|
||||
.setTypes(Image.class, DetectedObjects.class)
|
||||
.optModelPath(Paths.get("models/ch_PP-OCRv2_det_infer_onnx.zip"))
|
||||
.optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap<String, String>()))
|
||||
.optProgress(new ProgressBar())
|
||||
.build();
|
||||
|
||||
return criteria;
|
||||
}
|
||||
|
||||
public Criteria<Image, DirectionInfo> clsCriteria() {
|
||||
|
||||
Criteria<Image, DirectionInfo> criteria =
|
||||
Criteria.builder()
|
||||
.optEngine("OnnxRuntime")
|
||||
.optModelName("inference")
|
||||
.setTypes(Image.class, DirectionInfo.class)
|
||||
.optModelPath(Paths.get("models/ch_ppocr_mobile_v2.0_cls_onnx.zip"))
|
||||
.optTranslator(new PpWordRotateTranslator())
|
||||
.optProgress(new ProgressBar())
|
||||
.build();
|
||||
return criteria;
|
||||
}
|
||||
|
||||
private Image getSubImage(Image img, BoundingBox box) {
|
||||
Rectangle rect = box.getBounds();
|
||||
double[] extended = extendRect(rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight());
|
||||
int width = img.getWidth();
|
||||
int height = img.getHeight();
|
||||
int[] recovered = {
|
||||
(int) (extended[0] * width),
|
||||
(int) (extended[1] * height),
|
||||
(int) (extended[2] * width),
|
||||
(int) (extended[3] * height)
|
||||
};
|
||||
return img.getSubImage(recovered[0], recovered[1], recovered[2], recovered[3]);
|
||||
}
|
||||
|
||||
private double[] extendRect(double xmin, double ymin, double width, double height) {
|
||||
double centerx = xmin + width / 2;
|
||||
double centery = ymin + height / 2;
|
||||
if (width > height) {
|
||||
width += height * 2.0;
|
||||
height *= 3.0;
|
||||
} else {
|
||||
height += width * 2.0;
|
||||
width *= 3.0;
|
||||
}
|
||||
double newX = centerx - width / 2 < 0 ? 0 : centerx - width / 2;
|
||||
double newY = centery - height / 2 < 0 ? 0 : centery - height / 2;
|
||||
double newWidth = newX + width > 1 ? 1 - newX : width;
|
||||
double newHeight = newY + height > 1 ? 1 - newY : height;
|
||||
return new double[]{newX, newY, newWidth, newHeight};
|
||||
}
|
||||
|
||||
private Image rotateImg(Image image) {
|
||||
try (NDManager manager = NDManager.newBaseManager()) {
|
||||
NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1);
|
||||
return ImageFactory.getInstance().fromNDArray(rotated);
|
||||
}
|
||||
}
|
||||
}
|
@ -1,76 +0,0 @@
|
||||
package me.aias.example.cls;
|
||||
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.util.NDImageUtils;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDList;
|
||||
import ai.djl.ndarray.index.NDIndex;
|
||||
import ai.djl.ndarray.types.Shape;
|
||||
import ai.djl.translate.Batchifier;
|
||||
import ai.djl.translate.Translator;
|
||||
import ai.djl.translate.TranslatorContext;
|
||||
import me.aias.example.common.DirectionInfo;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class PpWordRotateTranslator implements Translator<Image, DirectionInfo> {
|
||||
List<String> classes = Arrays.asList("No Rotate", "Rotate");
|
||||
|
||||
public PpWordRotateTranslator() {
|
||||
}
|
||||
|
||||
public DirectionInfo processOutput(TranslatorContext ctx, NDList list) {
|
||||
NDArray prob = list.singletonOrThrow();
|
||||
float[] res = prob.toFloatArray();
|
||||
int maxIndex = 0;
|
||||
if (res[1] > res[0]) {
|
||||
maxIndex = 1;
|
||||
}
|
||||
|
||||
return new DirectionInfo(classes.get(maxIndex), Double.valueOf(res[maxIndex]));
|
||||
}
|
||||
|
||||
// public NDList processInput2(TranslatorContext ctx, Image input){
|
||||
// NDArray img = input.toNDArray(ctx.getNDManager());
|
||||
// img = NDImageUtils.resize(img, 192, 48);
|
||||
// img = NDImageUtils.toTensor(img).sub(0.5F).div(0.5F);
|
||||
// img = img.expandDims(0);
|
||||
// return new NDList(new NDArray[]{img});
|
||||
// }
|
||||
|
||||
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||
NDArray img = input.toNDArray(ctx.getNDManager());
|
||||
int imgC = 3;
|
||||
int imgH = 48;
|
||||
int imgW = 192;
|
||||
|
||||
NDArray array = ctx.getNDManager().zeros(new Shape(imgC, imgH, imgW));
|
||||
|
||||
int h = input.getHeight();
|
||||
int w = input.getWidth();
|
||||
int resized_w = 0;
|
||||
|
||||
float ratio = (float) w / (float) h;
|
||||
if (Math.ceil(imgH * ratio) > imgW) {
|
||||
resized_w = imgW;
|
||||
} else {
|
||||
resized_w = (int) (Math.ceil(imgH * ratio));
|
||||
}
|
||||
|
||||
img = NDImageUtils.resize(img, resized_w, imgH);
|
||||
|
||||
img = NDImageUtils.toTensor(img).sub(0.5F).div(0.5F);
|
||||
// img = img.transpose(2, 0, 1);
|
||||
|
||||
array.set(new NDIndex(":,:,0:" + resized_w), img);
|
||||
|
||||
array = array.expandDims(0);
|
||||
|
||||
return new NDList(new NDArray[]{array});
|
||||
}
|
||||
|
||||
public Batchifier getBatchifier() {
|
||||
return null;
|
||||
}
|
||||
}
|
@ -1,27 +0,0 @@
|
||||
package me.aias.example.common;
|
||||
|
||||
public class DirectionInfo {
|
||||
private String name;
|
||||
private Double prob;
|
||||
|
||||
public DirectionInfo(String name, Double prob) {
|
||||
this.name = name;
|
||||
this.prob = prob;
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public void setName(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public Double getProb() {
|
||||
return prob;
|
||||
}
|
||||
|
||||
public void setProb(Double prob) {
|
||||
this.prob = prob;
|
||||
}
|
||||
}
|
@ -2,9 +2,13 @@ package me.aias.example.common;
|
||||
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.output.DetectedObjects;
|
||||
import org.opencv.core.CvType;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.awt.image.DataBufferByte;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
@ -80,4 +84,35 @@ public class ImageUtils {
|
||||
graphics.dispose();
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Mat to BufferedImage
|
||||
*
|
||||
* @param mat
|
||||
* @return
|
||||
*/
|
||||
public static BufferedImage mat2Image(Mat mat) {
|
||||
int width = mat.width();
|
||||
int height = mat.height();
|
||||
byte[] data = new byte[width * height * (int) mat.elemSize()];
|
||||
Imgproc.cvtColor(mat, mat, 4);
|
||||
mat.get(0, 0, data);
|
||||
BufferedImage ret = new BufferedImage(width, height, 5);
|
||||
ret.getRaster().setDataElements(0, 0, width, height, data);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* BufferedImage to Mat
|
||||
*
|
||||
* @param img
|
||||
* @return
|
||||
*/
|
||||
public static Mat image2Mat(BufferedImage img) {
|
||||
int width = img.getWidth();
|
||||
int height = img.getHeight();
|
||||
byte[] data = ((DataBufferByte) img.getRaster().getDataBuffer()).getData();
|
||||
Mat mat = new Mat(height, width, CvType.CV_8UC3);
|
||||
mat.put(0, 0, data);
|
||||
return mat;
|
||||
}
|
||||
}
|
@ -1,43 +0,0 @@
|
||||
package me.aias.example.opencv;
|
||||
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.core.CvType;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.awt.image.DataBufferByte;
|
||||
|
||||
|
||||
public class OpenCVUtils {
|
||||
/**
|
||||
* Mat to BufferedImage
|
||||
*
|
||||
* @param mat
|
||||
* @return
|
||||
*/
|
||||
public static BufferedImage mat2Image(Mat mat) {
|
||||
int width = mat.width();
|
||||
int height = mat.height();
|
||||
byte[] data = new byte[width * height * (int) mat.elemSize()];
|
||||
Imgproc.cvtColor(mat, mat, 4);
|
||||
mat.get(0, 0, data);
|
||||
BufferedImage ret = new BufferedImage(width, height, 5);
|
||||
ret.getRaster().setDataElements(0, 0, width, height, data);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* BufferedImage to Mat
|
||||
*
|
||||
* @param img
|
||||
* @return
|
||||
*/
|
||||
public static Mat image2Mat(BufferedImage img) {
|
||||
int width = img.getWidth();
|
||||
int height = img.getHeight();
|
||||
byte[] data = ((DataBufferByte) img.getRaster().getDataBuffer()).getData();
|
||||
Mat mat = new Mat(height, width, CvType.CV_8UC3);
|
||||
mat.put(0, 0, data);
|
||||
return mat;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user