package algorithms.classification;

import data.Datapoint;
import data.Dataset;
import functions.kernels.Kernel;
import functions.kernels.LinearKernel;
import java.util.Iterator;

/* loaded from: input_file:algorithms/classification/SVM.class */
public class SVM extends Classifier {
    public static final int RAW_OUTPUT = 0;
    public static final int THRESHOLD_OUTPUT = 1;
    private int outputMode;

    /* renamed from: data, reason: collision with root package name */
    private Datapoint[] f0data;
    private int[] labels;
    private double[] alphas;
    private double[] errors;
    private Kernel kernel;
    private double bLow;
    private double b;
    private double bHigh;
    private double C;
    private int lowInd;
    private int highInd;
    private static final double tolerance = 0.001d;
    private static final double epsilon = 1.0E-12d;

    public SVM(Kernel kernel) {
        this.C = 1.0d;
        this.kernel = kernel;
        this.outputMode = 0;
    }

    public SVM() {
        this(new LinearKernel());
    }

    public void setOutputMode(int i) {
        if (i == 0 || i == 1) {
            this.outputMode = i;
        }
    }

    public void setC(double d) {
        this.C = d;
    }

    public void setKernel(Kernel kernel) {
        this.kernel = kernel;
    }

    @Override // algorithms.Algorithm
    public double predict(Datapoint datapoint) {
        double d = 0.0d;
        for (int i = 0; i < this.alphas.length; i++) {
            if (this.alphas[i] > 0.0d) {
                d += this.labels[i] * this.alphas[i] * this.kernel.evaluate(datapoint, this.f0data[i]);
            }
        }
        return d - this.b;
    }

    @Override // algorithms.classification.Classifier
    public double predictLabel(Datapoint datapoint) {
        return predict(datapoint) < 0.0d ? 1.0d : 2.0d;
    }

    @Override // algorithms.Algorithm
    public Dataset getOutput(Dataset dataset) {
        Iterator<Datapoint> it = dataset.iterator();
        while (it.hasNext()) {
            Datapoint next = it.next();
            double predict = predict(next);
            if (this.outputMode == 1) {
                dataset.setStyle(2);
                if (predict < 0.0d) {
                    next.setLabel(1.0d);
                } else {
                    next.setLabel(2.0d);
                }
            } else {
                dataset.setStyle(3);
                next.setLabel(predict);
                if (predict < dataset.getMinLabel()) {
                    dataset.setMinLabel(predict);
                }
                if (predict > dataset.getMaxLabel()) {
                    dataset.setMaxLabel(predict);
                }
            }
        }
        return dataset;
    }

    @Override // algorithms.Algorithm
    public void train(Dataset dataset) {
        if (dataset.size() < 3) {
            setVisible(false);
            return;
        }
        setVisible(true);
        this.f0data = new Datapoint[dataset.size()];
        this.labels = new int[dataset.size()];
        int i = 0;
        this.bLow = 1.0d;
        this.bHigh = -1.0d;
        this.highInd = -1;
        this.lowInd = -1;
        Iterator<Datapoint> it = dataset.iterator();
        while (it.hasNext()) {
            Datapoint next = it.next();
            this.f0data[i] = next;
            if (next.getLabel() == 1.0d) {
                this.labels[i] = -1;
                this.lowInd = i;
            } else if (next.getLabel() == 2.0d) {
                this.labels[i] = 1;
                this.highInd = i;
            }
            i++;
        }
        if (this.lowInd == -1 || this.highInd == -1) {
            setVisible(false);
            return;
        }
        this.alphas = new double[dataset.size()];
        this.errors = new double[dataset.size()];
        this.errors[this.lowInd] = 1.0d;
        this.errors[this.highInd] = -1.0d;
        int i2 = 0;
        boolean z = true;
        while (true) {
            if (i2 <= 0 && !z) {
                break;
            }
            i2 = 0;
            if (!z) {
                int i3 = 0;
                while (true) {
                    if (i3 >= this.alphas.length) {
                        break;
                    }
                    if ((-this.alphas[i3]) < epsilon && this.alphas[i3] < 0.0d) {
                        this.alphas[i3] = 0.0d;
                    }
                    if (I0(i3)) {
                        if (examineExample(i3)) {
                            i2++;
                        }
                        if (this.bHigh > this.bLow - 0.002d) {
                            i2 = 0;
                            break;
                        }
                    }
                    i3++;
                }
            } else {
                for (int i4 = 0; i4 < this.alphas.length; i4++) {
                    if (examineExample(i4)) {
                        i2++;
                    }
                }
            }
            if (z) {
                z = false;
            } else if (i2 == 0) {
                z = true;
            }
        }
        for (int i5 = 0; i5 < this.alphas.length; i5++) {
            this.f0data[i5].setHighlighted(this.alphas[i5] > 0.0d);
        }
        this.b = (this.bLow + this.bHigh) / 2.0d;
    }

    private boolean I0(int i) {
        return this.alphas[i] > 0.0d && this.alphas[i] < this.C;
    }

    private boolean I1(int i) {
        return this.labels[i] == 1 && this.alphas[i] == 0.0d;
    }

    private boolean I2(int i) {
        return this.labels[i] == -1 && this.alphas[i] == this.C;
    }

    private boolean I3(int i) {
        return this.labels[i] == 1 && this.alphas[i] == this.C;
    }

