package algorithms.classification;

import data.Datapoint;
import data.Dataset;
import java.util.Iterator;

/* loaded from: input_file:algorithms/classification/KNN.class */
public class KNN extends Classifier {
    private int k;

    public KNN(int i) {
        this.k = i;
        setVisible(false);
        this.trainData = null;
    }

    @Override // algorithms.Algorithm
    public double predict(Datapoint datapoint) {
        Iterator<Datapoint> it = this.trainData.kNearest(datapoint, this.k).iterator();
        int[] iArr = new int[5];
        while (it.hasNext()) {
            int label = ((int) it.next().getLabel()) - 1;
            iArr[label] = iArr[label] + 1;
        }
        int i = 0;
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (iArr[i2] > iArr[i]) {
                i = i2;
            } else if (iArr[i2] == iArr[i] && Math.random() > 0.5d) {
                i = i2;
            }
        }
        return i + 1;
    }

    @Override // algorithms.classification.Classifier
    public double predictLabel(Datapoint datapoint) {
        return predict(datapoint);
    }

    public void setK(int i) {
        this.k = i;
    }

    @Override // algorithms.Algorithm
    public void train(Dataset dataset) {
        this.trainData = dataset;
        setVisible(true);
    }

    @Override // algorithms.Algorithm
    public Dataset getOutput(Dataset dataset) {
        dataset.setStyle(2);
        Iterator<Datapoint> it = dataset.iterator();
        while (it.hasNext()) {
            Datapoint next = it.next();
            next.setLabel(predict(next));
        }
        return dataset;
    }
}
