mirror of
https://gitee.com/mymagicpower/AIAS.git
synced 2024-12-02 04:08:21 +08:00
update face feature similarity calculation.
This commit is contained in:
parent
132a1a3209
commit
8d30de3c53
@ -32,10 +32,12 @@ public final class FeatureComparisonExample {
|
||||
|
||||
public static void main(String[] args) throws IOException, ModelException, TranslateException {
|
||||
|
||||
Path imageFile1 = Paths.get("src/test/resources/kana1.jpg");
|
||||
Path imageFile1 = Paths.get("src/test/resources/kana1.png");
|
||||
Image img1 = ImageFactory.getInstance().fromFile(imageFile1);
|
||||
Path imageFile2 = Paths.get("src/test/resources/kana2.jpg");
|
||||
Path imageFile2 = Paths.get("src/test/resources/kana2.png");
|
||||
Image img2 = ImageFactory.getInstance().fromFile(imageFile2);
|
||||
Path imageFile3 = Paths.get("src/test/resources/beauty1.png");
|
||||
Image img3 = ImageFactory.getInstance().fromFile(imageFile3);
|
||||
|
||||
FaceFeature faceFeature = new FaceFeature();
|
||||
try (ZooModel<Image, float[]> model = ModelZoo.loadModel(faceFeature.criteria());
|
||||
@ -45,8 +47,12 @@ public final class FeatureComparisonExample {
|
||||
logger.info("face1 feature: " + Arrays.toString(feature1));
|
||||
float[] feature2 = predictor.predict(img2);
|
||||
logger.info("face2 feature: " + Arrays.toString(feature2));
|
||||
float[] feature3 = predictor.predict(img3);
|
||||
logger.info("face3 feature: " + Arrays.toString(feature3));
|
||||
|
||||
logger.info("kana1 - kana2 Similarity: "+ Float.toString(faceFeature.calculSimilar(feature1, feature2)));
|
||||
logger.info("kana1 - beauty1 Similarity: "+ Float.toString(faceFeature.calculSimilar(feature1, feature3)));
|
||||
|
||||
logger.info("相似度: "+ Float.toString(faceFeature.calculSimilar(feature1, feature2)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ public final class FeatureExtractionExample {
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws IOException, ModelException, TranslateException {
|
||||
Path imageFile = Paths.get("src/test/resources/kana1.jpg");
|
||||
Path imageFile = Paths.get("src/test/resources/kana1.png");
|
||||
Image img = ImageFactory.getInstance().fromFile(imageFile);
|
||||
|
||||
FaceFeature faceFeature = new FaceFeature();
|
||||
|
@ -14,7 +14,7 @@ public final class FaceFeature {
|
||||
Criteria.builder()
|
||||
.optEngine("PaddlePaddle")
|
||||
.setTypes(Image.class, float[].class)
|
||||
.optModelUrls("https://aias-home.oss-cn-beijing.aliyuncs.com/models/sec_models/arcface_iresnet50_v1.0_infer.zip")
|
||||
.optModelUrls("https://aias-home.oss-cn-beijing.aliyuncs.com/models/sec_models/MobileFace.zip")
|
||||
// .optModelUrls("/Users/calvin/Downloads/models/arcface_iresnet50_v1.0_infer/")
|
||||
.optModelName("inference")
|
||||
.optTranslator(new FaceFeatureTranslator())
|
||||
@ -34,6 +34,6 @@ public final class FaceFeature {
|
||||
mod1 += feature1[i] * feature1[i];
|
||||
mod2 += feature2[i] * feature2[i];
|
||||
}
|
||||
return (float) ((ret / Math.sqrt(mod1) / Math.sqrt(mod2) + 1) / 2.0f);
|
||||
return (float) ((ret / Math.sqrt(mod1) / Math.sqrt(mod2)));
|
||||
}
|
||||
}
|
||||
|
@ -15,58 +15,62 @@ import ai.djl.translate.TranslatorContext;
|
||||
|
||||
public final class FaceFeatureTranslator implements Translator<Image, float[]> {
|
||||
|
||||
public FaceFeatureTranslator() {}
|
||||
|
||||
@Override
|
||||
public NDList processInput(TranslatorContext ctx, Image input){
|
||||
NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);
|
||||
|
||||
float percent = 128f / Math.min(input.getWidth(), input.getHeight());
|
||||
int resizedWidth = Math.round(input.getWidth() * percent);
|
||||
int resizedHeight = Math.round(input.getHeight() * percent);
|
||||
// img = img.resize((resizedWidth, resizedHeight), Image.LANCZOS)
|
||||
|
||||
array = NDImageUtils.resize(array,resizedWidth,resizedHeight);
|
||||
array = NDImageUtils.centerCrop(array,112,112);
|
||||
|
||||
// The network by default takes float32
|
||||
if (!array.getDataType().equals(DataType.FLOAT32)) {
|
||||
array = array.toType(DataType.FLOAT32, false);
|
||||
public FaceFeatureTranslator() {
|
||||
}
|
||||
|
||||
array = array.transpose(2, 0, 1).div(255f); // HWC -> CHW RGB
|
||||
@Override
|
||||
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||
NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);
|
||||
|
||||
NDArray mean =
|
||||
ctx.getNDManager().create(new float[] {0.5f, 0.5f, 0.5f}, new Shape(3, 1, 1));
|
||||
NDArray std =
|
||||
ctx.getNDManager().create(new float[] {0.5f, 0.5f, 0.5f}, new Shape(3, 1, 1));
|
||||
// float percent = 128f / Math.min(input.getWidth(), input.getHeight());
|
||||
// int resizedWidth = Math.round(input.getWidth() * percent);
|
||||
// int resizedHeight = Math.round(input.getHeight() * percent);
|
||||
// img = img.resize((resizedWidth, resizedHeight), Image.LANCZOS)
|
||||
|
||||
array = array.sub(mean);
|
||||
array = array.div(std);
|
||||
// array = NDImageUtils.resize(array,resizedWidth,resizedHeight);
|
||||
// array = NDImageUtils.centerCrop(array,112,112);
|
||||
array = NDImageUtils.resize(array, 112, 112);
|
||||
|
||||
// The network by default takes float32
|
||||
if (!array.getDataType().equals(DataType.FLOAT32)) {
|
||||
array = array.toType(DataType.FLOAT32, false);
|
||||
}
|
||||
|
||||
array = array.transpose(2, 0, 1).div(255f); // HWC -> CHW RGB
|
||||
|
||||
NDArray mean =
|
||||
ctx.getNDManager().create(new float[]{0.5f, 0.5f, 0.5f}, new Shape(3, 1, 1));
|
||||
NDArray std =
|
||||
ctx.getNDManager().create(new float[]{0.5f, 0.5f, 0.5f}, new Shape(3, 1, 1));
|
||||
|
||||
array = array.sub(mean);
|
||||
array = array.div(std);
|
||||
|
||||
// array = array.expandDims(0);
|
||||
|
||||
return new NDList(array);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] processOutput(TranslatorContext ctx, NDList list) {
|
||||
NDList result = new NDList();
|
||||
long numOutputs = list.singletonOrThrow().getShape().get(0);
|
||||
for (int i = 0; i < numOutputs; i++) {
|
||||
result.add(list.singletonOrThrow().get(i));
|
||||
return new NDList(array);
|
||||
}
|
||||
float[][] embeddings = result.stream().map(NDArray::toFloatArray).toArray(float[][]::new);
|
||||
float[] feature = new float[embeddings.length];
|
||||
for (int i = 0; i < embeddings.length; i++) {
|
||||
feature[i] = embeddings[i][0];
|
||||
}
|
||||
return feature;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public Batchifier getBatchifier() {
|
||||
return Batchifier.STACK;
|
||||
}
|
||||
@Override
|
||||
public float[] processOutput(TranslatorContext ctx, NDList list) {
|
||||
NDList result = new NDList();
|
||||
long numOutputs = list.singletonOrThrow().getShape().get(0);
|
||||
for (int i = 0; i < numOutputs; i++) {
|
||||
result.add(list.singletonOrThrow().get(i));
|
||||
}
|
||||
float[][] embeddings = result.stream().map(NDArray::toFloatArray).toArray(float[][]::new);
|
||||
float[] feature = new float[embeddings.length];
|
||||
for (int i = 0; i < embeddings.length; i++) {
|
||||
feature[i] = embeddings[i][0];
|
||||
}
|
||||
return feature;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public Batchifier getBatchifier() {
|
||||
return Batchifier.STACK;
|
||||
}
|
||||
}
|
||||
|
Binary file not shown.
After Width: | Height: | Size: 71 KiB |
Binary file not shown.
After Width: | Height: | Size: 222 KiB |
Binary file not shown.
After Width: | Height: | Size: 395 KiB |
Loading…
Reference in New Issue
Block a user