package algorithms.clustering;

import Jama.Matrix;
import data.Datapoint;
import data.Dataset;
import functions.kernels.Kernel;
import functions.kernels.LinearKernel;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;

/* loaded from: input_file:algorithms/clustering/KernelKMeans.class */
public class KernelKMeans extends Clusterer {
    public static final int OFFSET = 42;
    private Random rand;
    private Kernel kernel;
    private Matrix kernelMatrix;
    private int k;
    private boolean converged;
    private int[] counts;

    public KernelKMeans(int i, Kernel kernel) {
        this.kernel = kernel;
        this.k = i;
        this.converged = false;
        this.rand = new Random(System.currentTimeMillis());
    }

    public KernelKMeans(int i) {
        this(i, new LinearKernel());
    }

    @Override // algorithms.Algorithm
    public Dataset getOutput(Dataset dataset) {
        return new Dataset();
    }

    @Override // algorithms.Algorithm
    public double predict(Datapoint datapoint) {
        return 0.0d;
    }

    public boolean converged() {
        return this.converged;
    }

    @Override // algorithms.Algorithm
    public void train(Dataset dataset) {
        this.trainData = dataset;
        this.kernelMatrix = new Matrix(dataset.size(), dataset.size());
        Iterator<Datapoint> it = dataset.iterator();
        this.counts = new int[this.k];
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < dataset.size(); i++) {
            arrayList.add(new Integer(i));
        }
        for (int i2 = 0; i2 < this.k; i2++) {
            dataset.points().get(((Integer) arrayList.remove(this.rand.nextInt(arrayList.size()))).intValue()).setLabel(42 + i2);
            int[] iArr = this.counts;
            int i3 = i2;
            iArr[i3] = iArr[i3] + 1;
        }
        for (int i4 = 0; i4 < dataset.size(); i4++) {
            Datapoint next = it.next();
            Iterator<Datapoint> it2 = dataset.iterator();
            for (int i5 = 0; i5 < dataset.size(); i5++) {
                this.kernelMatrix.set(i4, i5, this.kernel.evaluate(next, it2.next()));
            }
            if (arrayList.contains(new Integer(i4))) {
                int[] iArr2 = this.counts;
                int i6 = this.k - 1;
                iArr2[i6] = iArr2[i6] + 1;
                next.setLabel((42 + this.k) - 1);
            }
        }
    }

    public boolean assignPoints() {
        boolean z = false;
        Iterator<Datapoint> it = this.trainData.iterator();
        Dataset dataset = new Dataset();
        int[] iArr = new int[this.counts.length];
        System.arraycopy(this.counts, 0, iArr, 0, this.counts.length);
        double[] dArr = new double[this.k];
        for (int i = 0; i < this.k; i++) {
            Iterator<Datapoint> it2 = this.trainData.iterator();
            for (int i2 = 0; i2 < this.trainData.size(); i2++) {
                if (it2.next().getLabel() - 42.0d == i) {
                    Iterator<Datapoint> it3 = this.trainData.iterator();
                    for (int i3 = 0; i3 < this.trainData.size(); i3++) {
                        if (it3.next().getLabel() - 42.0d == i) {
                            int i4 = i;
                            dArr[i4] = dArr[i4] + this.kernelMatrix.get(i2, i3);
                        }
                    }
                }
            }
        }
        for (int i5 = 0; i5 < this.trainData.size(); i5++) {
            Datapoint next = it.next();
            int i6 = -1;
            double d = Double.POSITIVE_INFINITY;
            for (int i7 = 0; i7 < this.k; i7++) {
                Iterator<Datapoint> it4 = this.trainData.iterator();
                double d2 = 0.0d;
                for (int i8 = 0; i8 < this.trainData.size(); i8++) {
                    if (it4.next().getLabel() - 42.0d == i7) {
                        d2 += this.kernelMatrix.get(i5, i8);
                    }
                }
                int i9 = this.counts[i7];
                if (i9 != 0) {
                    double d3 = (this.kernelMatrix.get(i5, i5) - ((2.0d / i9) * d2)) + (dArr[i7] / (i9 * i9));
                    if (d3 < d) {
                        d = d3;
                        i6 = i7;
                    }
                }
            }
            int label = ((int) next.getLabel()) - 42;
            iArr[label] = iArr[label] - 1;
            int i10 = i6;
            iArr[i10] = iArr[i10] + 1;
            if (label != i6) {
                z = true;
            }
            Datapoint m0clone = next.m0clone();
            m0clone.setLabel(i6 + 42);
            dataset.addDatapoint(m0clone);
        }
        this.trainData.clear();
        this.trainData.addAll(dataset);
        this.counts = iArr;
        this.converged = !z;
        return z;
    }
}
