diff --git a/archive/1_image_sdks/ocr_iocr_sdk/pom.xml b/archive/1_image_sdks/ocr_iocr_sdk/pom.xml new file mode 100755 index 00000000..d360b2e3 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/pom.xml @@ -0,0 +1,155 @@ + + + + + 4.0.0 + + aias + ocr_preprocess_sdk + 0.22.1 + + + UTF-8 + 1.8 + 1.8 + 0.22.1 + + + + + + org.springframework.boot + spring-boot-maven-plugin + 2.2.6.RELEASE + + me.aias.ocr.OcrV3RecognitionExample + ZIP + + + + + repackage + + + + + + + + + + + + + + + + + + + + + + + commons-cli + commons-cli + 1.4 + + + org.apache.logging.log4j + log4j-slf4j-impl + 2.17.2 + + + com.google.code.gson + gson + 2.8.5 + + + + ai.djl + api + ${djl.version} + + + ai.djl + basicdataset + ${djl.version} + + + ai.djl + model-zoo + ${djl.version} + + + + + + + + + + ai.djl.pytorch + pytorch-engine + ${djl.version} + + + + + ai.djl.onnxruntime + onnxruntime-engine + ${djl.version} + + + + + org.bytedeco + javacv-platform + 1.5.7 + + + + ai.djl.opencv + opencv + ${djl.version} + + + + org.apache.commons + commons-lang3 + 3.12.0 + + + commons-collections + commons-collections + 3.2.2 + + + org.projectlombok + lombok + 1.18.18 + provided + + + + + \ No newline at end of file diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/DetectorPool.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/DetectorPool.java new file mode 100644 index 00000000..59166b85 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/DetectorPool.java @@ -0,0 +1,50 @@ +package me.aias.example;// 导入需要的包 + +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.ndarray.NDList; +import ai.djl.repository.zoo.ZooModel; + +import java.util.ArrayList; + +public class DetectorPool { + private int poolSize; + private ZooModel detectionModel; + private ArrayList> detectorList = new ArrayList<>(); + + + public DetectorPool(int poolSize, ZooModel detectionModel) { + this.poolSize = poolSize; + this.detectionModel = detectionModel; + + for (int i = 0; i < poolSize; i++) { + Predictor detector = detectionModel.newPredictor(); + detectorList.add(detector); + } + } + + public synchronized Predictor getDetector() { + while (detectorList.isEmpty()) { + try { + wait(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + Predictor detector = detectorList.remove(0); + return detector; + } + + public synchronized void releaseDetector(Predictor detector) { + detectorList.add(detector); + notifyAll(); + } + + public void close() { + detectionModel.close(); + for (Predictor detector : detectorList) { + detector.close(); + } + + } +} \ No newline at end of file diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/DetectorPoolExample.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/DetectorPoolExample.java new file mode 100644 index 00000000..571327be --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/DetectorPoolExample.java @@ -0,0 +1,71 @@ +package me.aias.example; + + +import ai.djl.MalformedModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.ndarray.NDList; +import ai.djl.opencv.OpenCVImageFactory; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.TranslateException; +import me.aias.example.utils.detection.OCRDetectionTranslator; + +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.concurrent.*; + +public class DetectorPoolExample { + /** + * 文本检测 + * + * @return + */ + private static Criteria detectCriteria(String detUri) { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, NDList.class) + .optModelPath(Paths.get(detUri)) + .optTranslator(new OCRDetectionTranslator(new ConcurrentHashMap())) + .optProgress(new ProgressBar()) + .build(); + + return criteria; + } + + public static void main(String[] args) throws ModelNotFoundException, MalformedModelException, IOException { + Path imageFile = Paths.get("src/test/resources/template.png"); + Image templateImg = OpenCVImageFactory.getInstance().fromFile(imageFile); + + ZooModel detectionModel = ModelZoo.loadModel(detectCriteria("models/ch_PP-OCRv3_det_infer_onnx.zip")); + + int nThreads = 5; // 并发数量 + DetectorPool detectorPool = new DetectorPool(3, detectionModel); + ExecutorService executorService = Executors.newFixedThreadPool(nThreads); // 3是线程池的大小 + + for (int i = 0; i < 10; i++) { + final int index = i; + executorService.execute(new Runnable() { + public void run() { + // 这里是需要异步执行的代码 + try { + Predictor detector = detectorPool.getDetector(); + NDList list = detector.predict(templateImg); + detectorPool.releaseDetector(detector); + System.out.println("" + index + ": "+ list.size()); + } catch (TranslateException e) { + e.printStackTrace(); + } + } + }); + } + executorService.shutdown(); // 当所有任务执行完毕后关闭线程池 + + } +} \ No newline at end of file diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/HorizontalDetectorPool.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/HorizontalDetectorPool.java new file mode 100644 index 00000000..c8dde4be --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/HorizontalDetectorPool.java @@ -0,0 +1,51 @@ +package me.aias.example;// 导入需要的包 + +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.ndarray.NDList; +import ai.djl.repository.zoo.ZooModel; + +import java.util.ArrayList; + +public class HorizontalDetectorPool { + private int poolSize; + private ZooModel detectionModel; + private ArrayList> detectorList = new ArrayList<>(); + + + public HorizontalDetectorPool(int poolSize, ZooModel detectionModel) { + this.poolSize = poolSize; + this.detectionModel = detectionModel; + + for (int i = 0; i < poolSize; i++) { + Predictor detector = detectionModel.newPredictor(); + detectorList.add(detector); + } + } + + public synchronized Predictor getDetector(){ + while (detectorList.isEmpty()) { + try { + wait(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + Predictor detector = detectorList.remove(0); + return detector; + } + + public synchronized void releaseDetector(Predictor detector) { + detectorList.add(detector); + notifyAll(); + } + + public void close() { + detectionModel.close(); + for (Predictor detector : detectorList) { + detector.close(); + } + + } +} \ No newline at end of file diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/MlsdExample.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/MlsdExample.java new file mode 100755 index 00000000..c185f363 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/MlsdExample.java @@ -0,0 +1,47 @@ +package me.aias.example; + +import ai.djl.Device; +import ai.djl.ModelException; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.opencv.OpenCVImageFactory; +import ai.djl.translate.TranslateException; +import me.aias.example.model.MlsdSquareModel; +import me.aias.example.model.SingleRecognitionModel; +import me.aias.example.utils.common.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +public final class MlsdExample { + + private static final Logger logger = LoggerFactory.getLogger(MlsdExample.class); + + private MlsdExample() { + } + + public static void main(String[] args) throws IOException, ModelException, TranslateException { + Path imageFile = Paths.get("src/test/resources/warp1.png"); + Image image = OpenCVImageFactory.getInstance().fromFile(imageFile); + + try (MlsdSquareModel mlsdSquareModel = new MlsdSquareModel(); + NDManager manager = NDManager.newBaseManager(Device.cpu(), "PyTorch")) { + mlsdSquareModel.init("models/mlsd_traced_model_onnx.zip"); + + Image newImg = mlsdSquareModel.predict(image); + if(newImg != null) + ImageUtils.saveImage(newImg, "newImg.png", "build/output"); + else + System.out.println("failure"); + } + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/MlsdSquareCompExample.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/MlsdSquareCompExample.java new file mode 100755 index 00000000..07a13148 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/MlsdSquareCompExample.java @@ -0,0 +1,114 @@ +package me.aias.example; + +import ai.djl.Device; +import ai.djl.ModelException; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.opencv.OpenCVImageFactory; +import ai.djl.translate.TranslateException; +import me.aias.example.model.MlsdSquareModel; +import me.aias.example.model.SingleRecognitionModel; +import me.aias.example.utils.common.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +public final class MlsdSquareCompExample { + + private static final Logger logger = LoggerFactory.getLogger(MlsdSquareCompExample.class); + + private MlsdSquareCompExample() { + } + + public static void main(String[] args) throws IOException, ModelException, TranslateException { + Path imageFile = Paths.get("src/test/resources/template.png"); + Image template = OpenCVImageFactory.getInstance().fromFile(imageFile); + + imageFile = Paths.get("src/test/resources/ticket_0.png"); + Image img = OpenCVImageFactory.getInstance().fromFile(imageFile); + + try (SingleRecognitionModel recognitionModel = new SingleRecognitionModel(); + MlsdSquareModel mlsdSquareModel = new MlsdSquareModel(); + NDManager manager = NDManager.newBaseManager(Device.cpu(), "PyTorch")) { + + recognitionModel.init("models/ch_PP-OCRv3_det_infer_onnx.zip", "models/ch_PP-OCRv3_rec_infer_onnx.zip"); + mlsdSquareModel.init("models/mlsd_traced_model_onnx.zip"); + + Image templateCropImg = mlsdSquareModel.predict(template); + ImageUtils.saveImage(templateCropImg, "templateCrop.png", "build/output"); + // 模版文本检测 1 + // Text detection area + List templateTexts = new ArrayList<>(); + List templateTextsDet = recognitionModel.predict(manager, templateCropImg); + for (RotatedBox rotatedBox : templateTextsDet) { + LabelBean labelBean = new LabelBean(); + List points = new ArrayList<>(); + labelBean.setValue(rotatedBox.getText()); + labelBean.setField(rotatedBox.getText()); + + float[] pointsArr = rotatedBox.getBox().toFloatArray(); + for (int i = 0; i < 4; i++) { + Point point = new Point((int) pointsArr[2 * i], (int) pointsArr[2 * i + 1]); + points.add(point); + } + + labelBean.setPoints(points); + labelBean.setCenterPoint(PointUtils.getCenterPoint(points)); + templateTexts.add(labelBean); + } + + + Image targetCropImg = mlsdSquareModel.predict(img); + NDArray array = NDImageUtils.resize(targetCropImg.toNDArray(manager), templateCropImg.getWidth(), templateCropImg.getHeight(), Image.Interpolation.BILINEAR); + targetCropImg = OpenCVImageFactory.getInstance().fromNDArray(array); + ImageUtils.saveImage(targetCropImg, "imgCrop.png", "build/output"); + + // 目标文本检测 2 + // Text detection area + List targetTexts = new ArrayList<>(); + List textDetections = recognitionModel.predict(manager, targetCropImg); + for (RotatedBox rotatedBox : textDetections) { + LabelBean labelBean = new LabelBean(); + List points = new ArrayList<>(); + labelBean.setValue(rotatedBox.getText()); + + float[] pointsArr = rotatedBox.getBox().toFloatArray(); + for (int i = 0; i < 4; i++) { + Point point = new Point((int) pointsArr[2 * i], (int) pointsArr[2 * i + 1]); + points.add(point); + } + + labelBean.setPoints(points); + labelBean.setCenterPoint(PointUtils.getCenterPoint(points)); + targetTexts.add(labelBean); + } + + + Map hashMap; + String distance = "IOU"; + if (distance.equalsIgnoreCase("IOU")) { + hashMap = DistanceUtils.iou(templateTexts, targetTexts); + } else { + hashMap = DistanceUtils.l2Distance(templateTexts, targetTexts); + } + + Iterator> iterator = hashMap.entrySet().iterator(); + + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + if (entry.getKey().trim().equals("") && entry.getValue().trim().equals("")) + continue; + System.out.println(entry.getKey() + " : " + entry.getValue()); + } + } + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/OcrDirectionExample.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/OcrDirectionExample.java new file mode 100755 index 00000000..63597897 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/OcrDirectionExample.java @@ -0,0 +1,58 @@ +package me.aias.example; + +import ai.djl.ModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.TranslateException; +import me.aias.example.utils.cls.OcrDirectionDetection; +import me.aias.example.utils.common.DirectionInfo; +import me.aias.example.utils.common.ImageUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; + +/** + * OCR文字方向检测(轻量级模型). + * + * OCR text direction detection (light model) + * + * @author Calvin + * @date 2021-10-04 + * @email 179209347@qq.com + */ +public final class OcrDirectionExample { + + private static final Logger logger = LoggerFactory.getLogger(OcrDirectionExample.class); + + private OcrDirectionExample() {} + + public static void main(String[] args) throws IOException, ModelException, TranslateException { + Path imageFile = Paths.get("src/test/resources/ticket_90.png"); + Image image = ImageFactory.getInstance().fromFile(imageFile); + + OcrDirectionDetection detection = new OcrDirectionDetection(); + try (ZooModel detectionModel = ModelZoo.loadModel(detection.detectCriteria()); + Predictor detector = detectionModel.newPredictor(); + ZooModel rotateModel = ModelZoo.loadModel(detection.clsCriteria()); + Predictor rotateClassifier = rotateModel.newPredictor()) { + + DetectedObjects detections = detection.predict(image,detector,rotateClassifier); + + List boxes = detections.items(); + for (DetectedObjects.DetectedObject result : boxes) { + System.out.println(result.getClassName() + " : " + result.getProbability()); + } + + ImageUtils.saveBoundingBoxImage(image, detections, "cls_detect_result.png", "build/output"); + logger.info("{}", detections); + } + } +} \ No newline at end of file diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/OcrV3ProjStackRecExample.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/OcrV3ProjStackRecExample.java new file mode 100755 index 00000000..0d9a8c0b --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/OcrV3ProjStackRecExample.java @@ -0,0 +1,114 @@ +package me.aias.example; + +import ai.djl.Device; +import ai.djl.ModelException; +import ai.djl.modality.cv.Image; +import ai.djl.ndarray.NDManager; +import ai.djl.opencv.OpenCVImageFactory; +import ai.djl.translate.TranslateException; +import ai.djl.util.Pair; +import me.aias.example.model.SingleRecognitionModel; +import me.aias.example.utils.common.*; +import me.aias.example.utils.common.Point; +import me.aias.example.utils.opencv.OpenCVUtils; +import org.opencv.core.Mat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.awt.*; +import java.awt.image.BufferedImage; +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public final class OcrV3ProjStackRecExample { + + private static final Logger logger = LoggerFactory.getLogger(OcrV3ProjStackRecExample.class); + + private OcrV3ProjStackRecExample() { + } + + public static void main(String[] args) throws IOException, ModelException, TranslateException { + Path imageFile = Paths.get("src/test/resources/template.png"); + Image templateImg = OpenCVImageFactory.getInstance().fromFile(imageFile); + + imageFile = Paths.get("src/test/resources/warp1.png"); + Image targetImg = OpenCVImageFactory.getInstance().fromFile(imageFile); + Image origTargetImg = targetImg.duplicate(); + try (SingleRecognitionModel recognitionModel = new SingleRecognitionModel(); + NDManager manager = NDManager.newBaseManager(Device.cpu(), "PyTorch")) { + recognitionModel.init("models/ch_PP-OCRv3_det_infer_onnx.zip", "models/ch_PP-OCRv3_rec_infer_onnx.zip"); + + List projList = new ArrayList<>(); + + for (int num = 0; num < 4; num++) { + Pair pair = ProjUtils.projPointsPair(manager, recognitionModel, templateImg, targetImg); + List srcQuadPoints = (List) pair.getKey(); + List dstQuadPoints = (List) pair.getValue(); + + // [516.74072265625, 367.02178955078125, 335.10894775390625, 578.5404052734375] + // [1.0, 1.0, 2.2360680103302, 1.4142135381698608] + + // 计算距离 + double[] distances = new double[4]; + for (int i = 0; i < 4; i++) { + distances[i] = PointUtils.distance(srcQuadPoints.get(i), dstQuadPoints.get(i)); + } + + System.out.println(Arrays.toString(distances)); + + boolean pass = true; + for (int i = 0; i < 4; i++) { + if (distances[i] > 2) { + pass = false; + break; + } + } + + if (!pass) { + ProjItemBean projItemBean = ProjUtils.projTransform(srcQuadPoints, dstQuadPoints, templateImg, targetImg); + targetImg = projItemBean.getImage(); + ImageUtils.saveImage(projItemBean.getImage(), "perspectiveTransform_" + num + ".png", "build/output"); + + projList.add(projItemBean); + + } else { + break; + } + } + if (projList.size() > 0) { + org.opencv.core.Mat warp_mat = projList.get(projList.size() - 1).getWarpMat(); + if(projList.size() > 1){ + for (int i = projList.size() - 2; i >= 0; i--) { + org.opencv.core.Mat matItem = projList.get(i).getWarpMat(); + warp_mat = warp_mat.matMul(matItem); + } + } + + org.opencv.core.Mat mat = OpenCVUtils.warpPerspective((Mat) origTargetImg.getWrappedImage(), (Mat) templateImg.getWrappedImage(), warp_mat); + Image finalImg = OpenCVImageFactory.getInstance().fromImage(mat); + ImageUtils.saveImage(finalImg, "perspectiveTransform_final.png", "build/output"); + } + + } + } + + public static void save(Image image, List srcQuadPoints, List dstQuadPoints) { + // 转 BufferedImage 解决 Imgproc.putText 中文乱码问题 + Mat matImage = (Mat) image.getWrappedImage(); + BufferedImage buffImage = OpenCVUtils.mat2Image(matImage); + Color c = new Color(0, 255, 0); + for (int i = 0; i < 4; i++) { + DJLImageUtils.drawImageRect(buffImage, dstQuadPoints.get(i).getX(), dstQuadPoints.get(i).getY(), 6, 6, c); + DJLImageUtils.drawImageRect(buffImage, srcQuadPoints.get(i).getX(), srcQuadPoints.get(i).getY(), 6, 6); + + } + Mat pointMat = OpenCVUtils.image2Mat(buffImage); + Image pointImg = OpenCVImageFactory.getInstance().fromImage(pointMat); + ImageUtils.saveImage(pointImg, "points_result.png", "build/output"); + } + +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/OcrV3RecExample.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/OcrV3RecExample.java new file mode 100755 index 00000000..1d53a01e --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/OcrV3RecExample.java @@ -0,0 +1,198 @@ +package me.aias.example; + +import ai.djl.Device; +import ai.djl.ModelException; +import ai.djl.modality.cv.Image; +import ai.djl.ndarray.NDManager; +import ai.djl.opencv.OpenCVImageFactory; +import ai.djl.translate.TranslateException; +import me.aias.example.model.SingleRecognitionModel; +import me.aias.example.utils.common.*; +import me.aias.example.utils.opencv.OpenCVUtils; +import org.opencv.core.Mat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.awt.image.BufferedImage; +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; + +public final class OcrV3RecExample { + + private static final Logger logger = LoggerFactory.getLogger(OcrV3RecExample.class); + + private OcrV3RecExample() { + } + + public static void main(String[] args) throws IOException, ModelException, TranslateException { + Path imageFile = Paths.get("src/test/resources/template.png"); + Image templateImg = OpenCVImageFactory.getInstance().fromFile(imageFile); + + imageFile = Paths.get("src/test/resources/perspectiveTransform2.png"); + Image targetImg = OpenCVImageFactory.getInstance().fromFile(imageFile); + + try (SingleRecognitionModel recognitionModel = new SingleRecognitionModel(); + NDManager manager = NDManager.newBaseManager(Device.cpu(), "PyTorch")) { + recognitionModel.init("models/ch_PP-OCRv3_det_infer_onnx.zip", "models/ch_PP-OCRv3_rec_infer_onnx.zip"); + + // 模版文本检测 1 + // Text detection area + List templateTexts = new ArrayList<>(); + List templateTextsDet = recognitionModel.predict(manager, templateImg); + for (RotatedBox rotatedBox : templateTextsDet) { + LabelBean labelBean = new LabelBean(); + List points = new ArrayList<>(); + labelBean.setValue(rotatedBox.getText()); + labelBean.setField(rotatedBox.getText()); + + float[] pointsArr = rotatedBox.getBox().toFloatArray(); + for (int i = 0; i < 4; i++) { + Point point = new Point((int) pointsArr[2 * i], (int) pointsArr[2 * i + 1]); + points.add(point); + } + + labelBean.setPoints(points); + labelBean.setCenterPoint(PointUtils.getCenterPoint(points)); + templateTexts.add(labelBean); + } + + // 转 BufferedImage 解决 Imgproc.putText 中文乱码问题 + Mat wrappedImage = (Mat) templateImg.getWrappedImage(); + BufferedImage bufferedImage = OpenCVUtils.mat2Image(wrappedImage); + for (RotatedBox result : templateTextsDet) { + ImageUtils.drawImageRectWithText(bufferedImage, result.getBox(), result.getText()); + } + + Mat image2Mat = OpenCVUtils.image2Mat(bufferedImage); + templateImg = OpenCVImageFactory.getInstance().fromImage(image2Mat); + ImageUtils.saveImage(templateImg, "ocr_result.png", "build/output"); + + // 目标文本检测 2 + // Text detection area + List targetTexts = new ArrayList<>(); + List textDetections = recognitionModel.predict(manager, targetImg); + for (RotatedBox rotatedBox : textDetections) { + LabelBean labelBean = new LabelBean(); + List points = new ArrayList<>(); + labelBean.setValue(rotatedBox.getText()); + + float[] pointsArr = rotatedBox.getBox().toFloatArray(); + for (int i = 0; i < 4; i++) { + Point point = new Point((int) pointsArr[2 * i], (int) pointsArr[2 * i + 1]); + points.add(point); + } + + labelBean.setPoints(points); + labelBean.setCenterPoint(PointUtils.getCenterPoint(points)); + targetTexts.add(labelBean); + } + + List srcPoints = new ArrayList<>(); + List dstPoints = new ArrayList<>(); + for (int i = 0; i < templateTexts.size(); i++) { + String anchorText = templateTexts.get(i).getValue(); + for (int j = 0; j < targetTexts.size(); j++) { + String detectedText = targetTexts.get(j).getValue(); + if (detectedText.equals(anchorText)) { + dstPoints.add(templateTexts.get(i)); + srcPoints.add(targetTexts.get(j)); + } + } + } + + List srcPointsList = new ArrayList<>(); + List dstPointsList = new ArrayList<>(); + + for (int i = 0; i < srcPoints.size(); i++) { + for (int j = i + 1; j < srcPoints.size(); j++) { + for (int k = j + 1; k < srcPoints.size(); k++) { + for (int l = k + 1; l < srcPoints.size(); l++) { + double[][] srcArr = new double[4][2]; + srcArr[0][0] = srcPoints.get(i).getCenterPoint().getX(); + srcArr[0][1] = srcPoints.get(i).getCenterPoint().getY(); + srcArr[1][0] = srcPoints.get(j).getCenterPoint().getX(); + srcArr[1][1] = srcPoints.get(j).getCenterPoint().getY(); + srcArr[2][0] = srcPoints.get(k).getCenterPoint().getX(); + srcArr[2][1] = srcPoints.get(k).getCenterPoint().getY(); + srcArr[3][0] = srcPoints.get(l).getCenterPoint().getX(); + srcArr[3][1] = srcPoints.get(l).getCenterPoint().getY(); + srcPointsList.add(srcArr); + + double[][] dstArr = new double[4][2]; + dstArr[0][0] = dstPoints.get(i).getCenterPoint().getX(); + dstArr[0][1] = dstPoints.get(i).getCenterPoint().getY(); + dstArr[1][0] = dstPoints.get(j).getCenterPoint().getX(); + dstArr[1][1] = dstPoints.get(j).getCenterPoint().getY(); + dstArr[2][0] = dstPoints.get(k).getCenterPoint().getX(); + dstArr[2][1] = dstPoints.get(k).getCenterPoint().getY(); + dstArr[3][0] = dstPoints.get(l).getCenterPoint().getX(); + dstArr[3][1] = dstPoints.get(l).getCenterPoint().getY(); + dstPointsList.add(dstArr); + } + } + } + } + + // 根据海伦公式(Heron's formula)计算4边形面积 + double maxArea = 0; + int index = -1; + for (int i = 0; i < dstPointsList.size(); i++) { + double[][] dstArr = dstPointsList.get(i); + double area = PointUtils.getQuadArea(manager, dstArr); + if (area > maxArea) { + maxArea = area; + index = i; + } + + } + + double[][] srcArr = srcPointsList.get(index); + double[][] dstArr = dstPointsList.get(index); +// // 转 BufferedImage 解决 Imgproc.putText 中文乱码问题 +// Mat matImage = (Mat) targetImg.getWrappedImage(); +// BufferedImage buffImage = OpenCVUtils.mat2Image(matImage); +// for (int i = 0; i < 4; i++) { +// DJLImageUtils.drawImageRect(buffImage, (int) dstArr[i][0], (int) dstArr[i][1], 4, 4); +// +// } +// Mat pointMat = OpenCVUtils.image2Mat(buffImage); +// Image pointImg = OpenCVImageFactory.getInstance().fromImage(pointMat); +// ImageUtils.saveImage(pointImg, "points_result.png", "build/output"); + + + List srcQuadPoints = new ArrayList<>(); + List dstQuadPoints = new ArrayList<>(); + + for (int i = 0; i < 4; i++) { + double x = srcArr[i][0]; + double y = srcArr[i][1]; + Point point = new Point((int) x, (int) y); + srcQuadPoints.add(point); + } + + for (int i = 0; i < 4; i++) { + double x = dstArr[i][0]; + double y = dstArr[i][1]; + Point point = new Point((int) x, (int) y); + dstQuadPoints.add(point); + } + + org.opencv.core.Mat srcPoint2f = OpenCVUtils.toMat(srcQuadPoints); + org.opencv.core.Mat dstPoint2f = OpenCVUtils.toMat(dstQuadPoints); + + //4点透视变换 + // 4-point perspective transformation + org.opencv.core.Mat mat = OpenCVUtils.perspectiveTransform((org.opencv.core.Mat) targetImg.getWrappedImage(), (org.opencv.core.Mat) templateImg.getWrappedImage(), srcPoint2f, dstPoint2f); + Image newImg = OpenCVImageFactory.getInstance().fromImage(mat); + ImageUtils.saveImage(newImg, "perspectiveTransform.png", "build/output"); + + + System.out.println("end"); + + + } + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/RecPoolExample.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/RecPoolExample.java new file mode 100644 index 00000000..5cbb8a48 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/RecPoolExample.java @@ -0,0 +1,66 @@ +package me.aias.example; + + +import ai.djl.MalformedModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Rectangle; +import ai.djl.ndarray.NDList; +import ai.djl.opencv.OpenCVImageFactory; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.TranslateException; +import me.aias.example.model.RecognitionModel; +import me.aias.example.utils.common.DJLImageUtils; +import me.aias.example.utils.common.LabelBean; +import me.aias.example.utils.common.PointUtils; +import me.aias.example.utils.detection.OCRDetectionTranslator; + +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +public class RecPoolExample { + + public static void main(String[] args) throws ModelNotFoundException, MalformedModelException, IOException { + Path imageFile = Paths.get("src/test/resources/template.png"); + Image templateImg = OpenCVImageFactory.getInstance().fromFile(imageFile); + + int nThreads = 5; // 并发数量 + + RecognitionModel recognitionModel = new RecognitionModel(); + recognitionModel.init("models/ch_PP-OCRv3_det_infer_onnx.zip", "models/ch_PP-OCRv3_rec_infer_onnx.zip", 4); + + ExecutorService executorService = Executors.newFixedThreadPool(nThreads); // 3是线程池的大小 + + for (int i = 0; i < 10; i++) { + final int index = i; + executorService.execute(new Runnable() { + public void run() { + // 这里是需要异步执行的代码 + try { + DetectedObjects textDetections = recognitionModel.predict(templateImg); + List dt_boxes = textDetections.items(); + for (DetectedObjects.DetectedObject item : dt_boxes) { + System.out.println(item.getClassName()); + } + } catch (TranslateException e) { + e.printStackTrace(); + } + } + }); + } + executorService.shutdown(); // 当所有任务执行完毕后关闭线程池 + + } +} \ No newline at end of file diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/RecognizerPool.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/RecognizerPool.java new file mode 100644 index 00000000..a0f0ea24 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/RecognizerPool.java @@ -0,0 +1,49 @@ +package me.aias.example;// 导入需要的包 + +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.repository.zoo.ZooModel; + +import java.util.ArrayList; + +public class RecognizerPool { + private int poolSize; + private ZooModel recognitionModel; + private ArrayList> recognizerList = new ArrayList<>(); + + + public RecognizerPool(int poolSize, ZooModel detectionModel) { + this.poolSize = poolSize; + this.recognitionModel = detectionModel; + + for (int i = 0; i < poolSize; i++) { + Predictor detector = detectionModel.newPredictor(); + recognizerList.add(detector); + } + } + + public synchronized Predictor getRecognizer(){ + while (recognizerList.isEmpty()) { + try { + wait(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + Predictor recognizer = recognizerList.remove(0); + return recognizer; + } + + public synchronized void releaseRecognizer(Predictor recognizer) { + recognizerList.add(recognizer); + notifyAll(); + } + + public void close() { + recognitionModel.close(); + for (Predictor detector : recognizerList) { + detector.close(); + } + + } +} \ No newline at end of file diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/RotationExample.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/RotationExample.java new file mode 100755 index 00000000..a815e3c7 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/RotationExample.java @@ -0,0 +1,59 @@ +package me.aias.example; + +import ai.djl.ModelException; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.translate.TranslateException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +/** + * 图片旋转 + * Rotation Example + * + * @author Calvin + * @date 2021-06-28 + * @email 179209347@qq.com + */ +public final class RotationExample { + + private static final Logger logger = LoggerFactory.getLogger(RotationExample.class); + + private RotationExample() { + } + + public static void main(String[] args) throws IOException, ModelException, TranslateException { + Path imageFile = Paths.get("src/test/resources/ticket_0.png"); + Image image = ImageFactory.getInstance().fromFile(imageFile); + // 逆时针旋转 + // Counterclockwise rotation + image = rotateImg(image); + + saveImage(image, "rotate_result.png", "build/output"); + } + + private static Image rotateImg(Image image) { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1); + return ImageFactory.getInstance().fromNDArray(rotated); + } + } + + public static void saveImage(Image img, String name, String path) { + Path outputDir = Paths.get(path); + Path imagePath = outputDir.resolve(name); + try { + img.save(Files.newOutputStream(imagePath), "png"); + } catch (IOException e) { + e.printStackTrace(); + } + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/TestExample.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/TestExample.java new file mode 100755 index 00000000..f286177c --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/TestExample.java @@ -0,0 +1,40 @@ +package me.aias.example; + +import ai.djl.ModelException; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.translate.TranslateException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; + +public final class TestExample { + + private TestExample() { + } + + public static void main(String[] args){ + int[] arr = {1, 2, 3, 4, 5, 6}; + for (int i = 0; i < arr.length; i++) { + for (int j = i + 1; j < arr.length; j++) { + for (int k = j + 1; k < arr.length; k++) { + for (int l = k + 1; l < arr.length; l++) { + int[] combination = {arr[i], arr[j], arr[k], arr[l]}; + // do something with combination + System.out.println(Arrays.toString(combination)); + } + } + } + } + + } + +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/ThreadExample.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/ThreadExample.java new file mode 100755 index 00000000..3fb96fab --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/ThreadExample.java @@ -0,0 +1,26 @@ +package me.aias.example; + +import java.util.Arrays; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +public final class ThreadExample { + + private ThreadExample() { + } + + public static void main(String[] args){ + ExecutorService threadPool = Executors.newFixedThreadPool(3); // 3是线程池的大小 + for (int i = 0; i < 10; i++) { + threadPool.execute(new Runnable() { + public void run() { + // 这里是需要异步执行的代码 + System.out.println(""); + } + }); + } + threadPool.shutdown(); // 当所有任务执行完毕后关闭线程池 + + } + +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/model/AlignedRecognitionModel.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/model/AlignedRecognitionModel.java new file mode 100644 index 00000000..421d9090 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/model/AlignedRecognitionModel.java @@ -0,0 +1,148 @@ +package me.aias.example.model; + +import ai.djl.MalformedModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.BoundingBox; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Rectangle; +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.opencv.OpenCVImageFactory; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.TranslateException; +import me.aias.example.utils.detection.PpWordDetectionTranslator; +import me.aias.example.utils.recognition.PpWordRecognitionTranslator; + +import java.io.IOException; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; + +/** + * 已摆正图片的文字识别 + * + * @author Calvin + * @date Oct 19, 2021 + */ +public final class AlignedRecognitionModel implements AutoCloseable{ + private ZooModel detectionModel; + private Predictor detector; + private ZooModel recognitionModel; + private Predictor recognizer; + + public void init(String detModel, String recModel) throws MalformedModelException, ModelNotFoundException, IOException { + this.recognitionModel = ModelZoo.loadModel(recognizeCriteria(recModel)); + this.recognizer = recognitionModel.newPredictor(); + this.detectionModel = ModelZoo.loadModel(detectCriteria(detModel)); + this.detector = detectionModel.newPredictor(); + } + + public void close() { + this.recognitionModel.close(); + this.recognizer.close(); + this.detectionModel.close(); + this.detector.close(); + } + + private Criteria detectCriteria(String detUri) { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, DetectedObjects.class) + .optModelPath(Paths.get(detUri)) + .optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap())) + .optProgress(new ProgressBar()) + .build(); + + return criteria; + } + + private Criteria recognizeCriteria(String recUri) { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, String.class) + .optModelPath(Paths.get(recUri)) + .optProgress(new ProgressBar()) + .optTranslator(new PpWordRecognitionTranslator((new ConcurrentHashMap()))) + .build(); + + return criteria; + } + + // 多线程环境,需要把 Predictor detector 改写成线程池,每个线程一个predictor,共享一个model + public synchronized String predictSingleLineText(Image image) + throws TranslateException { + return recognizer.predict(image); + } + + // 多线程环境,需要把 Predictor detector / Predictor recognizer 改写成线程池,每个线程一个predictor,共享一个model + public synchronized DetectedObjects predict(Image image) + throws TranslateException { + DetectedObjects detections = detector.predict(image); + List boxes = detections.items(); + List names = new ArrayList<>(); + List prob = new ArrayList<>(); + List rect = new ArrayList<>(); + for (int i = 0; i < boxes.size(); i++) { + Image subImg = getSubImage(image, boxes.get(i).getBoundingBox()); + if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) { + subImg = rotateImg(subImg); + } + String name = recognizer.predict(subImg); + System.out.println(name); + names.add(name); + prob.add(-1.0); + rect.add(boxes.get(i).getBoundingBox()); + } + DetectedObjects detectedObjects = new DetectedObjects(names, prob, rect); + return detectedObjects; + } + + private Image getSubImage(Image img, BoundingBox box) { + Rectangle rect = box.getBounds(); + double[] extended = extendRect(rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight()); + int width = img.getWidth(); + int height = img.getHeight(); + int[] recovered = { + (int) (extended[0] * width), + (int) (extended[1] * height), + (int) (extended[2] * width), + (int) (extended[3] * height) + }; + return img.getSubImage(recovered[0], recovered[1], recovered[2], recovered[3]); + } + + private double[] extendRect(double xmin, double ymin, double width, double height) { + double centerx = xmin + width / 2; + double centery = ymin + height / 2; + if (width > height) { + width += height * 2.0; + height *= 3.0; + } else { + height += width * 2.0; + width *= 3.0; + } + double newX = centerx - width / 2 < 0 ? 0 : centerx - width / 2; + double newY = centery - height / 2 < 0 ? 0 : centery - height / 2; + double newWidth = newX + width > 1 ? 1 - newX : width; + double newHeight = newY + height > 1 ? 1 - newY : height; + return new double[]{newX, newY, newWidth, newHeight}; + } + + private Image rotateImg(Image image) { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1); + return OpenCVImageFactory.getInstance().fromNDArray(rotated); + } + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/model/MlsdSquareModel.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/model/MlsdSquareModel.java new file mode 100755 index 00000000..0a051680 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/model/MlsdSquareModel.java @@ -0,0 +1,812 @@ +package me.aias.example.model; + +import ai.djl.Device; +import ai.djl.MalformedModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.Batchifier; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import me.aias.example.utils.common.ImageUtils; +import me.aias.example.utils.opencv.NDArrayUtils; +import me.aias.example.utils.opencv.OpenCVUtils; +import org.opencv.core.Mat; + +import java.io.IOException; +import java.nio.file.Paths; +import java.util.ArrayList; + +/** + * @author Calvin + * @date Jun 18, 2023 + */ +public final class MlsdSquareModel implements AutoCloseable { + private ZooModel model; + private Predictor predictor; + private float thr_v = 0.1f; + private float thr_d = 0.1f; + private int detect_resolution = 512; + + public void init(String modelUri) throws MalformedModelException, ModelNotFoundException, IOException { + this.model = ModelZoo.loadModel(onnxCriteria(modelUri)); + this.predictor = this.model.newPredictor(); + } + + public void close() { + this.model.close(); + this.predictor.close(); + } + + // 多线程环境,需要把 Predictor predictor 改写成线程池,每个线程一个predictor,共享一个model + public synchronized Image predict(Image image) throws TranslateException { + Image cropImg = predictor.predict(image); + return cropImg; + } + + private Criteria onnxCriteria(String modelUri) { + + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .setTypes(Image.class, Image.class) + .optModelName("mlsd_traced_model") + .optModelPath(Paths.get(modelUri)) + .optDevice(Device.cpu()) +// .optDevice(Device.gpu()) + .optTranslator(new FeatureTranslator()) + .optProgress(new ProgressBar()) + .build(); + + return criteria; + } + + private final class FeatureTranslator implements Translator { + protected Batchifier batchifier = Batchifier.STACK; + private int topk_n = 200; + private int ksize = 3; + private float score = 0.06f; + private float outside_ratio = 0.28f; + private float inside_ratio = 0.45f; + private float w_overlap = 0.0f; + private float w_degree = 1.95f; + private float w_length = 0.0f; + private float w_area = 1.86f; + private float w_center = 0.1f; + private NDArray imgArray; + // private int width; +// private int height; + // + private int original_shape[] = new int[2]; + private int input_shape[] = new int[2]; + + FeatureTranslator() { + } + + @Override + public NDList processInput(TranslatorContext ctx, Image input) { + try (NDManager manager = NDManager.newBaseManager(ctx.getNDManager().getDevice(), "PyTorch")) { + original_shape[1] = input.getWidth(); // w - input_shape[1] + original_shape[0] = input.getHeight(); // h - input_shape[0] + + NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR); + + array = array.toType(DataType.UINT8, false); + + imgArray = array; + +// NDArray padding_im = ctx.getNDManager().zeros(new Shape(array.getShape().get(0) + 200, array.getShape().get(1) + 200, array.getShape().get(2)), DataType.FLOAT32); +// padding_im.set(new NDIndex("100:" + (original_shape[0] + 100) + ",100:"+ (original_shape[1]+ 100) + ",:" ), imgArray); + + // h : input_shape[0], w : input_shape[1] + input_shape = resize64(original_shape[0], original_shape[1], detect_resolution); + + array = NDImageUtils.resize(array, input_shape[1], input_shape[0], Image.Interpolation.AREA); + + NDArray ones = manager.ones(new Shape(array.getShape().get(0), array.getShape().get(1), 1), DataType.UINT8); + + array = array.concat(ones, -1); + + array = array.transpose(2, 0, 1); // HWC -> CHW RGB + + array = array.toType(DataType.FLOAT32, false); + + array = array.div(127.5f).sub(1.0f); + + array = array.flip(0); + + return new NDList(array); + } + } + + @Override + public Image processOutput(TranslatorContext ctx, NDList list) { + try (NDManager manager = NDManager.newBaseManager(ctx.getNDManager().getDevice(), "PyTorch")) { + + NDArray tpMap = list.singletonOrThrow(); + + // deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 3) start + int width = (int) (tpMap.getShape().get(2)); + + NDArray displacement = tpMap.get("1:5, :, :"); + + NDArray center = tpMap.get("0, :, :"); + + // Sigmoid 函数,即f(x)=1/(1+e-x) + NDArray heat = NDArrayUtils.Sigmoid(center); + + NDArray hmax = NDArrayUtils.maxPool(manager, heat, ksize, 1, (ksize - 1) / 2); + + NDArray keep = hmax.eq(heat); + keep = keep.toType(DataType.FLOAT32, false); + + heat = heat.mul(keep); + heat = heat.reshape(-1); + + NDArray indices = heat.argSort(-1, false).get("0:200"); + NDArray pts_score = heat.get(indices); + indices = indices.toType(DataType.FLOAT32, true); + NDArray yy = indices.div(width).floor().expandDims(-1); + NDArray xx = indices.mod(width).expandDims(-1); + NDArray pts = yy.concat(xx, -1); + + NDArray vmap = displacement.transpose(1, 2, 0); + // deccode_output_score_and_ptss end + + NDArray start = vmap.get(":, :, :2"); + NDArray end = vmap.get(":, :, 2:"); + + NDArray dist_map = start.sub(end).pow(2).sum(new int[]{-1}).sqrt(); + + ArrayList junc_list = new ArrayList<>(); + ArrayList segments_list = new ArrayList<>(); + + for (int i = 0; i < pts_score.size(); i++) { + center = pts.get(i); + int y = (int) center.getFloat(0); + int x = (int) center.getFloat(1); + float score = pts_score.getFloat(i); + float distance = dist_map.getFloat(y, x); + + if (score > this.score && distance > 20.0f) { + float[] junc = new float[2]; + junc[0] = x; + junc[1] = y; + junc_list.add(junc); + + NDArray array = vmap.get(y + "," + x + ",:"); // y, x, : + float disp_x_start = array.getFloat(0); + float disp_y_start = array.getFloat(1); + float disp_x_end = array.getFloat(2); + float disp_y_end = array.getFloat(3); + + float x_start = x + disp_x_start; + float y_start = y + disp_y_start; + float x_end = x + disp_x_end; + float y_end = y + disp_y_end; + + float[] segment = new float[4]; + segment[0] = x_start; + segment[1] = y_start; + segment[2] = x_end; + segment[3] = y_end; + + segments_list.add(segment); + } + } + + float[][] segmentsArr = new float[segments_list.size()][4]; + for (int i = 0; i < segments_list.size(); i++) { + float[] item = segments_list.get(i); + segmentsArr[i][0] = item[0]; + segmentsArr[i][1] = item[1]; + segmentsArr[i][2] = item[2]; + segmentsArr[i][3] = item[3]; + } + + NDArray segments = manager.create(segmentsArr).toType(DataType.FLOAT32, false); + + // ####### post processing for squares + // 1. get unique lines + start = segments.get(":, :2"); + end = segments.get(":, 2:"); + NDArray diff = start.sub(end); + + NDArray a = diff.get(":, 1"); + NDArray b = diff.get(":, 0").neg(); + NDArray c = a.mul(start.get(":, 0")).add(b.mul(start.get(":, 1"))); + NDArray d = c.abs().div(a.square().add(b.square().add(Math.exp(-10))).sqrt()); + + NDArray theta = NDArrayUtils.arctan2(diff.get(":, 0"), diff.get(":, 1")); + NDArray index = theta.lt(0.0f); + index = index.toType(DataType.INT32, false).mul(180); + theta = theta.add(index); + + NDArray hough = d.expandDims(1).concat(theta.expandDims(1), -1); + + int d_quant = 1; + int theta_quant = 2; + hough.get(":, 0").divi(d_quant); + hough.get(":, 1").divi(theta_quant); + hough = hough.floor(); + float[][] houghArr = NDArrayUtils.floatNDArrayToArray(hough); + + NDList ndList = hough.unique(0, true, false, true); + // 唯一的元素列表 + NDArray yx_indices = ndList.get(0).toType(DataType.INT32, false); + int[][] yx_indicesArr = NDArrayUtils.intNDArrayToArray(yx_indices); + int[] inds = new int[yx_indicesArr.length]; + // 唯一的元素,对应的数量 + NDArray counts = ndList.get(2); + long[] countsArr = counts.toLongArray(); + + for (int i = 0; i < yx_indicesArr.length; i++) { + for (int j = 0; j < houghArr.length; j++) { + if (yx_indicesArr[i][0] == houghArr[j][0] && yx_indicesArr[i][1] == houghArr[j][1]) { + inds[i] = j; + break; + } + } + } + + + NDArray acc_map = manager.zeros(new Shape(512 / d_quant + 1, 360 / theta_quant + 1), DataType.FLOAT32); + NDArray idx_map = manager.zeros(new Shape(512 / d_quant + 1, 360 / theta_quant + 1), DataType.INT32).sub(1); + + for (int i = 0; i < yx_indicesArr.length; i++) { + acc_map.set(new NDIndex(yx_indicesArr[i][0], yx_indicesArr[i][1]), countsArr[i]); + idx_map.set(new NDIndex(yx_indicesArr[i][0], yx_indicesArr[i][1]), inds[i]); + } + + float[][] acc_map_np = NDArrayUtils.floatNDArrayToArray(acc_map); + + NDArray max_acc_map = NDArrayUtils.maxPool(manager, acc_map, 5, 1, 2); + + + keep = acc_map.eq(max_acc_map); + keep = keep.toType(DataType.FLOAT32, false); + acc_map = acc_map.mul(keep); + NDArray flatten_acc_map = acc_map.flatten(); + + indices = flatten_acc_map.argSort(-1, false).get("0:200"); + + NDArray scores = flatten_acc_map.get(indices); + int cols = (int) (acc_map.getShape().get(1)); + yy = indices.div(cols).floor().expandDims(-1); + xx = indices.mod(cols).expandDims(-1); + NDArray yx = yy.concat(xx, -1); + float[][] yx_arr = NDArrayUtils.floatNDArrayToArray(yx); + float[] topk_values = scores.toFloatArray(); + int[][] idx_map_arr = NDArrayUtils.intNDArrayToArray(idx_map); + + int[] indices_arr = new int[yx_arr.length]; + for (int i = 0; i < yx_arr.length; i++) { + indices_arr[i] = idx_map_arr[(int) yx_arr[i][0]][(int) yx_arr[i][1]]; + } + + int basis = 5 / 2; + NDArray merged_segments = manager.zeros(new Shape(0, 4), DataType.FLOAT32); + for (int i = 0; i < yx_arr.length; i++) { + float[] yx_pt = yx_arr[i]; + float y = yx_pt[0]; + float x = yx_pt[1]; + int max_indice = indices_arr[i]; + float value = topk_values[i]; + if (max_indice == -1 || value == 0) { + continue; + } + + NDList segment_list = new NDList(); + for (int y_offset = -basis; y_offset < basis + 1; y_offset++) { + for (int x_offset = -basis; x_offset < basis + 1; x_offset++) { + if (y + y_offset < 0 || x + x_offset < 0) { + continue; + } + int indice = idx_map_arr[(int) (y + y_offset)][(int) (x + x_offset)]; + int cnt = (int) acc_map_np[(int) (y + y_offset)][(int) (x + x_offset)]; + if (indice != -1) { + segment_list.add(segments.get(indice)); + } + if (cnt > 1) { + int check_cnt = 1; + NDArray current_hough = hough.get(indice); + for (int new_indice = 0; new_indice < hough.size(0); new_indice++) { + NDArray new_hough = hough.get(new_indice); + if (current_hough.eq(new_hough).all().toBooleanArray()[0] && indice != new_indice) { + segment_list.add(segments.get(new_indice)); + check_cnt += 1; + if (check_cnt == cnt) + break; + } + } + + } + } + } + + NDArray group_segments = NDArrays.concat(segment_list).reshape(-1, 2); + NDArray sorted_group_segments = group_segments.sort(0); + + float[] min = sorted_group_segments.get("0, :").toFloatArray(); + float[] max = sorted_group_segments.get("-1, :").toFloatArray(); + float x_min = min[0]; + float y_min = min[1]; + float x_max = max[0]; + float y_max = max[1]; + + float deg = theta.get(max_indice).toFloatArray()[0]; + if (deg >= 90) { + merged_segments = merged_segments.concat(manager.create(new float[]{x_min, y_max, x_max, y_min}).reshape(1, 4)); + } else { + merged_segments = merged_segments.concat(manager.create(new float[]{x_min, y_min, x_max, y_max}).reshape(1, 4)); + } + } + + // 2. get intersections + NDArray new_segments = merged_segments; + + start = new_segments.get(":, :2"); // (x1, y1) + end = new_segments.get(":, 2:"); // (x2, y2) + NDArray new_centers = start.add(end).div(2.0f); + diff = start.sub(end); + NDArray dist_segments = diff.square().sum(new int[]{-1}).sqrt(); + + // ax + by = c + a = diff.get(":, 1"); + b = diff.get(":, 0").neg(); + c = a.mul(start.get(":, 0")).add(b.mul(start.get(":, 1"))); + + NDArray pre_det = a.expandDims(1).mul(b.expandDims(0)); + NDArray det = pre_det.sub(pre_det.transpose()); + NDArray pre_inter_y = a.expandDims(1).mul(c.expandDims(0)); + NDArray inter_y = pre_inter_y.sub(pre_inter_y.transpose()).div(det.add(Math.exp(-10))); + NDArray pre_inter_x = c.expandDims(1).mul(b.expandDims(0)); + NDArray inter_x = pre_inter_x.sub(pre_inter_x.transpose()).div(det.add(Math.exp(-10))); + NDArray inter_pts = inter_x.expandDims(2).concat(inter_y.expandDims(2), -1).toType(DataType.INT32, false); + + // 3. get corner information + // 3.1 get distance + NDArray dist_inter_to_segment1_start = inter_pts.sub(start.expandDims(1)).square().sum(new int[]{-1}, true).sqrt(); + NDArray dist_inter_to_segment1_end = inter_pts.sub(end.expandDims(1)).square().sum(new int[]{-1}, true).sqrt(); + NDArray dist_inter_to_segment2_start = inter_pts.sub(start.expandDims(0)).square().sum(new int[]{-1}, true).sqrt(); + NDArray dist_inter_to_segment2_end = inter_pts.sub(end.expandDims(0)).square().sum(new int[]{-1}, true).sqrt(); + + // sort ascending + NDArray dist_inter_to_segment1 = dist_inter_to_segment1_start.concat(dist_inter_to_segment1_end, -1).sort(-1); + NDArray dist_inter_to_segment2 = dist_inter_to_segment2_start.concat(dist_inter_to_segment2_end, -1).sort(-1); + + // 3.2 get degree + NDArray inter_to_start = new_centers.expandDims(1).sub(inter_pts); + NDArray deg_inter_to_start = NDArrayUtils.arctan2(inter_to_start.get(":, :, 1"), inter_to_start.get(":, :, 0")); + index = deg_inter_to_start.lt(0.0f); + index = index.toType(DataType.INT32, false).mul(360); + deg_inter_to_start = deg_inter_to_start.add(index); + + NDArray inter_to_end = new_centers.expandDims(0).sub(inter_pts); + + // np.arctan2和np.arctan都是计算反正切值的NumPy函数,但它们的参数和返回值不同。一般来说,np.arctan2的参数为(y, x), + NDArray deg_inter_to_end = NDArrayUtils.arctan2(inter_to_end.get(":, :, 1"), inter_to_end.get(":, :, 0")); + index = deg_inter_to_end.lt(0.0f); + index = index.toType(DataType.INT32, false).mul(360); + deg_inter_to_end = deg_inter_to_end.add(index); + + // rename variables + NDArray deg1_map = deg_inter_to_start; + NDArray deg2_map = deg_inter_to_end; + + // sort deg ascending + NDArray deg_sort = deg1_map.expandDims(2).concat(deg2_map.expandDims(2), -1).sort(-1); + NDArray deg_diff_map = deg1_map.sub(deg2_map).abs(); + // we only consider the smallest degree of intersect + // deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180] + // x -> 360- x => x + 360 - 2x = 360 - x + index = deg_diff_map.gt(180); + NDArray val1 = index.toType(DataType.INT32, false).mul(360); + NDArray val2 = index.toType(DataType.INT32, false).mul(deg_diff_map).neg().mul(2); + + deg_diff_map = deg_diff_map.add(val1).add(val2); + + // define available degree range + int[] deg_range = new int[]{60, 120}; + ArrayList> corner_dict = new ArrayList<>(); + ArrayList blueList = new ArrayList<>(); + ArrayList greenList = new ArrayList<>(); + ArrayList blackList = new ArrayList<>(); + ArrayList cyanList = new ArrayList<>(); + ArrayList redList = new ArrayList<>(); + + corner_dict.add(blueList); + corner_dict.add(greenList); + corner_dict.add(blackList); + corner_dict.add(cyanList); + corner_dict.add(redList); + + NDArray inter_points = manager.zeros(new Shape(0, 2)); + + float[] dist_segments_arr = dist_segments.toFloatArray(); + for (int i = 0; i < inter_pts.getShape().get(0); i++) { + for (int j = i + 1; j < inter_pts.getShape().get(1); j++) { + // i, j > line index, always i < j + int[] point1 = inter_pts.get(i + "," + j + ",:").toIntArray(); + int x = point1[0]; + int y = point1[1]; + float[] point2 = deg_sort.get(i + "," + j + ",:").toFloatArray(); + float deg1 = point2[0]; + float deg2 = point2[1]; + float deg_diff = deg_diff_map.getFloat(i, j); + boolean check_degree = false; + if (deg_diff > deg_range[0] && deg_diff < deg_range[1]) { + check_degree = true; + } + boolean check_distance = false; + + if (((dist_inter_to_segment1.getFloat(i, j, 1) >= dist_segments_arr[i] && + dist_inter_to_segment1.getFloat(i, j, 0) <= dist_segments_arr[i] * this.outside_ratio) || + (dist_inter_to_segment1.getFloat(i, j, 1) <= dist_segments_arr[i] && + dist_inter_to_segment1.getFloat(i, j, 0) <= dist_segments_arr[i] * this.inside_ratio)) && + ((dist_inter_to_segment2.getFloat(i, j, 1) >= dist_segments_arr[j] && + dist_inter_to_segment2.getFloat(i, j, 0) <= dist_segments_arr[j] * this.outside_ratio) || + (dist_inter_to_segment2.getFloat(i, j, 1) <= dist_segments_arr[j] && + dist_inter_to_segment2.getFloat(i, j, 0) <= dist_segments_arr[j] * this.inside_ratio))) { + check_distance = true; + } + + if (check_degree && check_distance) { + int corner_info = 0; + if ((deg1 >= 0 && deg1 <= 45 && deg2 >= 45 && deg2 <= 120) || + (deg2 >= 315 && deg1 >= 45 && deg1 <= 120)) { + corner_info = 0; // blue + } else if (deg1 >= 45 && deg1 <= 125 && deg2 >= 125 && deg2 <= 225) { + corner_info = 1; // green + } else if (deg1 >= 125 && deg1 <= 225 && deg2 >= 225 && deg2 <= 315) { + corner_info = 2; // black + } else if ((deg1 >= 0 && deg1 <= 45 && deg2 >= 225 && deg2 <= 315) || + (deg2 >= 315 && deg1 >= 225 && deg1 <= 315)) { + corner_info = 3; // cyan + } else { + corner_info = 4; // red - we don't use it + continue; + } + corner_dict.get(corner_info).add(new int[]{x, y, i, j}); + inter_points = inter_points.concat(manager.create(new int[]{x, y}).reshape(1, 2)); + } + + } + } + + NDArray square_list = manager.zeros(new Shape(0, 8)); + NDArray connect_list = manager.zeros(new Shape(0, 4)); + NDArray segment_list = manager.zeros(new Shape(0, 8)); + + int corner0_line = 0; + int corner1_line = 0; + int corner2_line = 0; + int corner3_line = 0; + for (int[] corner0 : corner_dict.get(0)) { + for (int[] corner1 : corner_dict.get(1)) { + boolean connect01 = false; + for (int i = 0; i < 2; i++) { + corner0_line = corner0[2 + i]; + for (int j = 0; j < 2; j++) { + if (corner0_line == corner1[2 + j]) { + connect01 = true; + break; + } + } + } + if (connect01) { + for (int[] corner2 : corner_dict.get(2)) { + boolean connect12 = false; + for (int i = 0; i < 2; i++) { + corner1_line = corner1[2 + i]; + for (int j = 0; j < 2; j++) { + if (corner1_line == corner2[2 + j]) { + connect12 = true; + break; + } + } + } + if (connect12) { + for (int[] corner3 : corner_dict.get(3)) { + boolean connect23 = false; + for (int i = 0; i < 2; i++) { + corner2_line = corner1[2 + i]; + for (int j = 0; j < 2; j++) { + if (corner2_line == corner2[2 + j]) { + connect23 = true; + break; + } + } + } + if (connect23) { + for (int i = 0; i < 2; i++) { + corner3_line = corner3[2 + i]; + for (int j = 0; j < 2; j++) { + if (corner3_line == corner0[2 + j]) { + square_list = square_list.concat(manager.create(new int[]{corner0[0], corner0[1], corner1[0], corner1[1], corner2[0], corner2[1], corner3[0], corner3[1]}).reshape(1, 8)); + connect_list = connect_list.concat(manager.create(new int[]{corner0_line, corner1_line, corner2_line, corner3_line}).reshape(1, 4)); + segment_list = segment_list.concat(manager.create(new int[]{corner0[2], corner0[3], corner1[2], corner1[3], corner2[2], corner2[3], corner3[2], corner3[3]}).reshape(1, 8)); + } + } + } + } + } + } + + } + } + + } + } + + float map_size = (int) imgArray.getShape().get(0) / 2; + NDArray squares = square_list.reshape(-1, 4, 2); + NDArray score_array = null; + NDArray connect_array = connect_list; + NDArray segments_array = segment_list.reshape(-1, 4, 2); + //get degree of corners: + + NDArray squares_rollup = squares.duplicate(); + NDArray last = squares.get(":," + (squares.size(1) - 1) + ",:"); + for (int i = ((int) squares.size(1) - 1); i > 0; i--) { + squares_rollup.set(new NDIndex(":," + i + ",:"), squares.get(":," + (i - 1) + ",:")); + } + squares_rollup.set(new NDIndex(":,0,:"), last); + + NDArray squares_rolldown = manager.zeros(squares.getShape()); + NDArray first = squares.get(":,0,:"); + for (int i = 0; i < squares.size(1) - 1; i++) { + squares_rolldown.set(new NDIndex(":," + i + ",:"), squares.get(":," + (i + 1) + ",:")); + } + squares_rolldown.set(new NDIndex(":," + (squares.size(1) - 1) + ",:"), first); + + NDArray vec1 = squares_rollup.sub(squares); + NDArray normalized_vec1 = vec1.div(vec1.norm(new int[]{-1}, true).add(Math.exp(-10))); + + NDArray vec2 = squares_rolldown.sub(squares); + NDArray normalized_vec2 = vec2.div(vec2.norm(new int[]{-1}, true).add(Math.exp(-10))); + + NDArray inner_products = normalized_vec1.mul(normalized_vec2).sum(new int[]{-1}); + + NDArray squares_degree = inner_products.acos().mul(180).div(Math.PI); + + NDArray overlap_scores = null; + NDArray degree_scores = null; + NDArray length_scores = null; + + for (int i = 0; i < connect_array.size(0); i++) { + NDArray connects = connect_array.get(i); + segments = segments_array.get(i); + NDArray square = squares.get(i); + NDArray degree = squares_degree.get(i); + + // ###################################### OVERLAP SCORES + float cover = 0; + float perimeter = 0; + // check 0 > 1 > 2 > 3 + float[] square_length = new float[4]; + + for (int start_idx = 0; start_idx < 4; start_idx++) { + int end_idx = (start_idx + 1) % 4; + int connect_idx = (int) connects.get(start_idx).toFloatArray()[0]; + NDArray start_segments = segments.get(start_idx); + NDArray end_segments = segments.get(end_idx); + + // check whether outside or inside + int idx_i = (int) start_segments.toFloatArray()[0]; + int idx_j = (int) start_segments.toFloatArray()[1]; + NDArray check_dist_mat; + if (connect_idx == idx_i) { + check_dist_mat = dist_inter_to_segment1; + } else { + check_dist_mat = dist_inter_to_segment2; + } + float[] range = check_dist_mat.get(idx_i + "," + idx_j + ",:").toFloatArray(); + float min_dist = range[0]; + float max_dist = range[1]; + float connect_dist = dist_segments.get(connect_idx).toFloatArray()[0]; + String start_position; + float start_min; + int start_cover_param; + int start_peri_param; + if (max_dist > connect_dist) { + start_position = "outside"; + start_min = min_dist; + start_cover_param = 0; + start_peri_param = 1; + } else { + start_position = "inside"; + start_min = min_dist; + start_cover_param = -1; + start_peri_param = -1; + } + + // check whether outside or inside + idx_i = (int) end_segments.toFloatArray()[0]; + idx_j = (int) end_segments.toFloatArray()[1]; + if (connect_idx == idx_i) { + check_dist_mat = dist_inter_to_segment1; + } else { + check_dist_mat = dist_inter_to_segment2; + } + range = check_dist_mat.get(idx_i + "," + idx_j + ",:").toFloatArray(); + min_dist = range[0]; + max_dist = range[1]; + connect_dist = dist_segments.get(connect_idx).toFloatArray()[0]; + String end_position; + float end_min; + int end_cover_param; + int end_peri_param; + if (max_dist > connect_dist) { + end_position = "outside"; + end_min = min_dist; + end_cover_param = 0; + end_peri_param = 1; + } else { + end_position = "inside"; + end_min = min_dist; + end_cover_param = -1; + end_peri_param = -1; + } + + cover += connect_dist + start_cover_param * start_min + end_cover_param * end_min; + perimeter += connect_dist + start_peri_param * start_min + end_peri_param * end_min; + + square_length[start_idx] = connect_dist + start_peri_param * start_min + end_peri_param * end_min; + } + if (overlap_scores == null) { + overlap_scores = manager.create(cover / perimeter).reshape(1); + } else { + overlap_scores = overlap_scores.concat(manager.create(cover / perimeter).reshape(1)); + } + + // ###################################### + // ###################################### DEGREE SCORES + float[] degreeArr = degree.toFloatArray(); + float deg0 = degreeArr[0]; + float deg1 = degreeArr[1]; + float deg2 = degreeArr[2]; + float deg3 = degreeArr[3]; + float deg_ratio1 = deg0 / deg2; + if (deg_ratio1 > 1.0) { + deg_ratio1 = 1 / deg_ratio1; + } + float deg_ratio2 = deg1 / deg3; + if (deg_ratio2 > 1.0) { + deg_ratio2 = 1 / deg_ratio2; + } + if (degree_scores == null) { + degree_scores = manager.create((deg_ratio1 + deg_ratio2) / 2).reshape(1); + } else { + degree_scores = degree_scores.concat(manager.create((deg_ratio1 + deg_ratio2) / 2).reshape(1)); + } + + // ###################################### + // ###################################### LENGTH SCORES + float len0 = square_length[0]; + float len1 = square_length[1]; + float len2 = square_length[2]; + float len3 = square_length[3]; + float len_ratio1 = 0; + if (len2 > len0) { + len_ratio1 = len0 / len2; + } else { + len_ratio1 = len2 / len0; + } + float len_ratio2 = 0; + if (len3 > len1) { + len_ratio2 = len1 / len3; + } else { + len_ratio2 = len3 / len1; + } + if (length_scores == null) { + length_scores = manager.create((len_ratio1 + len_ratio2) / 2).reshape(1); + } else { + length_scores = length_scores.concat(manager.create((len_ratio1 + len_ratio2) / 2).reshape(1)); + } + } + if (overlap_scores != null) + overlap_scores = overlap_scores.div(overlap_scores.max().toFloatArray()[0]); + + // ###################################### AREA SCORES + NDArray area_scores = squares.reshape(new Shape(-1, 4, 2)); + NDArray area_x = area_scores.get(":, :, 0"); + NDArray area_y = area_scores.get(":, :, 1"); + NDArray correction = area_x.get(":, -1").mul(area_y.get(":, 0")).sub(area_y.get(":, -1").mul(area_x.get(":, 0"))); + + NDArray area_scores1 = area_x.get(":, :-1").mul(area_y.get(":, 1:")).sum(new int[]{-1}); + NDArray area_scores2 = area_y.get(":, :-1").mul(area_x.get(":, 1:")).sum(new int[]{-1}); + + area_scores = area_scores1.sub(area_scores2); + area_scores = area_scores.add(correction).abs().mul(0.5); + area_scores = area_scores.div(map_size * map_size); + + // ###################################### CENTER SCORES + NDArray centers = manager.create(new float[]{256 / 2, 256 / 2}); + NDArray square_centers = squares.mean(new int[]{1}); + NDArray center2center = centers.sub(square_centers).square().sum().sqrt(); + NDArray center_scores = center2center.div(map_size / Math.sqrt(2.0)); + + if (overlap_scores != null) { + score_array = overlap_scores.mul(this.w_overlap).add(degree_scores.mul(this.w_degree)).add(area_scores.mul(this.w_area)).add(center_scores.mul(this.w_center)).add(length_scores.mul(this.w_length)); + NDArray sorted_idx = score_array.argSort(0, false); + score_array = score_array.get(sorted_idx); + squares = squares.get(sorted_idx); + } + + try { + new_segments.get(":, 0").muli(2); + new_segments.get(":, 1").muli(2); + new_segments.get(":, 2").muli(2); + new_segments.get(":, 3").muli(2); + } catch (Exception e) { + new_segments = null; + } + + try { + squares.get(":, :, 0").muli(2).divi(input_shape[1]).muli(original_shape[1]); + squares.get(":, :, 1").muli(2).divi(input_shape[0]).muli(original_shape[0]); + ; + } catch (Exception e) { + squares = null; + score_array = null; + } + + try { + inter_points.get(":, 0").muli(2); + inter_points.get(":, 1").muli(2); + } catch (Exception e) { + inter_points = null; + } + + Image img = ImageFactory.getInstance().fromNDArray(imgArray); + Mat mat = (Mat) img.getWrappedImage(); + + if(squares.getShape().get(0) == 0) + return null; + NDArray maxSquare = squares.get(0); + float[] points = maxSquare.toFloatArray(); + int[] wh = OpenCVUtils.imgCrop(points); + + Mat dst = OpenCVUtils.perspectiveTransform(mat, points); + + img = ImageFactory.getInstance().fromImage(dst); +// return img; + return img.getSubImage(0,0,wh[0],wh[1]); + } + } + + private int[] resize64(double h, double w, double resolution) { + + double k = resolution / Math.min(h, w); + h *= k; + w *= k; + + int height = (int) (Math.round(h / 64.0)) * 64; + int width = (int) (Math.round(w / 64.0)) * 64; + + return new int[]{height, width}; + } + + @Override + public Batchifier getBatchifier() { + return batchifier; + } + + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/model/RecognitionModel.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/model/RecognitionModel.java new file mode 100644 index 00000000..eb790512 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/model/RecognitionModel.java @@ -0,0 +1,290 @@ +package me.aias.example.model; + +import ai.djl.MalformedModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.BoundingBox; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Rectangle; +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.opencv.OpenCVImageFactory; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.TranslateException; +import me.aias.example.DetectorPool; +import me.aias.example.HorizontalDetectorPool; +import me.aias.example.RecognizerPool; +import me.aias.example.utils.common.Point; +import me.aias.example.utils.common.RotatedBox; +import me.aias.example.utils.detection.OCRDetectionTranslator; +import me.aias.example.utils.detection.PpWordDetectionTranslator; +import me.aias.example.utils.opencv.OpenCVUtils; +import me.aias.example.utils.recognition.PpWordRecognitionTranslator; +import org.opencv.core.Mat; + +import java.io.IOException; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; + +/** + * @author Calvin + * @date Oct 19, 2021 + */ +public final class RecognitionModel implements AutoCloseable { + private DetectorPool detectorPool; + private HorizontalDetectorPool horizontalDetectorPool; + private RecognizerPool recognizerPool; + + private ZooModel horizontalDetectionModel; + private ZooModel detectionModel; + private ZooModel recognitionModel; + + public void init(String detModel, String recModel, int poolSize) throws MalformedModelException, ModelNotFoundException, IOException { + this.recognitionModel = ModelZoo.loadModel(recognizeCriteria(recModel)); + this.detectionModel = ModelZoo.loadModel(detectCriteria(detModel)); + this.horizontalDetectionModel = ModelZoo.loadModel(horizontalCriteria(detModel)); + + detectorPool = new DetectorPool(poolSize, detectionModel); + horizontalDetectorPool = new HorizontalDetectorPool(poolSize, horizontalDetectionModel); + recognizerPool = new RecognizerPool(poolSize, recognitionModel); + + } + + /** + * 释放资源 + */ + public void close() { + this.recognitionModel.close(); +// this.recognizer.close(); + this.detectionModel.close(); +// this.detector.close(); + this.horizontalDetectionModel.close(); +// this.horizontalDetector.close(); + this.detectorPool.close(); + this.horizontalDetectorPool.close(); + this.recognizerPool.close(); + } + + /** + * 文本检测(支持有倾斜角的文本) + * + * @return + */ + private Criteria detectCriteria(String detUri) { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, NDList.class) + .optModelPath(Paths.get(detUri)) + .optTranslator(new OCRDetectionTranslator(new ConcurrentHashMap())) + .optProgress(new ProgressBar()) + .build(); + + return criteria; + } + + + /** + * 水平文本检测 + * + * @return + */ + private Criteria horizontalCriteria(String detUri) { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, DetectedObjects.class) + .optModelPath(Paths.get(detUri)) + .optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap())) + .optProgress(new ProgressBar()) + .build(); + + return criteria; + } + + /** + * 文本识别 + * + * @return + */ + private Criteria recognizeCriteria(String recUri) { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, String.class) + .optModelPath(Paths.get(recUri)) + .optProgress(new ProgressBar()) + .optTranslator(new PpWordRecognitionTranslator((new ConcurrentHashMap()))) + .build(); + + return criteria; + } + + // 多线程环境,每个线程一个predictor,共享一个model, 资源池(CPU Core 核心数)达到上限则等待 + public String predictSingleLineText(Image image) + throws TranslateException { + Predictor recognizer = recognizerPool.getRecognizer(); + String text = recognizer.predict(image); + // 释放资源 + recognizerPool.releaseRecognizer(recognizer); + return text; + } + + // 多线程环境,每个线程一个predictor,共享一个model, 资源池(CPU Core 核心数)达到上限则等待 + public DetectedObjects predict(Image image) + throws TranslateException { + Predictor horizontalDetector = horizontalDetectorPool.getDetector(); + DetectedObjects detections = horizontalDetector.predict(image); + horizontalDetectorPool.releaseDetector(horizontalDetector); + + List boxes = detections.items(); + List names = new ArrayList<>(); + List prob = new ArrayList<>(); + List rect = new ArrayList<>(); + + Predictor recognizer = recognizerPool.getRecognizer(); + for (int i = 0; i < boxes.size(); i++) { + Image subImg = getSubImage(image, boxes.get(i).getBoundingBox()); + if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) { + subImg = rotateImg(subImg); + } + String name = recognizer.predict(subImg); + System.out.println(name); + names.add(name); + prob.add(-1.0); + rect.add(boxes.get(i).getBoundingBox()); + } + // 释放资源 + recognizerPool.releaseRecognizer(recognizer); + + DetectedObjects detectedObjects = new DetectedObjects(names, prob, rect); + return detectedObjects; + } + + // 多线程环境,每个线程一个predictor,共享一个model, 资源池(CPU Core 核心数)达到上限则等待 + public List predict(NDManager manager, Image image) + throws TranslateException { + + Predictor detector = detectorPool.getDetector(); + NDList boxes = detector.predict(image); + // 释放资源 + detectorPool.releaseDetector(detector); + + // 交给 NDManager自动管理内存 + // attach to manager for automatic memory management + boxes.attach(manager); + + List result = new ArrayList<>(); + Mat mat = (Mat) image.getWrappedImage(); + + Predictor recognizer = recognizerPool.getRecognizer(); + for (int i = 0; i < boxes.size(); i++) { + NDArray box = boxes.get(i); + + float[] pointsArr = box.toFloatArray(); + float[] lt = java.util.Arrays.copyOfRange(pointsArr, 0, 2); + float[] rt = java.util.Arrays.copyOfRange(pointsArr, 2, 4); + float[] rb = java.util.Arrays.copyOfRange(pointsArr, 4, 6); + float[] lb = java.util.Arrays.copyOfRange(pointsArr, 6, 8); + int img_crop_width = (int) Math.max(distance(lt, rt), distance(rb, lb)); + int img_crop_height = (int) Math.max(distance(lt, lb), distance(rt, rb)); + List srcPoints = new ArrayList<>(); + srcPoints.add(new Point((int) lt[0], (int) lt[1])); + srcPoints.add(new Point((int) rt[0], (int) rt[1])); + srcPoints.add(new Point((int) rb[0], (int) rb[1])); + srcPoints.add(new Point((int) lb[0], (int) lb[1])); + List dstPoints = new ArrayList<>(); + dstPoints.add(new Point(0, 0)); + dstPoints.add(new Point(img_crop_width, 0)); + dstPoints.add(new Point(img_crop_width, img_crop_height)); + dstPoints.add(new Point(0, img_crop_height)); + + Mat srcPoint2f = OpenCVUtils.toMat(srcPoints); + Mat dstPoint2f = OpenCVUtils.toMat(dstPoints); + + Mat cvMat = OpenCVUtils.perspectiveTransform(mat, srcPoint2f, dstPoint2f); + + Image subImg = OpenCVImageFactory.getInstance().fromImage(cvMat); +// ImageUtils.saveImage(subImg, i + ".png", "build/output"); + + subImg = subImg.getSubImage(0, 0, img_crop_width, img_crop_height); + if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) { + subImg = rotateImg(manager, subImg); + } + + String name = recognizer.predict(subImg); + RotatedBox rotatedBox = new RotatedBox(box, name); + result.add(rotatedBox); + + cvMat.release(); + srcPoint2f.release(); + dstPoint2f.release(); + + } + // 释放资源 + recognizerPool.releaseRecognizer(recognizer); + + return result; + } + + private Image getSubImage(Image img, BoundingBox box) { + Rectangle rect = box.getBounds(); + double[] extended = extendRect(rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight()); + int width = img.getWidth(); + int height = img.getHeight(); + int[] recovered = { + (int) (extended[0] * width), + (int) (extended[1] * height), + (int) (extended[2] * width), + (int) (extended[3] * height) + }; + return img.getSubImage(recovered[0], recovered[1], recovered[2], recovered[3]); + } + + private double[] extendRect(double xmin, double ymin, double width, double height) { + double centerx = xmin + width / 2; + double centery = ymin + height / 2; + if (width > height) { + width += height * 2.0; + height *= 3.0; + } else { + height += width * 2.0; + width *= 3.0; + } + double newX = centerx - width / 2 < 0 ? 0 : centerx - width / 2; + double newY = centery - height / 2 < 0 ? 0 : centery - height / 2; + double newWidth = newX + width > 1 ? 1 - newX : width; + double newHeight = newY + height > 1 ? 1 - newY : height; + return new double[]{newX, newY, newWidth, newHeight}; + } + + private float distance(float[] point1, float[] point2) { + float disX = point1[0] - point2[0]; + float disY = point1[1] - point2[1]; + float dis = (float) Math.sqrt(disX * disX + disY * disY); + return dis; + } + + private Image rotateImg(Image image) { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1); + return OpenCVImageFactory.getInstance().fromNDArray(rotated); + } + } + + private Image rotateImg(NDManager manager, Image image) { + NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1); + return OpenCVImageFactory.getInstance().fromNDArray(rotated); + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/model/SingleRecognitionModel.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/model/SingleRecognitionModel.java new file mode 100644 index 00000000..60080903 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/model/SingleRecognitionModel.java @@ -0,0 +1,185 @@ +package me.aias.example.model; + +import ai.djl.MalformedModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.opencv.OpenCVImageFactory; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.TranslateException; +import me.aias.example.utils.common.Point; +import me.aias.example.utils.common.RotatedBox; +import me.aias.example.utils.detection.OCRDetectionTranslator; +import me.aias.example.utils.opencv.OpenCVUtils; +import me.aias.example.utils.recognition.PpWordRecognitionTranslator; +import org.opencv.core.Mat; + +import java.awt.image.BufferedImage; +import java.io.IOException; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; + +/** + * @author Calvin + * @date Oct 19, 2021 + */ +public final class SingleRecognitionModel implements AutoCloseable { + private ZooModel detectionModel; + private Predictor detector; + private ZooModel recognitionModel; + private Predictor recognizer; + + public void init(String detModel, String recModel) throws MalformedModelException, ModelNotFoundException, IOException { + this.recognitionModel = ModelZoo.loadModel(recognizeCriteria(recModel)); + this.recognizer = recognitionModel.newPredictor(); + this.detectionModel = ModelZoo.loadModel(detectCriteria(detModel)); + this.detector = detectionModel.newPredictor(); + } + + public void close() { + this.recognitionModel.close(); + this.recognizer.close(); + this.detectionModel.close(); + this.detector.close(); + } + + /** + * 文本检测 + * + * @return + */ + private Criteria detectCriteria(String detUri) { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, NDList.class) + .optModelPath(Paths.get(detUri)) + .optTranslator(new OCRDetectionTranslator(new ConcurrentHashMap())) + .optProgress(new ProgressBar()) + .build(); + + return criteria; + } + +// private Criteria detectCriteria(String detUri) { +// Criteria criteria = +// Criteria.builder() +// .optEngine("OnnxRuntime") +// .optModelName("inference") +// .setTypes(Image.class, DetectedObjects.class) +// .optModelPath(Paths.get(detUri)) +// .optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap())) +// .optProgress(new ProgressBar()) +// .build(); +// +// return criteria; +// } + + /** + * 文本识别 + * + * @return + */ + private Criteria recognizeCriteria(String recUri) { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, String.class) + .optModelPath(Paths.get(recUri)) + .optProgress(new ProgressBar()) + .optTranslator(new PpWordRecognitionTranslator((new ConcurrentHashMap()))) + .build(); + + return criteria; + } + + // 多线程环境,需要把 Predictor detector 改写成线程池,每个线程一个predictor,共享一个model + public synchronized String predictSingleLineText(Image image) + throws TranslateException { + return recognizer.predict(image); + } + + // 多线程环境,需要把 Predictor detector / Predictor recognizer 改写成线程池,每个线程一个predictor,共享一个model + public synchronized List predict(NDManager manager, Image image) + throws TranslateException { + NDList boxes = detector.predict(image); + // 交给 NDManager自动管理内存 + // attach to manager for automatic memory management + boxes.attach(manager); + + List result = new ArrayList<>(); + Mat mat = (Mat) image.getWrappedImage(); + + for (int i = 0; i < boxes.size(); i++) { + NDArray box = boxes.get(i); + + float[] pointsArr = box.toFloatArray(); + float[] lt = java.util.Arrays.copyOfRange(pointsArr, 0, 2); + float[] rt = java.util.Arrays.copyOfRange(pointsArr, 2, 4); + float[] rb = java.util.Arrays.copyOfRange(pointsArr, 4, 6); + float[] lb = java.util.Arrays.copyOfRange(pointsArr, 6, 8); + int img_crop_width = (int) Math.max(distance(lt, rt), distance(rb, lb)); + int img_crop_height = (int) Math.max(distance(lt, lb), distance(rt, rb)); + List srcPoints = new ArrayList<>(); + srcPoints.add(new Point((int) lt[0], (int) lt[1])); + srcPoints.add(new Point((int) rt[0], (int) rt[1])); + srcPoints.add(new Point((int) rb[0], (int) rb[1])); + srcPoints.add(new Point((int) lb[0], (int) lb[1])); + List dstPoints = new ArrayList<>(); + dstPoints.add(new Point(0, 0)); + dstPoints.add(new Point(img_crop_width, 0)); + dstPoints.add(new Point(img_crop_width, img_crop_height)); + dstPoints.add(new Point(0, img_crop_height)); + + Mat srcPoint2f = OpenCVUtils.toMat(srcPoints); + Mat dstPoint2f = OpenCVUtils.toMat(dstPoints); + + Mat cvMat = OpenCVUtils.perspectiveTransform(mat, srcPoint2f, dstPoint2f); + + Image subImg = OpenCVImageFactory.getInstance().fromImage(cvMat); +// ImageUtils.saveImage(subImg, i + ".png", "build/output"); + + subImg = subImg.getSubImage(0, 0, img_crop_width, img_crop_height); + if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) { + subImg = rotateImg(manager, subImg); + } + + String name = recognizer.predict(subImg); + RotatedBox rotatedBox = new RotatedBox(box, name); + result.add(rotatedBox); + + cvMat.release(); + srcPoint2f.release(); + dstPoint2f.release(); + + } + return result; + } + + private BufferedImage get_rotate_crop_image(Image image, NDArray box) { + return null; + } + + private float distance(float[] point1, float[] point2) { + float disX = point1[0] - point2[0]; + float disY = point1[1] - point2[1]; + float dis = (float) Math.sqrt(disX * disX + disY * disY); + return dis; + } + + private Image rotateImg(NDManager manager, Image image) { + NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1); + return OpenCVImageFactory.getInstance().fromNDArray(rotated); + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/cls/OcrDirectionDetection.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/cls/OcrDirectionDetection.java new file mode 100755 index 00000000..dccab1bf --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/cls/OcrDirectionDetection.java @@ -0,0 +1,138 @@ +package me.aias.example.utils.cls; + +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.output.BoundingBox; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Rectangle; +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.repository.zoo.Criteria; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.TranslateException; +import me.aias.example.utils.common.DirectionInfo; +import me.aias.example.utils.detection.PpWordDetectionTranslator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; + +public final class OcrDirectionDetection { + + private static final Logger logger = LoggerFactory.getLogger(OcrDirectionDetection.class); + + public OcrDirectionDetection() { + } + + public DetectedObjects predict( + Image image, + Predictor detector, + Predictor rotateClassifier) + throws TranslateException { + DetectedObjects detections = detector.predict(image); + + List boxes = detections.items(); + + List names = new ArrayList<>(); + List prob = new ArrayList<>(); + List rect = new ArrayList<>(); + + for (int i = 0; i < boxes.size(); i++) { + Image subImg = getSubImage(image, boxes.get(i).getBoundingBox()); + DirectionInfo result = null; + if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) { + subImg = rotateImg(subImg); + result = rotateClassifier.predict(subImg); + prob.add(result.getProb()); + if (result.getName().equalsIgnoreCase("Rotate")) { + names.add("90"); + } else { + names.add("270"); + } + } else { + result = rotateClassifier.predict(subImg); + prob.add(result.getProb()); + if (result.getName().equalsIgnoreCase("No Rotate")) { + names.add("0"); + } else { + names.add("180"); + } + } + rect.add(boxes.get(i).getBoundingBox()); + } + DetectedObjects detectedObjects = new DetectedObjects(names, prob, rect); + + return detectedObjects; + } + + public Criteria detectCriteria() { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, DetectedObjects.class) + .optModelPath(Paths.get("models/ch_PP-OCRv2_det_infer_onnx.zip")) + .optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap())) + .optProgress(new ProgressBar()) + .build(); + + return criteria; + } + + public Criteria clsCriteria() { + + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, DirectionInfo.class) + .optModelPath(Paths.get("models/ch_ppocr_mobile_v2.0_cls_onnx.zip")) + .optTranslator(new PpWordRotateTranslator()) + .optProgress(new ProgressBar()) + .build(); + return criteria; + } + + private Image getSubImage(Image img, BoundingBox box) { + Rectangle rect = box.getBounds(); + double[] extended = extendRect(rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight()); + int width = img.getWidth(); + int height = img.getHeight(); + int[] recovered = { + (int) (extended[0] * width), + (int) (extended[1] * height), + (int) (extended[2] * width), + (int) (extended[3] * height) + }; + return img.getSubImage(recovered[0], recovered[1], recovered[2], recovered[3]); + } + + private double[] extendRect(double xmin, double ymin, double width, double height) { + double centerx = xmin + width / 2; + double centery = ymin + height / 2; + if (width > height) { + width += height * 2.0; + height *= 3.0; + } else { + height += width * 2.0; + width *= 3.0; + } + double newX = centerx - width / 2 < 0 ? 0 : centerx - width / 2; + double newY = centery - height / 2 < 0 ? 0 : centery - height / 2; + double newWidth = newX + width > 1 ? 1 - newX : width; + double newHeight = newY + height > 1 ? 1 - newY : height; + return new double[]{newX, newY, newWidth, newHeight}; + } + + private Image rotateImg(Image image) { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1); + return ImageFactory.getInstance().fromNDArray(rotated); + } + } +} \ No newline at end of file diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/cls/PpWordRotateTranslator.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/cls/PpWordRotateTranslator.java new file mode 100755 index 00000000..20155910 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/cls/PpWordRotateTranslator.java @@ -0,0 +1,76 @@ +package me.aias.example.utils.cls; + +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.types.Shape; +import ai.djl.translate.Batchifier; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import me.aias.example.utils.common.DirectionInfo; + +import java.util.Arrays; +import java.util.List; + +public class PpWordRotateTranslator implements Translator { + List classes = Arrays.asList("No Rotate", "Rotate"); + + public PpWordRotateTranslator() { + } + + public DirectionInfo processOutput(TranslatorContext ctx, NDList list) { + NDArray prob = list.singletonOrThrow(); + float[] res = prob.toFloatArray(); + int maxIndex = 0; + if (res[1] > res[0]) { + maxIndex = 1; + } + + return new DirectionInfo(classes.get(maxIndex), Double.valueOf(res[maxIndex])); + } + +// public NDList processInput2(TranslatorContext ctx, Image input){ +// NDArray img = input.toNDArray(ctx.getNDManager()); +// img = NDImageUtils.resize(img, 192, 48); +// img = NDImageUtils.toTensor(img).sub(0.5F).div(0.5F); +// img = img.expandDims(0); +// return new NDList(new NDArray[]{img}); +// } + + public NDList processInput(TranslatorContext ctx, Image input) { + NDArray img = input.toNDArray(ctx.getNDManager()); + int imgC = 3; + int imgH = 48; + int imgW = 192; + + NDArray array = ctx.getNDManager().zeros(new Shape(imgC, imgH, imgW)); + + int h = input.getHeight(); + int w = input.getWidth(); + int resized_w = 0; + + float ratio = (float) w / (float) h; + if (Math.ceil(imgH * ratio) > imgW) { + resized_w = imgW; + } else { + resized_w = (int) (Math.ceil(imgH * ratio)); + } + + img = NDImageUtils.resize(img, resized_w, imgH); + + img = NDImageUtils.toTensor(img).sub(0.5F).div(0.5F); + // img = img.transpose(2, 0, 1); + + array.set(new NDIndex(":,:,0:" + resized_w), img); + + array = array.expandDims(0); + + return new NDList(new NDArray[]{array}); + } + + public Batchifier getBatchifier() { + return null; + } +} \ No newline at end of file diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/DJLImageUtils.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/DJLImageUtils.java new file mode 100644 index 00000000..5b820394 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/DJLImageUtils.java @@ -0,0 +1,82 @@ +package me.aias.example.utils.common; + +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.DetectedObjects; + +import java.awt.*; +import java.awt.image.BufferedImage; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +/** + * + * @author Calvin + * + * @email 179209347@qq.com + **/ +public class DJLImageUtils { + + public static void saveDJLImage(Image img, String name, String path) { + Path outputDir = Paths.get(path); + Path imagePath = outputDir.resolve(name); + try { + img.save(Files.newOutputStream(imagePath), "png"); + } catch (IOException e) { + e.printStackTrace(); + } + } + + public static void saveBoundingBoxImage( + Image img, DetectedObjects detection, String name, String path) throws IOException { + // Make imageName copy with alpha channel because original imageName was jpg + img.drawBoundingBoxes(detection); + Path outputDir = Paths.get(path); + Files.createDirectories(outputDir); + Path imagePath = outputDir.resolve(name); + // OpenJDK can't save jpg with alpha channel + img.save(Files.newOutputStream(imagePath), "png"); + } + + public static void drawImageRect(BufferedImage image, int x, int y, int width, int height) { + Graphics2D g = (Graphics2D) image.getGraphics(); + try { + g.setColor(new Color(246, 96, 0)); + BasicStroke bStroke = new BasicStroke(4, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER); + g.setStroke(bStroke); + g.drawRect(x, y, width, height); + + } finally { + g.dispose(); + } + } + + public static void drawImageRect( + BufferedImage image, int x, int y, int width, int height, Color c) { + Graphics2D g = (Graphics2D) image.getGraphics(); + try { + g.setColor(c); + BasicStroke bStroke = new BasicStroke(4, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER); + g.setStroke(bStroke); + g.drawRect(x, y, width, height); + + } finally { + g.dispose(); + } + } + + public static void drawImageText(BufferedImage image, String text) { + Graphics graphics = image.getGraphics(); + int fontSize = 100; + Font font = new Font("楷体", Font.PLAIN, fontSize); + try { + graphics.setFont(font); + graphics.setColor(new Color(246, 96, 0)); + int strWidth = graphics.getFontMetrics().stringWidth(text); + graphics.drawString(text, fontSize - (strWidth / 2), fontSize + 30); + } finally { + graphics.dispose(); + } + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/DirectionInfo.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/DirectionInfo.java new file mode 100755 index 00000000..1c369277 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/DirectionInfo.java @@ -0,0 +1,27 @@ +package me.aias.example.utils.common; + +public class DirectionInfo { + private String name; + private Double prob; + + public DirectionInfo(String name, Double prob) { + this.name = name; + this.prob = prob; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public Double getProb() { + return prob; + } + + public void setProb(Double prob) { + this.prob = prob; + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/DistanceUtils.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/DistanceUtils.java new file mode 100644 index 00000000..f3d19d70 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/DistanceUtils.java @@ -0,0 +1,110 @@ +package me.aias.example.utils.common; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Utility class for calculating distances and IoU. + * + * @author Calvin + * @date Oct 19, 2021 + */ +public class DistanceUtils { + /** + * Calculate L2 distance + * + * @param contentLabels 内容识别区 - the list of labels for content recognition area + * @param detectedTexts 文本检测区 - the list of labels for text detection area + * @return + */ + public static Map l2Distance(List contentLabels, List detectedTexts) { + Map hashMap = new ConcurrentHashMap<>(); + for (int i = 0; i < contentLabels.size(); i++) { + String field = contentLabels.get(i).getField(); + double minDistance = Double.MAX_VALUE; + String value = ""; + for (int j = 0; j < detectedTexts.size(); j++) { + double dis = l2Distance(contentLabels.get(i).getCenterPoint(), detectedTexts.get(j).getCenterPoint()); + if (dis < minDistance) { + minDistance = dis; + value = detectedTexts.get(j).getValue(); + } + } + System.out.println(field + " : " + value); + hashMap.put(field, value); + } + return hashMap; + } + + /** + * Calculate iou + * + * @param contentLabels 内容识别区 - the list of labels for content recognition area + * @param detectedTexts 文本检测区 - the list of labels for text detection area + * @return + */ + public static Map iou(List contentLabels, List detectedTexts) { + Map hashMap = new ConcurrentHashMap<>(); + for (int i = 0; i < contentLabels.size(); i++) { + String field = contentLabels.get(i).getField(); + double maxIOU = 0d; + String value = ""; + int[] box_1 = PointUtils.rectXYXY(contentLabels.get(i).getPoints()); + for (int j = 0; j < detectedTexts.size(); j++) { + int[] box_2 = PointUtils.rectXYXY(detectedTexts.get(j).getPoints()); + double iou = compute_iou(box_1, box_2); + if (iou > maxIOU) { + maxIOU = iou; + value = detectedTexts.get(j).getValue(); + } + } +// System.out.println(field + " : " + value); + hashMap.put(field, value); + } + return hashMap; + } + + /** + * Calculate L2 distance + * + * @param point1 + * @param point2 + * @return + */ + public static double l2Distance(ai.djl.modality.cv.output.Point point1, ai.djl.modality.cv.output.Point point2) { + double partX = Math.pow((point1.getX() - point2.getX()), 2); + double partY = Math.pow((point1.getY() - point2.getY()), 2); + return Math.sqrt(partX + partY); + } + + /** + * computing IoU + * + * @param rec1: (y0, x0, y1, x1), which reflects (top, left, bottom, right) + * @param rec2: (y0, x0, y1, x1) + * @return scala value of IoU + */ + public static float compute_iou(int[] rec1, int[] rec2) { + // computing area of each rectangles + int S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]); + int S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]); + + // computing the sum_area + int sum_area = S_rec1 + S_rec2; + + // find the each edge of intersect rectangle + int left_line = Math.max(rec1[1], rec2[1]); + int right_line = Math.min(rec1[3], rec2[3]); + int top_line = Math.max(rec1[0], rec2[0]); + int bottom_line = Math.min(rec1[2], rec2[2]); + + // judge if there is an intersect + if (left_line >= right_line || top_line >= bottom_line) { + return 0.0f; + } else { + float intersect = (right_line - left_line) * (bottom_line - top_line); + return (intersect / (sum_area - intersect)) * 1.0f; + } + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/ImageUtils.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/ImageUtils.java new file mode 100755 index 00000000..bb411f3f --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/ImageUtils.java @@ -0,0 +1,221 @@ +package me.aias.example.utils.common; + +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.ndarray.NDArray; +import org.opencv.core.Mat; +import org.opencv.core.Point; +import org.opencv.core.Scalar; +import org.opencv.imgproc.Imgproc; + +import java.awt.*; +import java.awt.image.BufferedImage; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; + +public class ImageUtils { + + /** + * 保存BufferedImage图片 + * + * @author Calvin + */ + public static void saveImage(BufferedImage img, String name, String path) { + Image djlImg = ImageFactory.getInstance().fromImage(img); // 支持多种图片格式,自动适配 + Path outputDir = Paths.get(path); + Path imagePath = outputDir.resolve(name); + // OpenJDK 不能保存 jpg 图片的 alpha channel + try { + djlImg.save(Files.newOutputStream(imagePath), "png"); + } catch (IOException e) { + e.printStackTrace(); + } + } + + /** + * 保存DJL图片 + * + * @author Calvin + */ + public static void saveImage(Image img, String name, String path) { + Path outputDir = Paths.get(path); + Path imagePath = outputDir.resolve(name); + // OpenJDK 不能保存 jpg 图片的 alpha channel + try { + img.save(Files.newOutputStream(imagePath), "png"); + } catch (IOException e) { + e.printStackTrace(); + } + } + + /** + * 保存图片,含检测框 + * + * @author Calvin + */ + public static void saveBoundingBoxImage( + Image img, DetectedObjects detection, String name, String path) throws IOException { + // Make image copy with alpha channel because original image was jpg + img.drawBoundingBoxes(detection); + Path outputDir = Paths.get(path); + Files.createDirectories(outputDir); + Path imagePath = outputDir.resolve(name); + // OpenJDK can't save jpg with alpha channel + img.save(Files.newOutputStream(imagePath), "png"); + } + + /** + * 画矩形 + * + * @param mat + * @param box + * @return + */ + public static void drawRect(Mat mat, NDArray box) { + + float[] points = box.toFloatArray(); + List list = new ArrayList<>(); + + for (int i = 0; i < 4; i++) { + Point point = new Point((int) points[2 * i], (int) points[2 * i + 1]); + list.add(point); + } + + Imgproc.line(mat, list.get(0), list.get(1), new Scalar(0, 255, 0), 1); + Imgproc.line(mat, list.get(1), list.get(2), new Scalar(0, 255, 0), 1); + Imgproc.line(mat, list.get(2), list.get(3), new Scalar(0, 255, 0), 1); + Imgproc.line(mat, list.get(3), list.get(0), new Scalar(0, 255, 0), 1); + } + + /** + * 画矩形 + * + * @param mat + * @param box + * @return + */ + public static void drawRectWithText(Mat mat, NDArray box, String text) { + + float[] points = box.toFloatArray(); + List list = new ArrayList<>(); + + for (int i = 0; i < 4; i++) { + Point point = new Point((int) points[2 * i], (int) points[2 * i + 1]); + list.add(point); + } + + Imgproc.line(mat, list.get(0), list.get(1), new Scalar(0, 255, 0), 1); + Imgproc.line(mat, list.get(1), list.get(2), new Scalar(0, 255, 0), 1); + Imgproc.line(mat, list.get(2), list.get(3), new Scalar(0, 255, 0), 1); + Imgproc.line(mat, list.get(3), list.get(0), new Scalar(0, 255, 0), 1); + // 中文乱码 + Imgproc.putText(mat, text, list.get(0), Imgproc.FONT_HERSHEY_SCRIPT_SIMPLEX, 1.0, new Scalar(0, 255, 0), 1); + } + + /** + * 画检测框(有倾斜角) + * + * @author Calvin + */ + public static void drawImageRect(BufferedImage image, NDArray box) { + float[] points = box.toFloatArray(); + int[] xPoints = new int[5]; + int[] yPoints = new int[5]; + + for (int i = 0; i < 4; i++) { + xPoints[i] = (int) points[2 * i]; + yPoints[i] = (int) points[2 * i + 1]; + } + xPoints[4] = xPoints[0]; + yPoints[4] = yPoints[0]; + + // 将绘制图像转换为Graphics2D + Graphics2D g = (Graphics2D) image.getGraphics(); + try { + g.setColor(new Color(0, 255, 0)); + // 声明画笔属性 :粗 细(单位像素)末端无修饰 折线处呈尖角 + BasicStroke bStroke = new BasicStroke(4, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER); + g.setStroke(bStroke); + g.drawPolyline(xPoints, yPoints, 5); // xPoints, yPoints, nPoints + } finally { + g.dispose(); + } + } + + /** + * 画检测框(有倾斜角)和文本 + * + * @author Calvin + */ + public static void drawImageRectWithText(BufferedImage image, NDArray box, String text) { + float[] points = box.toFloatArray(); + int[] xPoints = new int[5]; + int[] yPoints = new int[5]; + + for (int i = 0; i < 4; i++) { + xPoints[i] = (int) points[2 * i]; + yPoints[i] = (int) points[2 * i + 1]; + } + xPoints[4] = xPoints[0]; + yPoints[4] = yPoints[0]; + + // 将绘制图像转换为Graphics2D + Graphics2D g = (Graphics2D) image.getGraphics(); + try { + int fontSize = 32; + Font font = new Font("楷体", Font.PLAIN, fontSize); + g.setFont(font); + g.setColor(new Color(0, 0, 255)); + // 声明画笔属性 :粗 细(单位像素)末端无修饰 折线处呈尖角 + BasicStroke bStroke = new BasicStroke(2, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER); + g.setStroke(bStroke); + g.drawPolyline(xPoints, yPoints, 5); // xPoints, yPoints, nPoints + g.drawString(text, xPoints[0], yPoints[0]); + } finally { + g.dispose(); + } + } + + /** + * 画检测框 + * + * @author Calvin + */ + public static void drawImageRect(BufferedImage image, int x, int y, int width, int height) { + // 将绘制图像转换为Graphics2D + Graphics2D g = (Graphics2D) image.getGraphics(); + try { + g.setColor(new Color(0, 255, 0)); + // 声明画笔属性 :粗 细(单位像素)末端无修饰 折线处呈尖角 + BasicStroke bStroke = new BasicStroke(2, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER); + g.setStroke(bStroke); + g.drawRect(x, y, width, height); + } finally { + g.dispose(); + } + } + + /** + * 显示文字 + * + * @author Calvin + */ + public static void drawImageText(BufferedImage image, String text, int x, int y) { + Graphics graphics = image.getGraphics(); + int fontSize = 32; + Font font = new Font("楷体", Font.PLAIN, fontSize); + try { + graphics.setFont(font); + graphics.setColor(new Color(0, 0, 255)); + int strWidth = graphics.getFontMetrics().stringWidth(text); + graphics.drawString(text, x, y); + } finally { + graphics.dispose(); + } + } +} \ No newline at end of file diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/LabelBean.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/LabelBean.java new file mode 100644 index 00000000..76fe43ee --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/LabelBean.java @@ -0,0 +1,16 @@ +package me.aias.example.utils.common; + +import lombok.Data; + +import java.util.List; + +@Data +public class LabelBean { + private int index; + private int active; + private String type; + private String value; + private String field; + private List points; + private ai.djl.modality.cv.output.Point centerPoint; +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/Point.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/Point.java new file mode 100644 index 00000000..f7860bf3 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/Point.java @@ -0,0 +1,17 @@ +package me.aias.example.utils.common; + +import lombok.Data; + +@Data +public class Point { + private int x; + private int y; + + public Point() { + } + + public Point(int x, int y) { + this.x = x; + this.y = y; + } +} \ No newline at end of file diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/PointUtils.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/PointUtils.java new file mode 100644 index 00000000..f9cddbf9 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/PointUtils.java @@ -0,0 +1,321 @@ +package me.aias.example.utils.common; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; + +import java.util.List; + +/** + * @author Calvin + * @date Oct 19, 2021 + */ +public class PointUtils { + /** + * 计算两点距离 + * @param point1 + * @param point2 + * @return + */ + public static float distance(float[] point1, float[] point2) { + float disX = point1[0] - point2[0]; + float disY = point1[1] - point2[1]; + float dis = (float) Math.sqrt(disX * disX + disY * disY); + return dis; + } + + /** + * 计算两点距离 + * @param point1 + * @param point2 + * @return + */ + public static float distance(Point point1, Point point2) { + double disX = point1.getX() - point2.getX(); + double disY = point1.getY() - point2.getY(); + float dis = (float) Math.sqrt(disX * disX + disY * disY); + return dis; + } + + /** + * sort the points based on their x-coordinates + * 顺时针 + * + * @param pts + * @return + */ + + private static NDArray order_points_clockwise(NDArray pts) { + NDList list = new NDList(); + long[] indexes = pts.get(":, 0").argSort().toLongArray(); + + // grab the left-most and right-most points from the sorted + // x-roodinate points + Shape s1 = pts.getShape(); + NDArray leftMost1 = pts.get(indexes[0] + ",:"); + NDArray leftMost2 = pts.get(indexes[1] + ",:"); + NDArray leftMost = leftMost1.concat(leftMost2).reshape(2, 2); + NDArray rightMost1 = pts.get(indexes[2] + ",:"); + NDArray rightMost2 = pts.get(indexes[3] + ",:"); + NDArray rightMost = rightMost1.concat(rightMost2).reshape(2, 2); + + // now, sort the left-most coordinates according to their + // y-coordinates so we can grab the top-left and bottom-left + // points, respectively + indexes = leftMost.get(":, 1").argSort().toLongArray(); + NDArray lt = leftMost.get(indexes[0] + ",:"); + NDArray lb = leftMost.get(indexes[1] + ",:"); + indexes = rightMost.get(":, 1").argSort().toLongArray(); + NDArray rt = rightMost.get(indexes[0] + ",:"); + NDArray rb = rightMost.get(indexes[1] + ",:"); + + list.add(lt); + list.add(rt); + list.add(rb); + list.add(lb); + + NDArray rect = NDArrays.concat(list).reshape(4, 2); + return rect; + } + + /** + * 计算四边形的面积 + * 根据海伦公式(Heron's formula)计算面积 + * + * @param arr + * @return + */ + public static double getQuadArea(NDManager manager, double[][] arr) { + NDArray ndArray = manager.create(arr).reshape(4, 2); + ndArray = order_points_clockwise(ndArray); + double[] array = ndArray.toDoubleArray(); + + double x1 = array[0]; + double y1 = array[1]; + double x2 = array[2]; + double y2 = array[3]; + double x3 = array[4]; + double y3 = array[5]; + double x4 = array[6]; + double y4 = array[7]; + + double totalArea; + if (isInTriangle(x2, y2, x3, y3, x4, y4, x1, y1)) { // 判断点 (x1, y1) 是否在三角形 (x2,y2),(x3,y3),(x4,y4) 内 + double area1 = getTriangleArea(x2, y2, x3, y3, x1, y1); + double area2 = getTriangleArea(x2, y2, x4, y4, x1, y1); + double area3 = getTriangleArea(x3, y3, x4, y4, x1, y1); + totalArea = area1 + area2 + area3; + } else if (isInTriangle(x1, y1, x3, y3, x4, y4, x2, y2)) {// 判断点 (x2, y2) 是否在三角形 (x1,y1),(x3,y3),(x4,y4) 内 + double area1 = getTriangleArea(x1, y1, x3, y3, x2, y2); + double area2 = getTriangleArea(x1, y1, x4, y4, x2, y2); + double area3 = getTriangleArea(x3, y3, x4, y4, x2, y2); + totalArea = area1 + area2 + area3; + } else if (isInTriangle(x1, y1, x2, y2, x4, y4, x3, y3)) {// 判断点 (x3, y3) 是否在三角形 (x1,y1),(x2,y2),(x4,y4) 内 + double area1 = getTriangleArea(x1, y1, x2, y2, x3, y3); + double area2 = getTriangleArea(x1, y1, x4, y4, x3, y3); + double area3 = getTriangleArea(x2, y2, x4, y4, x3, y3); + totalArea = area1 + area2 + area3; + } else if (isInTriangle(x1, y1, x2, y2, x3, y3, x4, y4)) {// 判断点 (x4, y4) 是否在三角形 (x1,y1),(x2,y2),(x3,y3) 内 + double area1 = getTriangleArea(x1, y1, x2, y2, x4, y4); + double area2 = getTriangleArea(x1, y1, x3, y3, x4, y4); + double area3 = getTriangleArea(x2, y2, x3, y3, x4, y4); + totalArea = area1 + area2 + area3; + } else { + double area1 = getTriangleArea(x1, y1, x2, y2, x3, y3); + double area2 = getTriangleArea(x1, y1, x3, y3, x4, y4); + totalArea = area1 + area2; + } + + return totalArea; + } + + /** + * 判断点 (px, py) 是否在三角形 (x1,y1),(x2,y2),(x3,y3) 内 + * + * @param x1 + * @param y1 + * @param x2 + * @param y2 + * @param x3 + * @param y3 + * @param px + * @param py + * @return + */ + public static boolean isInTriangle(double x1, double y1, double x2, double y2, double x3, double y3, double px, double py) { + if(!isTriangle(x1, y1, x2, y2, px, py)) + return false; + double area1 = getTriangleArea(x1, y1, x2, y2, px, py); + if(!isTriangle(x1, y1, x3, y3, px, py)) + return false; + double area2 = getTriangleArea(x1, y1, x3, y3, px, py); + if(!isTriangle(x2, y2, x3, y3, px, py)) + return false; + double area3 = getTriangleArea(x2, y2, x3, y3, px, py); + if(!isTriangle(x1, y1, x2, y2, x3, y3)) + return false; + double totalArea = getTriangleArea(x1, y1, x2, y2, x3, y3); + double delta = Math.abs(totalArea - (area1 + area2 + area3)); + if (delta < 1) + return true; + else + return false; + } + + /** + * 给定3个点坐标(x1,y1),(x2,y2),(x3,y3),给出判断是否能组成三角形 + * @param x1 + * @param y1 + * @param x2 + * @param y2 + * @param x3 + * @param y3 + * @return + */ + public static boolean isTriangle(double x1, double y1, double x2, double y2, double x3, double y3) { + double a = Math.sqrt(Math.pow(x1-x2, 2) + Math.pow(y1-y2, 2)); + double b = Math.sqrt(Math.pow(x1-x3, 2) + Math.pow(y1-y3, 2)); + double c = Math.sqrt(Math.pow(x2-x3, 2) + Math.pow(y2-y3, 2)); + return a + b > c && b + c > a && a + c > b; + } + + /** + * 计算三角形的面积 + * 根据海伦公式(Heron's formula)计算三角形面积 + * + * @param x1 + * @param y1 + * @param x2 + * @param y2 + * @param x3 + * @param y3 + * @return + */ + public static double getTriangleArea(double x1, double y1, double x2, double y2, double x3, double y3) { + double a = Math.sqrt(Math.pow(x2 - x1, 2) + Math.pow(y2 - y1, 2)); + double b = Math.sqrt(Math.pow(x3 - x2, 2) + Math.pow(y3 - y2, 2)); + double c = Math.sqrt(Math.pow(x1 - x3, 2) + Math.pow(y1 - y3, 2)); + double p = (a + b + c) / 2; + double area = Math.sqrt(p * (p - a) * (p - b) * (p - c)); + return area; + } + + public static ai.djl.modality.cv.output.Point getCenterPoint(List points) { + double sumX = 0; + double sumY = 0; + + for (Point point : points) { + sumX = sumX + point.getX(); + sumY = sumY + point.getY(); + } + + ai.djl.modality.cv.output.Point centerPoint = new ai.djl.modality.cv.output.Point(sumX / 4, sumY / 4); + return centerPoint; + } + + public static Point transformPoint(NDManager manager, org.opencv.core.Mat mat, Point point) { + double[][] pointsArray = new double[3][3]; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + pointsArray[i][j] = mat.get(i, j)[0]; + } + } + NDArray ndPoints = manager.create(pointsArray); + + double[] vector = new double[3]; + vector[0] = point.getX(); + vector[1] = point.getY(); + vector[2] = 1f; + NDArray vPoints = manager.create(vector); + vPoints = vPoints.reshape(3, 1); + NDArray result = ndPoints.matMul(vPoints); + double[] dArray = result.toDoubleArray(); + if (dArray[2] != 0) { + point.setX((int) (dArray[0] / dArray[2])); + point.setY((int) (dArray[1] / dArray[2])); + } + + return point; + } + + public static List transformPoints(NDManager manager, org.opencv.core.Mat mat, List points) { + int cols = mat.cols(); + int rows = mat.rows(); + double[][] pointsArray = new double[rows][cols]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + pointsArray[i][j] = mat.get(i, j)[0]; + } + } + NDArray ndPoints = manager.create(pointsArray); + + double[] vector = new double[3]; + for (int i = 0; i < points.size(); i++) { + vector[0] = points.get(i).getX(); + vector[1] = points.get(i).getY(); + vector[2] = 1f; + NDArray vPoints = manager.create(vector); + vPoints = vPoints.reshape(3, 1); + NDArray result = ndPoints.matMul(vPoints); + double[] dArray = result.toDoubleArray(); + if (dArray.length > 2) { + if (dArray[2] != 0) { + points.get(i).setX((int) (dArray[0] / dArray[2])); + points.get(i).setY((int) (dArray[1] / dArray[2])); + } + } else { + points.get(i).setX((int) (dArray[0])); + points.get(i).setY((int) (dArray[1])); + } + + } + + return points; + } + + /** + * Get (x1,y1,x2,y2) coordinations + * + * @param points + * @return + */ + public static int[] rectXYXY(List points) { + int left = points.get(0).getX(); + int top = points.get(0).getY(); + int right = points.get(2).getX(); + int bottom = points.get(2).getY(); + return new int[]{left, top, right, bottom}; + } + + /** + * Get (x1,y1,w,h) coordinations + * + * @param points + * @return + */ + public static int[] rectXYWH(List points) { + int minX = Integer.MAX_VALUE; + int minY = Integer.MAX_VALUE; + int maxX = Integer.MIN_VALUE; + int maxY = Integer.MIN_VALUE; + + for (Point point : points) { + int x = point.getX(); + int y = point.getY(); + if (x < minX) + minX = x; + if (x > maxX) + maxX = x; + if (y < minY) + minY = y; + if (y > maxY) + maxY = y; + } + + int w = maxX - minX; + int h = maxY - minY; + return new int[]{minX, minY, w, h}; + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/ProjItemBean.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/ProjItemBean.java new file mode 100644 index 00000000..d13715bb --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/ProjItemBean.java @@ -0,0 +1,12 @@ +package me.aias.example.utils.common; + +import ai.djl.modality.cv.Image; +import lombok.Data; + +import java.util.List; + +@Data +public class ProjItemBean { + private Image image; + private org.opencv.core.Mat warpMat; +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/ProjUtils.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/ProjUtils.java new file mode 100644 index 00000000..eff75ace --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/ProjUtils.java @@ -0,0 +1,185 @@ +package me.aias.example.utils.common; + +import ai.djl.modality.cv.Image; +import ai.djl.ndarray.NDManager; +import ai.djl.opencv.OpenCVImageFactory; +import ai.djl.translate.TranslateException; +import ai.djl.util.Pair; +import me.aias.example.model.SingleRecognitionModel; +import me.aias.example.utils.opencv.OpenCVUtils; +import org.opencv.core.Mat; +import org.opencv.imgproc.Imgproc; + +import java.util.ArrayList; +import java.util.List; + +/** + * @author Calvin + * @date Jun 18, 2023 + */ +public class ProjUtils { + + /** + * 获取图片对应2个4变形4对顶点 + * + * @return + */ + + public static Pair, List> projPointsPair(NDManager manager, SingleRecognitionModel recognitionModel, Image templateImg, Image targetImg) throws TranslateException { + // 模版文本检测 1 + // Text detection area + List templateTexts = new ArrayList<>(); + List templateTextsDet = recognitionModel.predict(manager, templateImg); + for (RotatedBox rotatedBox : templateTextsDet) { + LabelBean labelBean = new LabelBean(); + List points = new ArrayList<>(); + labelBean.setValue(rotatedBox.getText()); + labelBean.setField(rotatedBox.getText()); + + float[] pointsArr = rotatedBox.getBox().toFloatArray(); + for (int i = 0; i < 4; i++) { + Point point = new Point((int) pointsArr[2 * i], (int) pointsArr[2 * i + 1]); + points.add(point); + } + + labelBean.setPoints(points); + labelBean.setCenterPoint(PointUtils.getCenterPoint(points)); + templateTexts.add(labelBean); + } + +// // 转 BufferedImage 解决 Imgproc.putText 中文乱码问题 +// Mat wrappedImage = (Mat) templateImg.getWrappedImage(); +// BufferedImage bufferedImage = OpenCVUtils.mat2Image(wrappedImage); +// for (RotatedBox result : templateTextsDet) { +// ImageUtils.drawImageRectWithText(bufferedImage, result.getBox(), result.getText()); +// } +// +// Mat image2Mat = OpenCVUtils.image2Mat(bufferedImage); +// templateImg = OpenCVImageFactory.getInstance().fromImage(image2Mat); +// ImageUtils.saveImage(templateImg, "ocr_result.png", "build/output"); + + // 目标文本检测 2 + // Text detection area + List targetTexts = new ArrayList<>(); + List textDetections = recognitionModel.predict(manager, targetImg); + for (RotatedBox rotatedBox : textDetections) { + LabelBean labelBean = new LabelBean(); + List points = new ArrayList<>(); + labelBean.setValue(rotatedBox.getText()); + + float[] pointsArr = rotatedBox.getBox().toFloatArray(); + for (int i = 0; i < 4; i++) { + Point point = new Point((int) pointsArr[2 * i], (int) pointsArr[2 * i + 1]); + points.add(point); + } + + labelBean.setPoints(points); + labelBean.setCenterPoint(PointUtils.getCenterPoint(points)); + targetTexts.add(labelBean); + } + + List srcPoints = new ArrayList<>(); + List dstPoints = new ArrayList<>(); + for (int i = 0; i < templateTexts.size(); i++) { + String anchorText = templateTexts.get(i).getValue(); + for (int j = 0; j < targetTexts.size(); j++) { + String detectedText = targetTexts.get(j).getValue(); + if (detectedText.equals(anchorText)) { + dstPoints.add(templateTexts.get(i)); + srcPoints.add(targetTexts.get(j)); + } + } + } + + List srcPointsList = new ArrayList<>(); + List dstPointsList = new ArrayList<>(); + + for (int i = 0; i < srcPoints.size(); i++) { + for (int j = i + 1; j < srcPoints.size(); j++) { + for (int k = j + 1; k < srcPoints.size(); k++) { + for (int l = k + 1; l < srcPoints.size(); l++) { + double[][] srcArr = new double[4][2]; + srcArr[0][0] = srcPoints.get(i).getCenterPoint().getX(); + srcArr[0][1] = srcPoints.get(i).getCenterPoint().getY(); + srcArr[1][0] = srcPoints.get(j).getCenterPoint().getX(); + srcArr[1][1] = srcPoints.get(j).getCenterPoint().getY(); + srcArr[2][0] = srcPoints.get(k).getCenterPoint().getX(); + srcArr[2][1] = srcPoints.get(k).getCenterPoint().getY(); + srcArr[3][0] = srcPoints.get(l).getCenterPoint().getX(); + srcArr[3][1] = srcPoints.get(l).getCenterPoint().getY(); + srcPointsList.add(srcArr); + + double[][] dstArr = new double[4][2]; + dstArr[0][0] = dstPoints.get(i).getCenterPoint().getX(); + dstArr[0][1] = dstPoints.get(i).getCenterPoint().getY(); + dstArr[1][0] = dstPoints.get(j).getCenterPoint().getX(); + dstArr[1][1] = dstPoints.get(j).getCenterPoint().getY(); + dstArr[2][0] = dstPoints.get(k).getCenterPoint().getX(); + dstArr[2][1] = dstPoints.get(k).getCenterPoint().getY(); + dstArr[3][0] = dstPoints.get(l).getCenterPoint().getX(); + dstArr[3][1] = dstPoints.get(l).getCenterPoint().getY(); + dstPointsList.add(dstArr); + } + } + } + } + + // 根据海伦公式(Heron's formula)计算4边形面积 + double maxArea = 0; + int index = -1; + for (int i = 0; i < dstPointsList.size(); i++) { + double[][] dstArr = dstPointsList.get(i); + double area = PointUtils.getQuadArea(manager, dstArr); + if (area > maxArea) { + maxArea = area; + index = i; + } + + } + + double[][] srcArr = srcPointsList.get(index); + double[][] dstArr = dstPointsList.get(index); + + List srcQuadPoints = new ArrayList<>(); + List dstQuadPoints = new ArrayList<>(); + for (int i = 0; i < 4; i++) { + double x = srcArr[i][0]; + double y = srcArr[i][1]; + Point point1 = new Point((int) x, (int) y); + srcQuadPoints.add(point1); + + x = dstArr[i][0]; + y = dstArr[i][1]; + Point point2 = new Point((int) x, (int) y); + dstQuadPoints.add(point2); + } + + return new Pair<>(srcQuadPoints, dstQuadPoints); + } + + /** + * 透视变换 + * + * @return + */ + + public static ProjItemBean projTransform(List srcQuadPoints, List dstQuadPoints, Image templateImg, Image targetImg) { + Mat srcPoint2f = OpenCVUtils.toMat(srcQuadPoints); + Mat dstPoint2f = OpenCVUtils.toMat(dstQuadPoints); + + // 透视变换矩阵 + // perspective transformation + org.opencv.core.Mat warp_mat = Imgproc.getPerspectiveTransform(srcPoint2f, dstPoint2f); + + // 透视变换 + // perspective transformation + Mat mat = OpenCVUtils.perspectiveTransform((Mat) targetImg.getWrappedImage(), (Mat) templateImg.getWrappedImage(), srcPoint2f, dstPoint2f); + Image newImg = OpenCVImageFactory.getInstance().fromImage(mat); + ProjItemBean projItemBean = new ProjItemBean(); + projItemBean.setImage(newImg); + projItemBean.setWarpMat(warp_mat); + + return projItemBean; + } + +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/RotatedBox.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/RotatedBox.java new file mode 100755 index 00000000..a358fdb8 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/RotatedBox.java @@ -0,0 +1,44 @@ +package me.aias.example.utils.common; + +import ai.djl.ndarray.NDArray; + +public class RotatedBox implements Comparable { + private NDArray box; + private String text; + + public RotatedBox(NDArray box, String text) { + this.box = box; + this.text = text; + } + + /** + * 将左上角 Y 坐标升序排序 + * + * @param o + * @return + */ + @Override + public int compareTo(RotatedBox o) { + NDArray lowBox = this.getBox(); + NDArray highBox = o.getBox(); + float lowY = lowBox.toFloatArray()[1]; + float highY = highBox.toFloatArray()[1]; + return (lowY < highY) ? -1 : 1; + } + + public NDArray getBox() { + return box; + } + + public void setBox(NDArray box) { + this.box = box; + } + + public String getText() { + return text; + } + + public void setText(String text) { + this.text = text; + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/RotatedBoxCompX.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/RotatedBoxCompX.java new file mode 100755 index 00000000..5f713817 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/common/RotatedBoxCompX.java @@ -0,0 +1,44 @@ +package me.aias.example.utils.common; + +import ai.djl.ndarray.NDArray; + +public class RotatedBoxCompX implements Comparable { + private NDArray box; + private String text; + + public RotatedBoxCompX(NDArray box, String text) { + this.box = box; + this.text = text; + } + + /** + * 将左上角 X 坐标升序排序 + * + * @param o + * @return + */ + @Override + public int compareTo(RotatedBoxCompX o) { + NDArray leftBox = this.getBox(); + NDArray rightBox = o.getBox(); + float leftX = leftBox.toFloatArray()[0]; + float rightX = rightBox.toFloatArray()[0]; + return (leftX < rightX) ? -1 : 1; + } + + public NDArray getBox() { + return box; + } + + public void setBox(NDArray box) { + this.box = box; + } + + public String getText() { + return text; + } + + public void setText(String text) { + this.text = text; + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/detection/BoundFinder.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/detection/BoundFinder.java new file mode 100755 index 00000000..5af53127 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/detection/BoundFinder.java @@ -0,0 +1,122 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package me.aias.example.utils.detection; + +import ai.djl.modality.cv.output.BoundingBox; +import ai.djl.modality.cv.output.Point; +import ai.djl.modality.cv.output.Rectangle; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.stream.Collectors; + +/** Compute the bound of single colored region. */ +public class BoundFinder { + + private final int[] deltaX = {0, 1, -1, 0}; + private final int[] deltaY = {1, 0, 0, -1}; + private List> pointsCollection; + private int width; + private int height; + + /** + * Compute the bound based on the boolean mask. + * + * @param grid the 2D boolean mask that defines the region + */ + public BoundFinder(boolean[][] grid) { + pointsCollection = new ArrayList<>(); + width = grid.length; + height = grid[0].length; + boolean[][] visited = new boolean[width][height]; + // get all points connections + for (int i = 0; i < width; i++) { + for (int j = 0; j < height; j++) { + if (grid[i][j] && !visited[i][j]) { + pointsCollection.add(bfs(grid, i, j, visited)); + } + } + } + } + + /** + * Gets all points from the region. + * + * @return all connected points + */ + public List> getPoints() { + return pointsCollection; + } + + /** + * Compute rectangle bounding boxes. + * + * @return the region defined by boxes + */ + public List getBoxes() { + return pointsCollection.stream() + .parallel() + .map( + points -> { + double[] minMax = {Integer.MAX_VALUE, Integer.MAX_VALUE, -1, -1}; + points.forEach( + p -> { + minMax[0] = Math.min(minMax[0], p.getX()); + minMax[1] = Math.min(minMax[1], p.getY()); + minMax[2] = Math.max(minMax[2], p.getX()); + minMax[3] = Math.max(minMax[3], p.getY()); + }); + return new Rectangle( + minMax[1], + minMax[0], + minMax[3] - minMax[1], + minMax[2] - minMax[0]); + }) + .filter(rect -> rect.getWidth() * width > 5.0 && rect.getHeight() * height > 5.0) + .collect(Collectors.toList()); + } + + private List bfs(boolean[][] grid, int x, int y, boolean[][] visited) { + Queue queue = new ArrayDeque<>(); + queue.offer(new Point(x, y)); + visited[x][y] = true; + + List points = new ArrayList<>(); + while (!queue.isEmpty()) { + Point point = queue.poll(); + points.add(new Point(point.getX() / width, point.getY() / height)); + for (int direction = 0; direction < 4; direction++) { + int newX = (int) point.getX() + deltaX[direction]; + int newY = (int) point.getY() + deltaY[direction]; + if (!isVaild(grid, newX, newY, visited)) { + continue; + } + queue.offer(new Point(newX, newY)); + visited[newX][newY] = true; + } + } + return points; + } + + private boolean isVaild(boolean[][] grid, int x, int y, boolean[][] visited) { + if (x < 0 || x >= width || y < 0 || y >= height) { + return false; + } + if (visited[x][y]) { + return false; + } + return grid[x][y]; + } +} \ No newline at end of file diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/detection/OCRDetectionTranslator.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/detection/OCRDetectionTranslator.java new file mode 100755 index 00000000..947abbbd --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/detection/OCRDetectionTranslator.java @@ -0,0 +1,514 @@ +package me.aias.example.utils.detection; + +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.translate.Batchifier; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import me.aias.example.utils.opencv.NDArrayUtils; +import org.opencv.core.*; +import org.opencv.imgproc.Imgproc; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class OCRDetectionTranslator implements Translator { + // det_algorithm == "DB" + private final float thresh = 0.3f; + private final boolean use_dilation = false; + private final String score_mode = "fast"; + private final String box_type = "quad"; + + private final int limit_side_len; + private final int max_candidates; + private final int min_size; + private final float box_thresh; + private final float unclip_ratio; + private float ratio_h; + private float ratio_w; + private int img_height; + private int img_width; + + public OCRDetectionTranslator(Map arguments) { + limit_side_len = + arguments.containsKey("limit_side_len") + ? Integer.parseInt(arguments.get("limit_side_len").toString()) + : 960; + max_candidates = + arguments.containsKey("max_candidates") + ? Integer.parseInt(arguments.get("max_candidates").toString()) + : 1000; + min_size = + arguments.containsKey("min_size") + ? Integer.parseInt(arguments.get("min_size").toString()) + : 3; + box_thresh = + arguments.containsKey("box_thresh") + ? Float.parseFloat(arguments.get("box_thresh").toString()) + : 0.6f; // 0.5f + unclip_ratio = + arguments.containsKey("unclip_ratio") + ? Float.parseFloat(arguments.get("unclip_ratio").toString()) + : 1.6f; + } + + @Override + public NDList processOutput(TranslatorContext ctx, NDList list) { + NDManager manager = ctx.getNDManager(); + NDArray pred = list.singletonOrThrow(); + pred = pred.squeeze(); + NDArray segmentation = pred.gt(thresh); // thresh=0.3 .mul(255f) + + segmentation = segmentation.toType(DataType.UINT8, true); + Shape shape = segmentation.getShape(); + int rows = (int) shape.get(0); + int cols = (int) shape.get(1); + + Mat newMask = new Mat(); + if (this.use_dilation) { + Mat mask = new Mat(); + //convert from NDArray to Mat + Mat srcMat = NDArrayUtils.uint8NDArrayToMat(segmentation); + // size 越小,腐蚀的单位越小,图片越接近原图 + // Mat dilation_kernel = Imgproc.getStructuringElement(Imgproc.MORPH_RECT, new Size(2, 2)); + Mat dilation_kernel = NDArrayUtils.uint8ArrayToMat(new byte[][]{{1, 1}, {1, 1}}); + /** + * 膨胀说明: 图像的一部分区域与指定的核进行卷积, 求核的最`大`值并赋值给指定区域。 膨胀可以理解为图像中`高亮区域`的'领域扩大'。 + * 意思是高亮部分会侵蚀不是高亮的部分,使高亮部分越来越多。 + */ + Imgproc.dilate(srcMat, mask, dilation_kernel); + //destination Matrix + Scalar scalar = new Scalar(255); + Core.multiply(mask, scalar, newMask); + // release Mat + mask.release(); + srcMat.release(); + dilation_kernel.release(); + } else { + Mat srcMat = NDArrayUtils.uint8NDArrayToMat(segmentation); + //destination Matrix + Scalar scalar = new Scalar(255); + Core.multiply(srcMat, scalar, newMask); + // release Mat + srcMat.release(); + } + + NDArray boxes = boxes_from_bitmap(manager, pred, newMask); + + //boxes[:, :, 0] = boxes[:, :, 0] / ratio_w + NDArray boxes1 = boxes.get(":, :, 0").div(ratio_w); + boxes.set(new NDIndex(":, :, 0"), boxes1); + //boxes[:, :, 1] = boxes[:, :, 1] / ratio_h + NDArray boxes2 = boxes.get(":, :, 1").div(ratio_h); + boxes.set(new NDIndex(":, :, 1"), boxes2); + + NDList dt_boxes = this.filter_tag_det_res(boxes); + + dt_boxes.detach(); + + // release Mat + newMask.release(); + + return dt_boxes; + } + + + private NDList filter_tag_det_res(NDArray dt_boxes) { + NDList boxesList = new NDList(); + + int num = (int) dt_boxes.getShape().get(0); + for (int i = 0; i < num; i++) { + NDArray box = dt_boxes.get(i); + box = order_points_clockwise(box); + box = clip_det_res(box); + float[] box0 = box.get(0).toFloatArray(); + float[] box1 = box.get(1).toFloatArray(); + float[] box3 = box.get(3).toFloatArray(); + int rect_width = (int) Math.sqrt(Math.pow(box1[0] - box0[0], 2) + Math.pow(box1[1] - box0[1], 2)); + int rect_height = (int) Math.sqrt(Math.pow(box3[0] - box0[0], 2) + Math.pow(box3[1] - box0[1], 2)); + if (rect_width <= 3 || rect_height <= 3) + continue; + boxesList.add(box); + } + + return boxesList; + } + + private NDArray clip_det_res(NDArray points) { + for (int i = 0; i < points.getShape().get(0); i++) { + int value = Math.max((int) points.get(i, 0).toFloatArray()[0], 0); + value = Math.min(value, img_width - 1); + points.set(new NDIndex(i + ",0"), value); + value = Math.max((int) points.get(i, 1).toFloatArray()[0], 0); + value = Math.min(value, img_height - 1); + points.set(new NDIndex(i + ",1"), value); + } + + return points; + } + + /** + * sort the points based on their x-coordinates + * 顺时针 + * + * @param pts + * @return + */ + + private NDArray order_points_clockwise(NDArray pts) { + NDList list = new NDList(); + long[] indexes = pts.get(":, 0").argSort().toLongArray(); + + // grab the left-most and right-most points from the sorted + // x-roodinate points + Shape s1 = pts.getShape(); + NDArray leftMost1 = pts.get(indexes[0] + ",:"); + NDArray leftMost2 = pts.get(indexes[1] + ",:"); + NDArray leftMost = leftMost1.concat(leftMost2).reshape(2, 2); + NDArray rightMost1 = pts.get(indexes[2] + ",:"); + NDArray rightMost2 = pts.get(indexes[3] + ",:"); + NDArray rightMost = rightMost1.concat(rightMost2).reshape(2, 2); + + // now, sort the left-most coordinates according to their + // y-coordinates so we can grab the top-left and bottom-left + // points, respectively + indexes = leftMost.get(":, 1").argSort().toLongArray(); + NDArray lt = leftMost.get(indexes[0] + ",:"); + NDArray lb = leftMost.get(indexes[1] + ",:"); + indexes = rightMost.get(":, 1").argSort().toLongArray(); + NDArray rt = rightMost.get(indexes[0] + ",:"); + NDArray rb = rightMost.get(indexes[1] + ",:"); + + list.add(lt); + list.add(rt); + list.add(rb); + list.add(lb); + + NDArray rect = NDArrays.concat(list).reshape(4, 2); + return rect; + } + + /** + * Get boxes from the binarized image predicted by DB + * + * @param manager + * @param pred the binarized image predicted by DB. + * @param bitmap new 'pred' after threshold filtering. + */ + private NDArray boxes_from_bitmap(NDManager manager, NDArray pred, Mat bitmap) { + int dest_height = (int) pred.getShape().get(0); + int dest_width = (int) pred.getShape().get(1); + int height = bitmap.rows(); + int width = bitmap.cols(); + + List contours = new ArrayList<>(); + Mat hierarchy = new Mat(); + // 寻找轮廓 + Imgproc.findContours( + bitmap, + contours, + hierarchy, + Imgproc.RETR_LIST, + Imgproc.CHAIN_APPROX_SIMPLE); + + int num_contours = Math.min(contours.size(), max_candidates); + NDList boxList = new NDList(); + float[] scores = new float[num_contours]; + + for (int index = 0; index < num_contours; index++) { + MatOfPoint contour = contours.get(index); + MatOfPoint2f newContour = new MatOfPoint2f(contour.toArray()); + float[][] pointsArr = new float[4][2]; + int sside = get_mini_boxes(newContour, pointsArr); + if (sside < this.min_size) + continue; + NDArray points = manager.create(pointsArr); + float score = box_score_fast(manager, pred, points); + if (score < this.box_thresh) + continue; + + NDArray box = unclip(manager, points); // TODO get_mini_boxes(box) + + // box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) + NDArray boxes1 = box.get(":,0").div(width).mul(dest_width).round().clip(0, dest_width); + box.set(new NDIndex(":, 0"), boxes1); + // box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height) + NDArray boxes2 = box.get(":,1").div(height).mul(dest_height).round().clip(0, dest_height); + box.set(new NDIndex(":, 1"), boxes2); + + boxList.add(box); + scores[index] = score; + + // release memory + contour.release(); + newContour.release(); + } + + NDArray boxes = NDArrays.stack(boxList); + + // release + hierarchy.release(); + + return boxes; + } + + /** + * Shrink or expand the boxaccording to 'unclip_ratio' + * + * @param points The predicted box. + * @return uncliped box + */ + private NDArray unclip(NDManager manager, NDArray points) { + points = order_points_clockwise(points); + float[] pointsArr = points.toFloatArray(); + float[] lt = java.util.Arrays.copyOfRange(pointsArr, 0, 2); + float[] lb = java.util.Arrays.copyOfRange(pointsArr, 6, 8); + + float[] rt = java.util.Arrays.copyOfRange(pointsArr, 2, 4); + float[] rb = java.util.Arrays.copyOfRange(pointsArr, 4, 6); + + float width = distance(lt, rt); + float height = distance(lt, lb); + + if (width > height) { + float k = (lt[1] - rt[1]) / (lt[0] - rt[0]); // y = k * x + b + + float delta_dis = height; + float delta_x = (float) Math.sqrt((delta_dis * delta_dis) / (k * k + 1)); + float delta_y = Math.abs(k * delta_x); + + if (k > 0) { + pointsArr[0] = lt[0] - delta_x + delta_y; + pointsArr[1] = lt[1] - delta_y - delta_x; + pointsArr[2] = rt[0] + delta_x + delta_y; + pointsArr[3] = rt[1] + delta_y - delta_x; + + pointsArr[4] = rb[0] + delta_x - delta_y; + pointsArr[5] = rb[1] + delta_y + delta_x; + pointsArr[6] = lb[0] - delta_x - delta_y; + pointsArr[7] = lb[1] - delta_y + delta_x; + } else { + pointsArr[0] = lt[0] - delta_x - delta_y; + pointsArr[1] = lt[1] + delta_y - delta_x; + pointsArr[2] = rt[0] + delta_x - delta_y; + pointsArr[3] = rt[1] - delta_y - delta_x; + + pointsArr[4] = rb[0] + delta_x + delta_y; + pointsArr[5] = rb[1] - delta_y + delta_x; + pointsArr[6] = lb[0] - delta_x + delta_y; + pointsArr[7] = lb[1] + delta_y + delta_x; + } + } else { + float k = (lt[1] - rt[1]) / (lt[0] - rt[0]); // y = k * x + b + + float delta_dis = width; + float delta_y = (float) Math.sqrt((delta_dis * delta_dis) / (k * k + 1)); + float delta_x = Math.abs(k * delta_y); + + if (k > 0) { + pointsArr[0] = lt[0] + delta_x - delta_y; + pointsArr[1] = lt[1] - delta_y - delta_x; + pointsArr[2] = rt[0] + delta_x + delta_y; + pointsArr[3] = rt[1] - delta_y + delta_x; + + pointsArr[4] = rb[0] - delta_x + delta_y; + pointsArr[5] = rb[1] + delta_y + delta_x; + pointsArr[6] = lb[0] - delta_x - delta_y; + pointsArr[7] = lb[1] + delta_y - delta_x; + } else { + pointsArr[0] = lt[0] - delta_x - delta_y; + pointsArr[1] = lt[1] - delta_y + delta_x; + pointsArr[2] = rt[0] - delta_x + delta_y; + pointsArr[3] = rt[1] - delta_y - delta_x; + + pointsArr[4] = rb[0] + delta_x + delta_y; + pointsArr[5] = rb[1] + delta_y - delta_x; + pointsArr[6] = lb[0] + delta_x - delta_y; + pointsArr[7] = lb[1] + delta_y + delta_x; + } + } + points = manager.create(pointsArr).reshape(4, 2); + + return points; + } + + private float distance(float[] point1, float[] point2) { + float disX = point1[0] - point2[0]; + float disY = point1[1] - point2[1]; + float dis = (float) Math.sqrt(disX * disX + disY * disY); + return dis; + } + + /** + * Get boxes from the contour or box. + * + * @param contour The predicted contour. + * @param pointsArr The predicted box. + * @return smaller side of box + */ + private int get_mini_boxes(MatOfPoint2f contour, float[][] pointsArr) { + // https://blog.csdn.net/qq_37385726/article/details/82313558 + // bounding_box[1] - rect 返回矩形的长和宽 + RotatedRect rect = Imgproc.minAreaRect(contour); + Mat points = new Mat(); + Imgproc.boxPoints(rect, points); + + float[][] fourPoints = new float[4][2]; + for (int row = 0; row < 4; row++) { + fourPoints[row][0] = (float) points.get(row, 0)[0]; + fourPoints[row][1] = (float) points.get(row, 1)[0]; + } + + float[] tmpPoint = new float[2]; + for (int i = 0; i < 4; i++) { + for (int j = i + 1; j < 4; j++) { + if (fourPoints[j][0] < fourPoints[i][0]) { + tmpPoint[0] = fourPoints[i][0]; + tmpPoint[1] = fourPoints[i][1]; + fourPoints[i][0] = fourPoints[j][0]; + fourPoints[i][1] = fourPoints[j][1]; + fourPoints[j][0] = tmpPoint[0]; + fourPoints[j][1] = tmpPoint[1]; + } + } + } + + int index_1 = 0; + int index_2 = 1; + int index_3 = 2; + int index_4 = 3; + + if (fourPoints[1][1] > fourPoints[0][1]) { + index_1 = 0; + index_4 = 1; + } else { + index_1 = 1; + index_4 = 0; + } + + if (fourPoints[3][1] > fourPoints[2][1]) { + index_2 = 2; + index_3 = 3; + } else { + index_2 = 3; + index_3 = 2; + } + + pointsArr[0] = fourPoints[index_1]; + pointsArr[1] = fourPoints[index_2]; + pointsArr[2] = fourPoints[index_3]; + pointsArr[3] = fourPoints[index_4]; + + int height = rect.boundingRect().height; + int width = rect.boundingRect().width; + int sside = Math.min(height, width); + + // release + points.release(); + + return sside; + } + + /** + * Calculate the score of box. + * + * @param bitmap The binarized image predicted by DB. + * @param points The predicted box + * @return + */ + private float box_score_fast(NDManager manager, NDArray bitmap, NDArray points) { + NDArray box = points.get(":"); + long h = bitmap.getShape().get(0); + long w = bitmap.getShape().get(1); + // xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) + int xmin = box.get(":, 0").min().floor().clip(0, w - 1).toType(DataType.INT32, true).toIntArray()[0]; + int xmax = box.get(":, 0").max().ceil().clip(0, w - 1).toType(DataType.INT32, true).toIntArray()[0]; + int ymin = box.get(":, 1").min().floor().clip(0, h - 1).toType(DataType.INT32, true).toIntArray()[0]; + int ymax = box.get(":, 1").max().ceil().clip(0, h - 1).toType(DataType.INT32, true).toIntArray()[0]; + + NDArray mask = manager.zeros(new Shape(ymax - ymin + 1, xmax - xmin + 1), DataType.UINT8); + + box.set(new NDIndex(":, 0"), box.get(":, 0").sub(xmin)); + box.set(new NDIndex(":, 1"), box.get(":, 1").sub(ymin)); + + //mask - convert from NDArray to Mat + Mat maskMat = NDArrayUtils.uint8NDArrayToMat(mask); + + //mask - convert from NDArray to Mat - 4 rows, 2 cols + Mat boxMat = NDArrayUtils.floatNDArrayToMat(box, CvType.CV_32S); + +// boxMat.reshape(1, new int[]{1, 4, 2}); + List pts = new ArrayList<>(); + MatOfPoint matOfPoint = NDArrayUtils.matToMatOfPoint(boxMat); // new MatOfPoint(boxMat); + pts.add(matOfPoint); + Imgproc.fillPoly(maskMat, pts, new Scalar(1)); + + + NDArray subBitMap = bitmap.get(ymin + ":" + (ymax + 1) + "," + xmin + ":" + (xmax + 1)); + Mat bitMapMat = NDArrayUtils.floatNDArrayToMat(subBitMap); + + Scalar score = Core.mean(bitMapMat, maskMat); + float scoreValue = (float) score.val[0]; + // release + maskMat.release(); + boxMat.release(); + bitMapMat.release(); + + return scoreValue; + } + + @Override + public NDList processInput(TranslatorContext ctx, Image input) { + NDArray img = input.toNDArray(ctx.getNDManager()); + int h = input.getHeight(); + int w = input.getWidth(); + img_height = h; + img_width = w; + + // limit the max side + float ratio = 1.0f; + if (Math.max(h, w) > limit_side_len) { + if (h > w) { + ratio = (float) limit_side_len / (float) h; + } else { + ratio = (float) limit_side_len / (float) w; + } + } + + int resize_h = (int) (h * ratio); + int resize_w = (int) (w * ratio); + + resize_h = Math.round((float) resize_h / 32f) * 32; + resize_w = Math.round((float) resize_w / 32f) * 32; + + ratio_h = resize_h / (float) h; + ratio_w = resize_w / (float) w; + + img = NDImageUtils.resize(img, resize_w, resize_h); + + img = NDImageUtils.toTensor(img); + + img = + NDImageUtils.normalize( + img, + new float[]{0.485f, 0.456f, 0.406f}, + new float[]{0.229f, 0.224f, 0.225f}); + + img = img.expandDims(0); + + return new NDList(img); + } + + @Override + public Batchifier getBatchifier() { + return null; + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/detection/OcrV3Detection.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/detection/OcrV3Detection.java new file mode 100755 index 00000000..f2388d3f --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/detection/OcrV3Detection.java @@ -0,0 +1,73 @@ +package me.aias.example.utils.detection; + +import ai.djl.modality.cv.Image; +import ai.djl.ndarray.NDList; +import ai.djl.repository.zoo.Criteria; +import ai.djl.training.util.ProgressBar; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.file.Paths; +import java.util.concurrent.ConcurrentHashMap; + +public final class OcrV3Detection { + + private static final Logger logger = LoggerFactory.getLogger(OcrV3Detection.class); + + public OcrV3Detection() { + } + + /** + * 中文文本检测 + * @return + */ + public Criteria chDetCriteria() { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, NDList.class) + .optModelPath(Paths.get("models/ch_PP-OCRv3_det_infer_onnx.zip")) + .optTranslator(new OCRDetectionTranslator(new ConcurrentHashMap())) + .optProgress(new ProgressBar()) + .build(); + + return criteria; + } + + /** + * 英文文本检测 + * @return + */ + public Criteria enDetCriteria() { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, NDList.class) + .optModelPath(Paths.get("models/en_PP-OCRv3_det_infer_onnx.zip")) + .optTranslator(new OCRDetectionTranslator(new ConcurrentHashMap())) + .optProgress(new ProgressBar()) + .build(); + + return criteria; + } + + /** + * 多语言文本检测 + * @return + */ + public Criteria mlDetCriteria() { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, NDList.class) + .optModelPath(Paths.get("models/Multilingual_PP-OCRv3_det_infer_onnx.zip")) + .optTranslator(new OCRDetectionTranslator(new ConcurrentHashMap())) + .optProgress(new ProgressBar()) + .build(); + + return criteria; + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/detection/PpWordDetectionTranslator.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/detection/PpWordDetectionTranslator.java new file mode 100755 index 00000000..7fbaa138 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/detection/PpWordDetectionTranslator.java @@ -0,0 +1,107 @@ +package me.aias.example.utils.detection; + +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.BoundingBox; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.translate.Batchifier; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; + +public class PpWordDetectionTranslator implements Translator { + + private final int max_side_len; + + public PpWordDetectionTranslator(Map arguments) { + max_side_len = + arguments.containsKey("maxLength") + ? Integer.parseInt(arguments.get("maxLength").toString()) + : 960; + } + + @Override + public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { + NDArray result = list.singletonOrThrow(); + result = result.squeeze().mul(255f).toType(DataType.UINT8, true).gt(0.3); // thresh=0.3 + boolean[] flattened = result.toBooleanArray(); + Shape shape = result.getShape(); + int w = (int) shape.get(0); + int h = (int) shape.get(1); + boolean[][] grid = new boolean[w][h]; + IntStream.range(0, flattened.length) + .parallel() + .forEach(i -> grid[i / h][i % h] = flattened[i]); + List boxes = new BoundFinder(grid).getBoxes(); + List names = new ArrayList<>(); + List probs = new ArrayList<>(); + int boxSize = boxes.size(); + for (int i = 0; i < boxSize; i++) { + names.add("word"); + probs.add(1.0); + } + return new DetectedObjects(names, probs, boxes); + } + + @Override + public NDList processInput(TranslatorContext ctx, Image input) { + NDArray img = input.toNDArray(ctx.getNDManager()); + int h = input.getHeight(); + int w = input.getWidth(); + int resize_w = w; + int resize_h = h; + + // limit the max side + float ratio = 1.0f; + if (Math.max(resize_h, resize_w) > max_side_len) { + if (resize_h > resize_w) { + ratio = (float) max_side_len / (float) resize_h; + } else { + ratio = (float) max_side_len / (float) resize_w; + } + } + + resize_h = (int) (resize_h * ratio); + resize_w = (int) (resize_w * ratio); + + if (resize_h % 32 == 0) { + resize_h = resize_h; + } else if (Math.floor((float) resize_h / 32f) <= 1) { + resize_h = 32; + } else { + resize_h = (int) Math.floor((float) resize_h / 32f) * 32; + } + + if (resize_w % 32 == 0) { + resize_w = resize_w; + } else if (Math.floor((float) resize_w / 32f) <= 1) { + resize_w = 32; + } else { + resize_w = (int) Math.floor((float) resize_w / 32f) * 32; + } + + img = NDImageUtils.resize(img, resize_w, resize_h); + img = NDImageUtils.toTensor(img); + img = + NDImageUtils.normalize( + img, + new float[]{0.485f, 0.456f, 0.406f}, + new float[]{0.229f, 0.224f, 0.225f}); + img = img.expandDims(0); + return new NDList(img); + } + + @Override + public Batchifier getBatchifier() { + return null; + } + +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/opencv/NDArrayUtils.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/opencv/NDArrayUtils.java new file mode 100755 index 00000000..96cda290 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/opencv/NDArrayUtils.java @@ -0,0 +1,247 @@ +package me.aias.example.utils.opencv; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import org.opencv.core.CvType; +import org.opencv.core.Mat; +import org.opencv.core.MatOfPoint; +import org.opencv.core.Point; + +import java.util.ArrayList; +import java.util.List; + +public class NDArrayUtils { + public static NDArray Sigmoid(NDArray input) { + // Sigmoid 函数,即f(x)=1/(1+e-x) + return input.neg().exp().add(1).pow(-1); + } + + /** + * np.arctan2和np.arctan都是计算反正切值的NumPy函数,但它们的参数和返回值不同。一般来说,np.arctan2的参数为(y, x), + * 返回值为[-π, π]之间的弧度值;而np.arctan的参数为x,返回值为[-π/2, π/2]之间的弧度值。两者之间的换算关系是: + * np.arctan(y/x) = np.arctan2(y, x)(当x>0时), + * 或 np.pi + np.arctan(y/x) = np.arctan2(y, x) (当x<0且y>=0时), + * 或 np.pi - np.arctan(y/x) = np.arctan2(y, x) (当x<0且y<0时)。 + * @param y + * @param x + * @return + */ + public static NDArray arctan2(NDArray y, NDArray x) { + NDArray x_neg = x.lt(0).toType(DataType.INT32, false); + NDArray y_pos = y.gte(0).toType(DataType.INT32, false); + NDArray y_neg = y.lt(0).toType(DataType.INT32, false); + + NDArray theta = y.div(x).atan(); + // np.arctan(y/x) + np.pi = np.arctan2(y, x) (当x<0且y>=0时) + theta = theta.add(x_neg.mul(y_pos).mul((float) Math.PI)); + // np.arctan(y/x) - np.pi = np.arctan2(y, x) (当x<0且y<0时) + theta = theta.add(x_neg.mul(y_neg).mul(-(float) Math.PI)); + + theta = theta.mul(180).div((float) Math.PI); + + return theta; + } + + + public static NDArray maxPool(NDManager manager, NDArray heat, int ksize, int stride, int padding) { + int rows = (int) (heat.getShape().get(0)); + int cols = (int) (heat.getShape().get(1)); + // hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2) + NDArray max_pool2d = manager.zeros(new Shape(rows + 2 * padding, cols + 2 * padding)); + max_pool2d.set(new NDIndex(padding + ":" + (rows + padding) + ","+ padding + ":" + (cols + padding)), heat); + float[][] max_pool2d_arr = NDArrayUtils.floatNDArrayToArray(max_pool2d); + float[][] arr = new float[rows][cols]; + + for (int row = 0; row < rows; row++) { + for (int col = 0; col < cols; col++) { + float max = max_pool2d_arr[row][col]; + for (int i = row; i < row + ksize; i++) { + for (int j = col; j < col + ksize; j++) { + if (max_pool2d_arr[i][j] > max) { + max = max_pool2d_arr[i][j]; + } + } + } + arr[row][col] = max; + } + } + + NDArray hmax = manager.create(arr).reshape(rows, cols); + return hmax; + } + + public static MatOfPoint matToMatOfPoint(Mat mat) { + int rows = mat.rows(); + MatOfPoint matOfPoint = new MatOfPoint(); + + List list = new ArrayList<>(); + for (int i = 0; i < rows; i++) { + Point point = new Point((float) mat.get(i, 0)[0], (float) mat.get(i, 1)[0]); + list.add(point); + } + matOfPoint.fromList(list); + + return matOfPoint; + } + + public static int[][] intNDArrayToArray(NDArray ndArray) { + int rows = (int) (ndArray.getShape().get(0)); + int cols = (int) (ndArray.getShape().get(1)); + int[][] arr = new int[rows][cols]; + + int[] arrs = ndArray.toIntArray(); + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + arr[i][j] = arrs[i * cols + j]; + } + } + return arr; + } + + public static float[][] floatNDArrayToArray(NDArray ndArray) { + int rows = (int) (ndArray.getShape().get(0)); + int cols = (int) (ndArray.getShape().get(1)); + float[][] arr = new float[rows][cols]; + + float[] arrs = ndArray.toFloatArray(); + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + arr[i][j] = arrs[i * cols + j]; + } + } + return arr; + } + + public static double[][] matToDoubleArray(Mat mat) { + int rows = mat.rows(); + int cols = mat.cols(); + + double[][] doubles = new double[rows][cols]; + + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + doubles[i][j] = mat.get(i, j)[0]; + } + } + + return doubles; + } + + public static float[][] matToFloatArray(Mat mat) { + int rows = mat.rows(); + int cols = mat.cols(); + + float[][] floats = new float[rows][cols]; + + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + floats[i][j] = (float) mat.get(i, j)[0]; + } + } + + return floats; + } + + public static byte[][] matToUint8Array(Mat mat) { + int rows = mat.rows(); + int cols = mat.cols(); + + byte[][] bytes = new byte[rows][cols]; + + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + bytes[i][j] = (byte) mat.get(i, j)[0]; + } + } + + return bytes; + } + + public static Mat floatNDArrayToMat(NDArray ndArray, int cvType) { + int rows = (int) (ndArray.getShape().get(0)); + int cols = (int) (ndArray.getShape().get(1)); + Mat mat = new Mat(rows, cols, cvType); + + float[] arrs = ndArray.toFloatArray(); + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + mat.put(i, j, arrs[i * cols + j]); + } + } + return mat; + } + + public static Mat floatNDArrayToMat(NDArray ndArray) { + int rows = (int) (ndArray.getShape().get(0)); + int cols = (int) (ndArray.getShape().get(1)); + Mat mat = new Mat(rows, cols, CvType.CV_32F); + + float[] arrs = ndArray.toFloatArray(); + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + mat.put(i, j, arrs[i * cols + j]); + } + } + + return mat; + + } + + public static Mat uint8NDArrayToMat(NDArray ndArray) { + int rows = (int) (ndArray.getShape().get(0)); + int cols = (int) (ndArray.getShape().get(1)); + Mat mat = new Mat(rows, cols, CvType.CV_8U); + + byte[] arrs = ndArray.toByteArray(); + + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + mat.put(i, j, arrs[i * cols + j]); + } + } + return mat; + } + + public static Mat floatArrayToMat(float[][] arr) { + int rows = arr.length; + int cols = arr[0].length; + Mat mat = new Mat(rows, cols, CvType.CV_32F); + + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + mat.put(i, j, arr[i][j]); + } + } + + return mat; + } + + public static Mat uint8ArrayToMat(byte[][] arr) { + int rows = arr.length; + int cols = arr[0].length; + Mat mat = new Mat(rows, cols, CvType.CV_8U); + + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + mat.put(i, j, arr[i][j]); + } + } + + return mat; + } + + // list 转 Mat + public static Mat toMat(List points) { + Mat mat = new Mat(points.size(), 2, CvType.CV_32F); + for (int i = 0; i < points.size(); i++) { + ai.djl.modality.cv.output.Point point = points.get(i); + mat.put(i, 0, (float) point.getX()); + mat.put(i, 1, (float) point.getY()); + } + + return mat; + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/opencv/OpenCVUtils.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/opencv/OpenCVUtils.java new file mode 100755 index 00000000..e6fe4ee4 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/opencv/OpenCVUtils.java @@ -0,0 +1,207 @@ +package me.aias.example.utils.opencv; + +import ai.djl.ndarray.NDArray; +import me.aias.example.utils.common.Point; +import me.aias.example.utils.common.PointUtils; +import org.opencv.core.CvType; +import org.opencv.core.Mat; +import org.opencv.core.MatOfPoint; +import org.opencv.core.Scalar; +import org.opencv.imgproc.Imgproc; + +import java.awt.image.BufferedImage; +import java.awt.image.DataBufferByte; +import java.util.ArrayList; +import java.util.List; + + +public class OpenCVUtils { + /** + * Mat to BufferedImage + * + * @param mat + * @return + */ + public static BufferedImage mat2Image(Mat mat) { + int width = mat.width(); + int height = mat.height(); + byte[] data = new byte[width * height * (int) mat.elemSize()]; + Imgproc.cvtColor(mat, mat, 4); + mat.get(0, 0, data); + BufferedImage ret = new BufferedImage(width, height, 5); + ret.getRaster().setDataElements(0, 0, width, height, data); + return ret; + } + + /** + * BufferedImage to Mat + * + * @param img + * @return + */ + public static Mat image2Mat(BufferedImage img) { + int width = img.getWidth(); + int height = img.getHeight(); + byte[] data = ((DataBufferByte) img.getRaster().getDataBuffer()).getData(); + Mat mat = new Mat(height, width, CvType.CV_8UC3); + mat.put(0, 0, data); + return mat; + } + + // list 转 Mat + public static org.opencv.core.Mat toMat(List points) { + org.opencv.core.Mat mat = new org.opencv.core.Mat(points.size(), 2, CvType.CV_32F); + for (int i = 0; i < points.size(); i++) { + Point point = points.get(i); + mat.put(i, 0, (float) point.getX()); + mat.put(i, 1, (float) point.getY()); + } + + return mat; + } + + public static Mat warpPerspective(Mat src, Mat dst, Mat warp_mat) { + Mat dstClone = dst.clone(); +// org.opencv.core.Mat mat = new org.opencv.core.Mat(dst.rows(), dst.cols(), CvType.CV_8UC3); + Imgproc.warpPerspective(src, dstClone, warp_mat, dst.size()); + return dstClone; + } + + public static Mat perspectiveTransform(Mat src, Mat srcPoints, Mat dstPoints) { + Mat dst = src.clone(); + Mat warp_mat = Imgproc.getPerspectiveTransform(srcPoints, dstPoints); + Imgproc.warpPerspective(src, dst, warp_mat, dst.size()); + warp_mat.release(); + + return dst; + } + + public static Mat perspectiveTransform(Mat src, Mat dst, Mat srcPoints, Mat dstPoints) { + Mat dstClone = dst.clone(); + Mat warp_mat = Imgproc.getPerspectiveTransform(srcPoints, dstPoints); + Imgproc.warpPerspective(src, dstClone, warp_mat, dst.size()); + warp_mat.release(); + + return dstClone; + } + + /** + * 图片裁剪 + * @param points + * @return + */ + public static int[] imgCrop(float[] points) { + int[] wh = new int[2]; + float[] lt = java.util.Arrays.copyOfRange(points, 0, 2); + float[] rt = java.util.Arrays.copyOfRange(points, 2, 4); + float[] rb = java.util.Arrays.copyOfRange(points, 4, 6); + float[] lb = java.util.Arrays.copyOfRange(points, 6, 8); + wh[0] = (int) Math.max(PointUtils.distance(lt, rt), PointUtils.distance(rb, lb)); + wh[1] = (int) Math.max(PointUtils.distance(lt, lb), PointUtils.distance(rt, rb)); + return wh; + } + + /** + * 转正图片 + * @param mat + * @param points + * @return + */ + public static Mat perspectiveTransform(Mat mat, float[] points) { + float[] lt = java.util.Arrays.copyOfRange(points, 0, 2); + float[] rt = java.util.Arrays.copyOfRange(points, 2, 4); + float[] rb = java.util.Arrays.copyOfRange(points, 4, 6); + float[] lb = java.util.Arrays.copyOfRange(points, 6, 8); + int img_crop_width = (int) Math.max(PointUtils.distance(lt, rt), PointUtils.distance(rb, lb)); + int img_crop_height = (int) Math.max(PointUtils.distance(lt, lb), PointUtils.distance(rt, rb)); + List srcPoints = new ArrayList<>(); + srcPoints.add(new Point((int)lt[0], (int)lt[1])); + srcPoints.add(new Point((int)rt[0], (int)rt[1])); + srcPoints.add(new Point((int)rb[0], (int)rb[1])); + srcPoints.add(new Point((int)lb[0], (int)lb[1])); + List dstPoints = new ArrayList<>(); + dstPoints.add(new Point(0, 0)); + dstPoints.add(new Point(img_crop_width, 0)); + dstPoints.add(new Point(img_crop_width, img_crop_height)); + dstPoints.add(new Point(0, img_crop_height)); + + Mat srcPoint2f = toMat(srcPoints); + Mat dstPoint2f = toMat(dstPoints); + + Mat cvMat = OpenCVUtils.perspectiveTransform(mat, srcPoint2f, dstPoint2f); + srcPoint2f.release(); + dstPoint2f.release(); + return cvMat; + } + /** + * 转正图片 - 废弃 + * @param mat + * @param points + * @return + */ + public Mat perspectiveTransformOld(Mat mat, float[] points) { + List pointList = new ArrayList<>(); + float[][] srcArr = new float[4][2]; + float min_X = Float.MAX_VALUE; + float min_Y = Float.MAX_VALUE; + float max_X = -1; + float max_Y = -1; + + for (int j = 0; j < 4; j++) { + org.opencv.core.Point pt = new org.opencv.core.Point(points[2 * j], points[2 * j + 1]); + pointList.add(pt); + srcArr[j][0] = points[2 * j]; + srcArr[j][1] = points[2 * j + 1]; + if (points[2 * j] > max_X) { + max_X = points[2 * j]; + } + if (points[2 * j] < min_X) { + min_X = points[2 * j]; + } + if (points[2 * j + 1] > max_Y) { + max_Y = points[2 * j + 1]; + } + if (points[2 * j + 1] < min_Y) { + min_Y = points[2 * j + 1]; + } + } + + Mat src = NDArrayUtils.floatArrayToMat(srcArr); + + float width = max_Y - min_Y; + float height = max_X - min_X; + + float[][] dstArr = new float[4][2]; + dstArr[0] = new float[]{0, 0}; + dstArr[1] = new float[]{width - 1, 0}; + dstArr[2] = new float[]{width - 1, height - 1}; + dstArr[3] = new float[]{0, height - 1}; + + Mat dst = NDArrayUtils.floatArrayToMat(dstArr); + return OpenCVUtils.perspectiveTransform(mat, src, dst); + } + + /** + * 画边框 + * @param mat + * @param squares + * @param topK + */ + public static void drawSquares(Mat mat, NDArray squares, int topK) { + for (int i = 0; i < topK; i++) { + float[] points = squares.get(i).toFloatArray(); + List matOfPoints = new ArrayList<>(); + MatOfPoint matOfPoint = new MatOfPoint(); + matOfPoints.add(matOfPoint); + List pointList = new ArrayList<>(); + for (int j = 0; j < 4; j++) { + org.opencv.core.Point pt = new org.opencv.core.Point(points[2 * j], points[2 * j + 1]); + pointList.add(pt); + Imgproc.circle(mat, pt, 10, new Scalar(0, 255, 255), -1); + Imgproc.putText(mat, "" + j, pt, Imgproc.FONT_HERSHEY_SCRIPT_SIMPLEX, 1.0, new Scalar(0, 255, 0), 1); + } + matOfPoint.fromList(pointList); + Imgproc.polylines(mat, matOfPoints, true, new Scalar(200, 200, 0), 5); + } + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/recognition/OcrV3Recognition.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/recognition/OcrV3Recognition.java new file mode 100755 index 00000000..a222aa6a --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/recognition/OcrV3Recognition.java @@ -0,0 +1,323 @@ +package me.aias.example.utils.recognition; + +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.output.Point; +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.opencv.OpenCVImageFactory; +import ai.djl.repository.zoo.Criteria; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.TranslateException; +import me.aias.example.utils.common.RotatedBox; +import me.aias.example.utils.opencv.NDArrayUtils; +import me.aias.example.utils.opencv.OpenCVUtils; +import org.opencv.core.Mat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.awt.image.BufferedImage; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; + +/** + * https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/doc/doc_ch/models_list.md + */ +public final class OcrV3Recognition { + + private static final Logger logger = LoggerFactory.getLogger(OcrV3Recognition.class); + + + public OcrV3Recognition() { + } + + /** + * 中文简体 + * @return + */ + public Criteria chRecCriteria() { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, String.class) + .optModelPath(Paths.get("models/ch_PP-OCRv3_rec_infer_onnx.zip")) + .optProgress(new ProgressBar()) + .optTranslator(new PpWordRecognitionTranslator(new ConcurrentHashMap())) + .build(); + return criteria; + } + + /** + * 中文繁体 + * @return + */ + public Criteria chtRecCriteria() { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, String.class) + .optModelPath(Paths.get("models/chinese_cht_PP-OCRv3_rec_onnx.zip")) + .optProgress(new ProgressBar()) + .optTranslator(new PpWordRecognitionTranslator(new ConcurrentHashMap())) + .build(); + return criteria; + } + + /** + * 英文 + * @return + */ + public Criteria enRecCriteria() { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, String.class) + .optModelPath(Paths.get("models/en_PP-OCRv3_rec_onnx.zip")) + .optProgress(new ProgressBar()) + .optTranslator(new PpWordRecognitionTranslator((new ConcurrentHashMap()))) + .build(); + return criteria; + } + + /** + * 韩语 + * @return + */ + public Criteria koreanRecCriteria() { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, String.class) + .optModelPath(Paths.get("models/korean_PP-OCRv3_rec_onnx.zip")) + .optProgress(new ProgressBar()) + .optTranslator(new PpWordRecognitionTranslator((new ConcurrentHashMap()))) + .build(); + return criteria; + } + + /** + * 日语 + * @return + */ + public Criteria japanRecCriteria() { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, String.class) + .optModelPath(Paths.get("models/japan_PP-OCRv3_rec_onnx.zip")) + .optProgress(new ProgressBar()) + .optTranslator(new PpWordRecognitionTranslator((new ConcurrentHashMap()))) + .build(); + return criteria; + } + + + /** + * 泰米尔语 + * @return + */ + public Criteria taRecCriteria() { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, String.class) + .optModelPath(Paths.get("models/ta_PP-OCRv3_rec_onnx.zip")) + .optProgress(new ProgressBar()) + .optTranslator(new PpWordRecognitionTranslator((new ConcurrentHashMap()))) + .build(); + return criteria; + } + + /** + * 泰卢固语 + * @return + */ + public Criteria teRecCriteria() { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, String.class) + .optModelPath(Paths.get("models/te_PP-OCRv3_rec_onnx.zip")) + .optProgress(new ProgressBar()) + .optTranslator(new PpWordRecognitionTranslator((new ConcurrentHashMap()))) + .build(); + return criteria; + } + + /** + * 卡纳达文 + * @return + */ + public Criteria kaRecCriteria() { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, String.class) + .optModelPath(Paths.get("models/ka_PP-OCRv3_rec_onnx.zip")) + .optProgress(new ProgressBar()) + .optTranslator(new PpWordRecognitionTranslator((new ConcurrentHashMap()))) + .build(); + return criteria; + } + + /** + * 阿拉伯 + * + * arabic_lang = ['ar', 'fa', 'ug', 'ur'] + * + * @return + */ + public Criteria arabicRecCriteria() { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, String.class) + .optModelPath(Paths.get("models/arabic_PP-OCRv3_rec_onnx.zip")) + .optProgress(new ProgressBar()) + .optTranslator(new PpWordRecognitionTranslator((new ConcurrentHashMap()))) + .build(); + return criteria; + } + + /** + * 斯拉夫 + * 西里尔字母(英:Cyrillic,俄:Кириллица)源于希腊字母,普遍认为是由基督教传教士西里尔(827年–869年) + * 在9世纪为了方便在斯拉夫民族传播东正教所创立的,被斯拉夫民族广泛采用 + * + * cyrillic_lang = [ + * 'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava', + * 'dar', 'inh', 'che', 'lbe', 'lez', 'tab' + * ] + * + * @return + */ + public Criteria cyrillicRecCriteria() { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, String.class) + .optModelPath(Paths.get("models/cyrillic_PP-OCRv3_rec_onnx.zip")) + .optProgress(new ProgressBar()) + .optTranslator(new PpWordRecognitionTranslator((new ConcurrentHashMap()))) + .build(); + return criteria; + } + + /** + * 梵文 + * + * devanagari_lang = [ + * 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', + * 'sa', 'bgc' + * ] + * + * @return + */ + public Criteria devanagariRecCriteria() { + Criteria criteria = + Criteria.builder() + .optEngine("OnnxRuntime") + .optModelName("inference") + .setTypes(Image.class, String.class) + .optModelPath(Paths.get("models/devanagari_PP-OCRv3_rec_onnx.zip")) + .optProgress(new ProgressBar()) + .optTranslator(new PpWordRecognitionTranslator((new ConcurrentHashMap()))) + .build(); + return criteria; + } + + + public List predict(NDManager manager, + Image image, Predictor detector, Predictor recognizer) + throws TranslateException { + NDList boxes = detector.predict(image); + // 交给 NDManager自动管理内存 + // attach to manager for automatic memory management + boxes.attach(manager); + + List result = new ArrayList<>(); + long timeInferStart = System.currentTimeMillis(); + + Mat mat = (Mat) image.getWrappedImage(); + + for (int i = 0; i < boxes.size(); i++) { + NDArray box = boxes.get(i); + + float[] pointsArr = box.toFloatArray(); + float[] lt = java.util.Arrays.copyOfRange(pointsArr, 0, 2); + float[] rt = java.util.Arrays.copyOfRange(pointsArr, 2, 4); + float[] rb = java.util.Arrays.copyOfRange(pointsArr, 4, 6); + float[] lb = java.util.Arrays.copyOfRange(pointsArr, 6, 8); + int img_crop_width = (int) Math.max(distance(lt, rt), distance(rb, lb)); + int img_crop_height = (int) Math.max(distance(lt, lb), distance(rt, rb)); + List srcPoints = new ArrayList<>(); + srcPoints.add(new Point(lt[0], lt[1])); + srcPoints.add(new Point(rt[0], rt[1])); + srcPoints.add(new Point(rb[0], rb[1])); + srcPoints.add(new Point(lb[0], lb[1])); + List dstPoints = new ArrayList<>(); + dstPoints.add(new Point(0, 0)); + dstPoints.add(new Point(img_crop_width, 0)); + dstPoints.add(new Point(img_crop_width, img_crop_height)); + dstPoints.add(new Point(0, img_crop_height)); + + Mat srcPoint2f = NDArrayUtils.toMat(srcPoints); + Mat dstPoint2f = NDArrayUtils.toMat(dstPoints); + + Mat cvMat = OpenCVUtils.perspectiveTransform(mat, srcPoint2f, dstPoint2f); + + Image subImg = OpenCVImageFactory.getInstance().fromImage(cvMat); +// ImageUtils.saveImage(subImg, i + ".png", "build/output"); + + subImg = subImg.getSubImage(0, 0, img_crop_width, img_crop_height); + if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) { + subImg = rotateImg(manager, subImg); + } + + String name = recognizer.predict(subImg); + RotatedBox rotatedBox = new RotatedBox(box, name); + result.add(rotatedBox); + + cvMat.release(); + srcPoint2f.release(); + dstPoint2f.release(); + + } + + long timeInferEnd = System.currentTimeMillis(); + System.out.println("time: " + (timeInferEnd - timeInferStart)); + + return result; + } + + private BufferedImage get_rotate_crop_image(Image image, NDArray box) { + return null; + } + + private float distance(float[] point1, float[] point2) { + float disX = point1[0] - point2[0]; + float disY = point1[1] - point2[1]; + float dis = (float) Math.sqrt(disX * disX + disY * disY); + return dis; + } + + private Image rotateImg(NDManager manager, Image image) { + NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1); + return ImageFactory.getInstance().fromNDArray(rotated); + } +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/recognition/PpWordRecognitionTranslator.java b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/recognition/PpWordRecognitionTranslator.java new file mode 100755 index 00000000..cf7ca1d3 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/java/me/aias/example/utils/recognition/PpWordRecognitionTranslator.java @@ -0,0 +1,119 @@ +package me.aias.example.utils.recognition; + +import ai.djl.Model; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.translate.Batchifier; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import ai.djl.util.Utils; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +public class PpWordRecognitionTranslator implements Translator { + private List table; + private final boolean use_space_char; + + public PpWordRecognitionTranslator(Map arguments) { + use_space_char = + arguments.containsKey("use_space_char") + ? Boolean.parseBoolean(arguments.get("use_space_char").toString()) + : true; + } + + @Override + public void prepare(TranslatorContext ctx) throws IOException { + Model model = ctx.getModel(); + try (InputStream is = model.getArtifact("dict.txt").openStream()) { + table = Utils.readLines(is, true); + table.add(0, "blank"); + if(use_space_char){ + table.add(" "); + table.add(" "); + } + else{ + table.add(""); + table.add(""); + } + + } + } + + @Override + public String processOutput(TranslatorContext ctx, NDList list) throws IOException { + StringBuilder sb = new StringBuilder(); + NDArray tokens = list.singletonOrThrow(); + + long[] indices = tokens.get(0).argMax(1).toLongArray(); + boolean[] selection = new boolean[indices.length]; + Arrays.fill(selection, true); + for (int i = 1; i < indices.length; i++) { + if (indices[i] == indices[i - 1]) { + selection[i] = false; + } + } + + // 字符置信度 +// float[] probs = new float[indices.length]; +// for (int row = 0; row < indices.length; row++) { +// NDArray value = tokens.get(0).get(new NDIndex(""+ row +":" + (row + 1) +"," + indices[row] +":" + ( indices[row] + 1))); +// probs[row] = value.toFloatArray()[0]; +// } + + int lastIdx = 0; + for (int i = 0; i < indices.length; i++) { + if (selection[i] == true && indices[i] > 0 && !(i > 0 && indices[i] == lastIdx)) { + sb.append(table.get((int) indices[i])); + } + } + return sb.toString(); + } + + @Override + public NDList processInput(TranslatorContext ctx, Image input) { + NDArray img = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR); + int imgC = 3; + int imgH = 48; + int imgW = 320; + + float max_wh_ratio = (float) imgW / (float) imgH; + + int h = input.getHeight(); + int w = input.getWidth(); + float wh_ratio = (float) w / (float) h; + + max_wh_ratio = Math.max(max_wh_ratio,wh_ratio); + imgW = (int)(imgH * max_wh_ratio); + + int resized_w; + if (Math.ceil(imgH * wh_ratio) > imgW) { + resized_w = imgW; + } else { + resized_w = (int) (Math.ceil(imgH * wh_ratio)); + } + NDArray resized_image = NDImageUtils.resize(img, resized_w, imgH); + resized_image = resized_image.transpose(2, 0, 1).toType(DataType.FLOAT32,false); + resized_image.divi(255f).subi(0.5f).divi(0.5f); + NDArray padding_im = ctx.getNDManager().zeros(new Shape(imgC, imgH, imgW), DataType.FLOAT32); + padding_im.set(new NDIndex(":,:,0:" + resized_w), resized_image); + + padding_im = padding_im.flip(0); + padding_im = padding_im.expandDims(0); + return new NDList(padding_im); + } + + @Override + public Batchifier getBatchifier() { + return null; + } + +} diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/main/resources/log4j2.xml b/archive/1_image_sdks/ocr_iocr_sdk/src/main/resources/log4j2.xml new file mode 100755 index 00000000..4ec55f77 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/src/main/resources/log4j2.xml @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + diff --git a/archive/1_image_sdks/ocr_iocr_sdk/src/test/resources/frame_1.jpg b/archive/1_image_sdks/ocr_iocr_sdk/src/test/resources/frame_1.jpg new file mode 100755 index 00000000..d831f645 Binary files /dev/null and b/archive/1_image_sdks/ocr_iocr_sdk/src/test/resources/frame_1.jpg differ diff --git a/archive/1_image_sdks/ocr_iocr_sdk/target/classes/log4j2.xml b/archive/1_image_sdks/ocr_iocr_sdk/target/classes/log4j2.xml new file mode 100755 index 00000000..4ec55f77 --- /dev/null +++ b/archive/1_image_sdks/ocr_iocr_sdk/target/classes/log4j2.xml @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + diff --git a/archive/1_image_sdks/ocr_iocr_sdk/target/test-classes/frame_1.jpg b/archive/1_image_sdks/ocr_iocr_sdk/target/test-classes/frame_1.jpg new file mode 100755 index 00000000..d831f645 Binary files /dev/null and b/archive/1_image_sdks/ocr_iocr_sdk/target/test-classes/frame_1.jpg differ