    private boolean I4(int i) {
        return this.labels[i] == -1 && this.alphas[i] == 0.0d;
    }

    private boolean examineExample(int i) {
        double d;
        int i2 = this.labels[i];
        if (I0(i)) {
            d = this.errors[i];
        } else {
            double d2 = 0.0d;
            for (int i3 = 0; i3 < this.f0data.length; i3++) {
                if (this.alphas[i3] > 0.0d) {
                    d2 += this.alphas[i3] * this.labels[i3] * this.kernel.evaluate(this.f0data[i], this.f0data[i3]);
                }
            }
            d = d2 - i2;
            this.errors[i] = d;
            if ((I1(i) || I2(i)) && d < this.bHigh) {
                this.bHigh = d;
                this.highInd = i;
            } else if ((I3(i) || I4(i)) && d > this.bLow) {
                this.bLow = d;
                this.lowInd = i;
            }
        }
        int i4 = -1;
        boolean z = true;
        if ((I0(i) || I1(i) || I2(i)) && this.bLow - d > 0.002d) {
            z = false;
            i4 = this.lowInd;
        }
        if ((I0(i) || I3(i) || I4(i)) && d - this.bHigh > 0.002d) {
            z = false;
            i4 = this.highInd;
        }
        if (z) {
            return false;
        }
        if (I0(i)) {
            i4 = this.bLow - d > d - this.bHigh ? this.lowInd : this.highInd;
        }
        return takeStep(i4, i);
    }

    private boolean takeStep(int i, int i2) {
        double max;
        double min;
        double d;
        if (i == i2) {
            return false;
        }
        double d2 = this.alphas[i];
        double d3 = this.alphas[i2];
        int i3 = this.labels[i];
        int i4 = this.labels[i2];
        double d4 = this.errors[i];
        double d5 = this.errors[i2];
        int i5 = i3 * i4;
        if (i3 != i4) {
            max = Math.max(0.0d, d3 - d2);
            min = Math.min(this.C, (this.C + d3) - d2);
        } else {
            max = Math.max(0.0d, (d2 + d3) - this.C);
            min = Math.min(this.C, d2 + d3);
        }
        if (max >= min) {
            return false;
        }
        double evaluate = this.kernel.evaluate(this.f0data[i], this.f0data[i]);
        double evaluate2 = this.kernel.evaluate(this.f0data[i], this.f0data[i2]);
        double evaluate3 = this.kernel.evaluate(this.f0data[i2], this.f0data[i2]);
        double d6 = ((2.0d * evaluate2) - evaluate) - evaluate3;
        if (d6 < 0.0d) {
            d = d3 - ((i4 * (d4 - d5)) / d6);
            if (d < max) {
                d = max;
            } else if (d > min) {
                d = min;
            }
        } else {
            double d7 = ((i3 * d4) - (d2 * evaluate)) - ((i5 * d3) * evaluate2);
            double d8 = ((i4 * d5) - ((i5 * d2) * evaluate2)) - (d3 * evaluate3);
            double d9 = d2 + (i5 * (d3 - max));
            double d10 = d2 + (i5 * (d3 - min));
            double d11 = (d9 * d7) + (max * d8) + (0.5d * d9 * d9 * evaluate) + (0.5d * max * max * evaluate3) + (i5 * max * d9 * evaluate2);
            double d12 = (d10 * d7) + (min * d8) + (0.5d * d10 * d10 * evaluate) + (0.5d * min * min * evaluate3) + (i5 * min * d10 * evaluate2);
            d = d11 < d12 - epsilon ? max : d11 > d12 + epsilon ? min : d3;
        }
        if (Math.abs(d - d3) < epsilon * (d + d3 + epsilon)) {
            return false;
        }
        double d13 = d2 + (i5 * (d3 - d));
        for (int i6 = 0; i6 < this.alphas.length; i6++) {
            if (I0(i6) && i6 != i && i6 != i2) {
                double[] dArr = this.errors;
                int i7 = i6;
                dArr[i7] = dArr[i7] + (i3 * (d13 - d2) * this.kernel.evaluate(this.f0data[i], this.f0data[i6])) + (i4 * (d - d3) * this.kernel.evaluate(this.f0data[i2], this.f0data[i6]));
            }
        }
        double[] dArr2 = this.errors;
        dArr2[i] = dArr2[i] + (i3 * (d13 - d2) * evaluate) + (i4 * (d - d3) * evaluate2);
        double[] dArr3 = this.errors;
        dArr3[i2] = dArr3[i2] + (i3 * (d13 - d2) * evaluate2) + (i4 * (d - d3) * evaluate3);
        this.alphas[i] = d13;
        this.alphas[i2] = d;
        this.bLow = -1.7976931348623157E308d;
        this.bHigh = Double.MAX_VALUE;
        this.lowInd = -1;
        this.highInd = -1;
        for (int i8 = 0; i8 < this.alphas.length; i8++) {
            if (I0(i8) || i8 == i || i8 == i2) {
                if (this.errors[i8] < this.bHigh) {
                    this.bHigh = this.errors[i8];
                    this.highInd = i8;
                }
                if (this.errors[i8] > this.bLow) {
                    this.bLow = this.errors[i8];
                    this.lowInd = i8;
                }
            }
        }
        return true;
    }
}
