update face feature similarity calculation.

This commit is contained in:
Calvin 2022-10-17 11:27:14 +08:00
parent 132a1a3209
commit 8d30de3c53
7 changed files with 61 additions and 51 deletions

View File

@ -32,10 +32,12 @@ public final class FeatureComparisonExample {
public static void main(String[] args) throws IOException, ModelException, TranslateException {
Path imageFile1 = Paths.get("src/test/resources/kana1.jpg");
Path imageFile1 = Paths.get("src/test/resources/kana1.png");
Image img1 = ImageFactory.getInstance().fromFile(imageFile1);
Path imageFile2 = Paths.get("src/test/resources/kana2.jpg");
Path imageFile2 = Paths.get("src/test/resources/kana2.png");
Image img2 = ImageFactory.getInstance().fromFile(imageFile2);
Path imageFile3 = Paths.get("src/test/resources/beauty1.png");
Image img3 = ImageFactory.getInstance().fromFile(imageFile3);
FaceFeature faceFeature = new FaceFeature();
try (ZooModel<Image, float[]> model = ModelZoo.loadModel(faceFeature.criteria());
@ -45,8 +47,12 @@ public final class FeatureComparisonExample {
logger.info("face1 feature: " + Arrays.toString(feature1));
float[] feature2 = predictor.predict(img2);
logger.info("face2 feature: " + Arrays.toString(feature2));
float[] feature3 = predictor.predict(img3);
logger.info("face3 feature: " + Arrays.toString(feature3));
logger.info("kana1 - kana2 Similarity "+ Float.toString(faceFeature.calculSimilar(feature1, feature2)));
logger.info("kana1 - beauty1 Similarity "+ Float.toString(faceFeature.calculSimilar(feature1, feature3)));
logger.info("相似度: "+ Float.toString(faceFeature.calculSimilar(feature1, feature2)));
}
}
}

View File

@ -31,7 +31,7 @@ public final class FeatureExtractionExample {
}
public static void main(String[] args) throws IOException, ModelException, TranslateException {
Path imageFile = Paths.get("src/test/resources/kana1.jpg");
Path imageFile = Paths.get("src/test/resources/kana1.png");
Image img = ImageFactory.getInstance().fromFile(imageFile);
FaceFeature faceFeature = new FaceFeature();

View File

@ -14,7 +14,7 @@ public final class FaceFeature {
Criteria.builder()
.optEngine("PaddlePaddle")
.setTypes(Image.class, float[].class)
.optModelUrls("https://aias-home.oss-cn-beijing.aliyuncs.com/models/sec_models/arcface_iresnet50_v1.0_infer.zip")
.optModelUrls("https://aias-home.oss-cn-beijing.aliyuncs.com/models/sec_models/MobileFace.zip")
// .optModelUrls("/Users/calvin/Downloads/models/arcface_iresnet50_v1.0_infer/")
.optModelName("inference")
.optTranslator(new FaceFeatureTranslator())
@ -34,6 +34,6 @@ public final class FaceFeature {
mod1 += feature1[i] * feature1[i];
mod2 += feature2[i] * feature2[i];
}
return (float) ((ret / Math.sqrt(mod1) / Math.sqrt(mod2) + 1) / 2.0f);
return (float) ((ret / Math.sqrt(mod1) / Math.sqrt(mod2)));
}
}

View File

@ -15,58 +15,62 @@ import ai.djl.translate.TranslatorContext;
public final class FaceFeatureTranslator implements Translator<Image, float[]> {
public FaceFeatureTranslator() {}
@Override
public NDList processInput(TranslatorContext ctx, Image input){
NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);
float percent = 128f / Math.min(input.getWidth(), input.getHeight());
int resizedWidth = Math.round(input.getWidth() * percent);
int resizedHeight = Math.round(input.getHeight() * percent);
// img = img.resize((resizedWidth, resizedHeight), Image.LANCZOS)
array = NDImageUtils.resize(array,resizedWidth,resizedHeight);
array = NDImageUtils.centerCrop(array,112,112);
// The network by default takes float32
if (!array.getDataType().equals(DataType.FLOAT32)) {
array = array.toType(DataType.FLOAT32, false);
public FaceFeatureTranslator() {
}
array = array.transpose(2, 0, 1).div(255f); // HWC -> CHW RGB
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);
NDArray mean =
ctx.getNDManager().create(new float[] {0.5f, 0.5f, 0.5f}, new Shape(3, 1, 1));
NDArray std =
ctx.getNDManager().create(new float[] {0.5f, 0.5f, 0.5f}, new Shape(3, 1, 1));
// float percent = 128f / Math.min(input.getWidth(), input.getHeight());
// int resizedWidth = Math.round(input.getWidth() * percent);
// int resizedHeight = Math.round(input.getHeight() * percent);
// img = img.resize((resizedWidth, resizedHeight), Image.LANCZOS)
array = array.sub(mean);
array = array.div(std);
// array = NDImageUtils.resize(array,resizedWidth,resizedHeight);
// array = NDImageUtils.centerCrop(array,112,112);
array = NDImageUtils.resize(array, 112, 112);
// The network by default takes float32
if (!array.getDataType().equals(DataType.FLOAT32)) {
array = array.toType(DataType.FLOAT32, false);
}
array = array.transpose(2, 0, 1).div(255f); // HWC -> CHW RGB
NDArray mean =
ctx.getNDManager().create(new float[]{0.5f, 0.5f, 0.5f}, new Shape(3, 1, 1));
NDArray std =
ctx.getNDManager().create(new float[]{0.5f, 0.5f, 0.5f}, new Shape(3, 1, 1));
array = array.sub(mean);
array = array.div(std);
// array = array.expandDims(0);
return new NDList(array);
}
@Override
public float[] processOutput(TranslatorContext ctx, NDList list) {
NDList result = new NDList();
long numOutputs = list.singletonOrThrow().getShape().get(0);
for (int i = 0; i < numOutputs; i++) {
result.add(list.singletonOrThrow().get(i));
return new NDList(array);
}
float[][] embeddings = result.stream().map(NDArray::toFloatArray).toArray(float[][]::new);
float[] feature = new float[embeddings.length];
for (int i = 0; i < embeddings.length; i++) {
feature[i] = embeddings[i][0];
}
return feature;
}
/** {@inheritDoc} */
@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
@Override
public float[] processOutput(TranslatorContext ctx, NDList list) {
NDList result = new NDList();
long numOutputs = list.singletonOrThrow().getShape().get(0);
for (int i = 0; i < numOutputs; i++) {
result.add(list.singletonOrThrow().get(i));
}
float[][] embeddings = result.stream().map(NDArray::toFloatArray).toArray(float[][]::new);
float[] feature = new float[embeddings.length];
for (int i = 0; i < embeddings.length; i++) {
feature[i] = embeddings[i][0];
}
return feature;
}
/**
* {@inheritDoc}
*/
@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 71 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 222 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 395 KiB