package gui.options;

import algorithms.classification.Classifier;
import algorithms.classification.SVM;
import data.Dataset;
import functions.kernels.GaussianKernel;
import functions.kernels.Kernel;
import functions.kernels.LinearKernel;
import functions.kernels.PolynomialKernel;
import gui.Plotter;
import java.awt.CardLayout;
import java.awt.Component;
import java.awt.Dimension;
import java.awt.event.ActionEvent;
import java.awt.event.ItemEvent;
import java.awt.event.ItemListener;
import javax.swing.Box;
import javax.swing.JButton;
import javax.swing.JComboBox;
import javax.swing.JLabel;
import javax.swing.JOptionPane;
import javax.swing.JPanel;
import javax.swing.JTextField;

/* loaded from: input_file:gui/options/SVMPanel.class */
public class SVMPanel extends OptionPanel implements ItemListener {
    private JPanel pan;
    private JPanel kernelOpt;
    private Plotter plot;
    private JButton train;
    private JButton cv;
    private JTextField cBox;
    private JTextField gaussBeta;
    private JTextField polyBeta;
    private JComboBox kernel;
    private JComboBox outputType;
    private JLabel trainErrors;
    private JLabel testErrors;
    private boolean trained = false;

    public SVMPanel(Plotter plotter) {
        this.plot = plotter;
        this.alg = new SVM();
    }

    @Override // gui.options.OptionPanel
    protected void layoutPanel() {
        this.kernel = new JComboBox();
        this.kernel.addItem("Linear");
        this.kernel.addItem("Gaussian");
        this.kernel.addItem("Polynomial");
        this.kernel.addItemListener(this);
        JPanel jPanel = new JPanel();
        jPanel.add(new JLabel("Select kernel: "));
        jPanel.add(this.kernel);
        add(jPanel);
        JPanel jPanel2 = new JPanel();
        jPanel2.add(new JLabel("Enter beta: "));
        this.gaussBeta = new JTextField("1.0", 3);
        this.gaussBeta.addActionListener(this);
        jPanel2.add(this.gaussBeta);
        JPanel jPanel3 = new JPanel();
        jPanel3.add(new JLabel("Enter exponent: "));
        this.polyBeta = new JTextField("2", 3);
        this.polyBeta.addActionListener(this);
        jPanel3.add(this.polyBeta);
        this.kernelOpt = new JPanel();
        this.kernelOpt.setMaximumSize(new Dimension(200, 100));
        this.kernelOpt.setLayout(new CardLayout());
        this.kernelOpt.add(new JPanel(), "Linear");
        this.kernelOpt.add(jPanel2, "Gaussian");
        this.kernelOpt.add(jPanel3, "Polynomial");
        JPanel jPanel4 = new JPanel();
        jPanel4.add(new JLabel("Select colouring mode: "));
        this.outputType = new JComboBox();
        this.outputType.addItem("Thresholded");
        this.outputType.addItem("Raw");
        this.outputType.addItemListener(this);
        jPanel4.add(this.outputType);
        this.pan = new JPanel();
        this.pan.add(new JLabel("Enter C: "));
        this.cBox = new JTextField("1.0", 3);
        this.cBox.addActionListener(this);
        this.pan.add(this.cBox);
        add(this.kernelOpt);
        add(jPanel4);
        add(this.pan);
        this.train = new JButton("Train model");
        this.train.setAlignmentX(0.5f);
        this.train.addActionListener(this);
        add(this.train);
        add(Box.createVerticalGlue());
        this.trainErrors = new JLabel("Training misclassifications: 0/0");
        this.trainErrors.setAlignmentX(0.5f);
        add(this.trainErrors);
        this.testErrors = new JLabel("Testing misclassifications: 0/0");
        this.testErrors.setAlignmentX(0.5f);
        add(this.testErrors);
        this.cv = new JButton("Cross validate");
        this.cv.addActionListener(this);
        this.cv.setAlignmentX(0.5f);
        add(this.cv);
    }

    public void updateUI() {
        if (this.trained) {
            Dataset trainingData = this.plot.getTrainingData();
            this.trainErrors.setText("Training misclassifications: " + ((Classifier) this.alg).classificationErrors(trainingData) + "/" + trainingData.size());
            Dataset testData = this.plot.getTestData();
            this.testErrors.setText("Testing misclassifications: " + ((Classifier) this.alg).classificationErrors(testData) + "/" + testData.size());
        }
        super.updateUI();
    }

    private void trainModel() {
        Dataset trainingData = this.plot.getTrainingData();
        this.plot.removeAlgorithm(this.alg);
        Kernel kernel = null;
        String str = (String) this.kernel.getSelectedItem();
        if (str.equals("Linear")) {
            kernel = new LinearKernel();
        } else if (str.equals("Gaussian")) {
            kernel = new GaussianKernel(Double.parseDouble(this.gaussBeta.getText()));
        } else if (str.equals("Polynomial")) {
            kernel = new PolynomialKernel(Double.parseDouble(this.polyBeta.getText()));
        }
        this.alg = new SVM(kernel);
        ((SVM) this.alg).setC(Double.parseDouble(this.cBox.getText()));
        String str2 = (String) this.outputType.getSelectedItem();
        if (str2.equals("Thresholded")) {
            ((SVM) this.alg).setOutputMode(1);
        } else if (str2.equals("Raw")) {
            ((SVM) this.alg).setOutputMode(0);
        }
        this.alg.train(trainingData);
        this.alg.setVisible(true);
        this.plot.addAlgorithm(this.alg);
        this.trained = true;
    }

    private void crossValidate() {
        Dataset trainingData = this.plot.getTrainingData();
        int size = trainingData.size() < 10 ? trainingData.size() : 10;
        if (size == 0) {
            JOptionPane.showMessageDialog((Component) null, "No data has been added.", "Error", 0);
            return;
        }
        Dataset[] fold = trainingData.fold(size);
        double d = 0.0d;
        for (int i = 0; i < fold.length; i++) {
            Dataset dataset = new Dataset();
            for (int i2 = 0; i2 < fold.length; i2++) {
                if (i2 != i) {
                    dataset.addAll(fold[i2]);
                }
            }
            this.alg.train(dataset);
            d += ((Classifier) this.alg).classificationErrors(fold[i]);
        }
        trainModel();
        JOptionPane.showMessageDialog((Component) null, size + "-fold Cross Validation: \n Average misclassifications per point = " + (d / trainingData.size()), "CV", 1);
    }

    @Override // gui.options.OptionPanel
    public void actionPerformed(ActionEvent actionEvent) {
        if (actionEvent.getSource() == this.train) {
            trainModel();
        } else if (actionEvent.getSource() == this.cv) {
            crossValidate();
        }
    }

    public void itemStateChanged(ItemEvent itemEvent) {
        String str = (String) itemEvent.getItem();
        if (itemEvent.getSource() == this.kernel) {
            this.kernelOpt.getLayout().show(this.kernelOpt, str);
            return;
        }
        if (itemEvent.getSource() == this.outputType) {
            if (str.equals("Thresholded")) {
                ((SVM) this.alg).setOutputMode(1);
            } else if (str.equals("Raw")) {
                ((SVM) this.alg).setOutputMode(0);
            }
            this.plot.repaint();
        }
    }
}
