package ai.fritz.fritzvisionsegmentation;

import ai.fritz.fritzvisionsegmentation.FritzVisionSegmentPredictorOptions;
import ai.fritz.vision.inputs.FritzVisionImage;
import ai.fritz.vision.inputs.PreparedImage;
import ai.fritz.vision.predictors.FritzVisionPredictor;
import android.graphics.Bitmap;
import android.graphics.Point;
import android.util.Log;
import android.util.Size;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;

/* loaded from: classes.dex */
public class FritzVisionSegmentPredictor extends FritzVisionPredictor<FritzVisionSegmentResult> {
    private static final int NUM_CHANNELS = 3;
    private static final String TAG = "FritzVisionSegmentPredictor";
    private FloatBuffer inputBuffer;
    private String inputLayerName;
    private int inputSize;
    private int[] intValues;
    private FritzVisionSegmentPredictorOptions options;
    private FloatBuffer outputBuffer;
    private String outputLayerName;
    private int outputSize;
    private MaskType[] segmentClassifications;

    public FritzVisionSegmentPredictor(SegmentModel segmentModel) {
        this(segmentModel, new FritzVisionSegmentPredictorOptions.Builder().build());
    }

    public FritzVisionSegmentPredictor(SegmentModel segmentModel, FritzVisionSegmentPredictorOptions fritzVisionSegmentPredictorOptions) {
        super(segmentModel);
        this.inputSize = segmentModel.getInputSize();
        this.outputSize = segmentModel.getOutputSize();
        this.inputLayerName = segmentModel.getInputLayerName();
        this.outputLayerName = segmentModel.getOutputLayerName();
        this.segmentClassifications = setTargetClassifications(segmentModel.getClassifications(), fritzVisionSegmentPredictorOptions.getTargetSegments());
        this.intValues = new int[this.inputSize * this.inputSize];
        this.inputBuffer = FloatBuffer.allocate(this.inputSize * this.inputSize * 3);
        this.outputBuffer = FloatBuffer.allocate(this.outputSize * this.outputSize * this.segmentClassifications.length);
        this.options = fritzVisionSegmentPredictorOptions;
    }

    private void calculateSegment(List<FritzVisionMask> list, Size size, int i, int i2) {
        Iterator<FritzVisionMask> it = list.iterator();
        while (it.hasNext()) {
            for (FritzVisionPoint fritzVisionPoint : it.next().getMaskPoints()) {
                fritzVisionPoint.setSegmentBox(fritzVisionPoint.calculateScaledPixel(size.getWidth() / this.outputSize, size.getHeight() / this.outputSize, i, i2));
            }
        }
    }

    private int getColOffset(int i) {
        return i * this.segmentClassifications.length;
    }

    private int getRowOffset(int i) {
        return i * this.outputSize * this.segmentClassifications.length;
    }

    private List<FritzVisionMask> postprocess() {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.outputSize; i++) {
            for (int i2 = 0; i2 < this.outputSize; i2++) {
                float f = 0.0f;
                int i3 = 0;
                float f2 = 0.0f;
                int i4 = 0;
                for (int i5 = 0; i5 < this.segmentClassifications.length; i5++) {
                    MaskType maskType = this.segmentClassifications[i5];
                    float f3 = this.outputBuffer.get(getRowOffset(i) + getColOffset(i2) + i5);
                    if (!MaskType.NONE.equals(maskType) && f3 > f) {
                        i4 = i5;
                        f = f3;
                    }
                    if (f3 > f2) {
                        i3 = i5;
                        f2 = f3;
                    }
                }
                MaskType maskType2 = this.options.getTargetConfidenceThreshold() > 0.0f ? f >= this.options.getTargetConfidenceThreshold() ? this.segmentClassifications[i4] : MaskType.NONE : this.segmentClassifications[i3];
                FritzVisionPoint fritzVisionPoint = new FritzVisionPoint(new Point(i2, i), f2);
                if (hashMap.containsKey(maskType2)) {
                    ((FritzVisionMask) hashMap.get(maskType2)).addPoint(fritzVisionPoint);
                } else {
                    FritzVisionMask fritzVisionMask = new FritzVisionMask(maskType2);
                    fritzVisionMask.addPoint(fritzVisionPoint);
                    hashMap.put(maskType2, fritzVisionMask);
                }
            }
        }
        return new ArrayList(hashMap.values());
    }

    private void preprocess(Bitmap bitmap) {
        bitmap.getPixels(this.intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
        this.inputBuffer.rewind();
        for (int i = 0; i < this.inputSize; i++) {
            for (int i2 = 0; i2 < this.inputSize; i2++) {
                int i3 = this.intValues[(this.inputSize * i) + i2];
                int i4 = i2 * 3;
                this.inputBuffer.put((this.inputSize * i * 3) + i4 + 2, ((i3 & 255) / 255.0f) - 0.5f);
                this.inputBuffer.put((this.inputSize * i * 3) + i4 + 1, (((i3 >> 8) & 255) / 255.0f) - 0.5f);
                this.inputBuffer.put((this.inputSize * i * 3) + i4, (((i3 >> 16) & 255) / 255.0f) - 0.5f);
            }
        }
    }

    private MaskType[] setTargetClassifications(MaskType[] maskTypeArr, List<MaskType> list) {
        if (list == null) {
            return maskTypeArr;
        }
        for (int i = 0; i < maskTypeArr.length; i++) {
            if (!list.contains(maskTypeArr[i])) {
                maskTypeArr[i] = MaskType.NONE;
            }
        }
        return maskTypeArr;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // ai.fritz.vision.predictors.FritzVisionPredictor
    public FritzVisionSegmentResult predict(FritzVisionImage fritzVisionImage) {
        PreparedImage create = PreparedImage.create(fritzVisionImage, this.options.getCropAndScaleOption(), new Size(this.inputSize, this.inputSize));
        this.outputBuffer.rewind();
        String[] strArr = {this.outputLayerName};
        preprocess(create.getBitmapForModel());
        long currentTimeMillis = System.currentTimeMillis();
        this.interpreter.feed(this.inputLayerName, this.inputBuffer, 1, this.inputSize, this.inputSize, 3);
        this.interpreter.run(strArr);
        this.interpreter.fetch(this.outputLayerName, this.outputBuffer);
        Log.d(TAG, "INFERENCE TOOK: " + (System.currentTimeMillis() - currentTimeMillis) + "ms");
        List<FritzVisionMask> postprocess = postprocess();
        calculateSegment(postprocess, create.getTargetInferenceSize(), create.getOffsetX(), create.getOffsetY());
        return new FritzVisionSegmentResult(fritzVisionImage, postprocess);
    }

    public void setOptions(FritzVisionSegmentPredictorOptions fritzVisionSegmentPredictorOptions) {
        this.options = fritzVisionSegmentPredictorOptions;
        this.segmentClassifications = setTargetClassifications(this.segmentClassifications, fritzVisionSegmentPredictorOptions.getTargetSegments());
    }
}
