package algorithms.clustering;

import Jama.Matrix;
import data.Datapoint;
import data.Dataset;
import java.util.Iterator;
import java.util.Random;

/* loaded from: input_file:algorithms/clustering/KMeans.class */
public class KMeans extends Clusterer {
    public static final int OFFSET = 42;
    private int k;
    private Matrix mu;
    private double totalDistance = 0.0d;
    private Random rand = new Random(System.currentTimeMillis());
    private boolean converged = false;

    public KMeans(int i) {
        this.k = i;
        this.mu = new Matrix(i, 2);
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < 2; i3++) {
                double nextDouble = this.rand.nextDouble();
                if (this.rand.nextBoolean()) {
                    nextDouble = -nextDouble;
                }
                this.mu.set(i2, i3, nextDouble * 10.0d);
            }
        }
    }

    @Override // algorithms.Algorithm
    public Dataset getOutput(Dataset dataset) {
        Dataset dataset2 = new Dataset();
        dataset2.setStyle(4);
        for (int i = 0; i < this.k; i++) {
            dataset2.addDatapoint(this.mu.get(i, 0), this.mu.get(i, 1), i + 42);
        }
        return dataset2;
    }

    public boolean converged() {
        return this.converged;
    }

    @Override // algorithms.Algorithm
    public double predict(Datapoint datapoint) {
        double d = Double.POSITIVE_INFINITY;
        int i = 0;
        for (int i2 = 0; i2 < this.k; i2++) {
            double distanceFrom = datapoint.distanceFrom(new Datapoint(this.mu.get(i2, 0), this.mu.get(i2, 1), 1.0d));
            if (distanceFrom < d) {
                d = distanceFrom;
                i = i2;
            }
        }
        return i + 42;
    }

    @Override // algorithms.Algorithm
    public void train(Dataset dataset) {
        this.trainData = dataset;
        assignPoints();
    }

    public double getTotalDistance() {
        return this.totalDistance;
    }

    public boolean assignPoints() {
        boolean z = false;
        Iterator<Datapoint> it = this.trainData.iterator();
        this.totalDistance = 0.0d;
        for (int i = 0; i < this.trainData.size(); i++) {
            Datapoint next = it.next();
            double d = Double.POSITIVE_INFINITY;
            int i2 = 0;
            for (int i3 = 0; i3 < this.k; i3++) {
                double distanceFrom = next.distanceFrom(new Datapoint(this.mu.get(i3, 0), this.mu.get(i3, 1), 1.0d));
                if (distanceFrom < d) {
                    d = distanceFrom;
                    i2 = i3;
                }
            }
            this.totalDistance += d;
            if (((int) next.getLabel()) - 42 != i2) {
                z = true;
            }
            next.setLabel(i2 + 42);
            this.converged = !z;
        }
        return z;
    }

    public void updateMeans() {
        double nextDouble;
        double nextDouble2;
        double[] dArr = new double[this.k];
        double[] dArr2 = new double[this.k];
        int[] iArr = new int[this.k];
        Iterator<Datapoint> it = this.trainData.iterator();
        while (it.hasNext()) {
            Datapoint next = it.next();
            int label = ((int) next.getLabel()) - 42;
            dArr[label] = dArr[label] + next.X();
            dArr2[label] = dArr2[label] + next.Y();
            iArr[label] = iArr[label] + 1;
        }
        for (int i = 0; i < dArr.length; i++) {
            double d = dArr[i];
            double d2 = dArr2[i];
            if (iArr[i] != 0) {
                nextDouble = d / iArr[i];
                nextDouble2 = d2 / iArr[i];
            } else {
                nextDouble = this.rand.nextDouble();
                nextDouble2 = this.rand.nextDouble();
            }
            this.mu.set(i, 0, nextDouble);
            this.mu.set(i, 1, nextDouble2);
        }
    }
}
