package gui.options;

import algorithms.regression.MaxLikelihood;
import data.Dataset;
import gui.Plotter;
import java.awt.Color;
import java.awt.Component;
import java.awt.Dimension;
import java.awt.event.ActionEvent;
import java.awt.event.ItemEvent;
import java.awt.event.ItemListener;
import javax.swing.Box;
import javax.swing.BoxLayout;
import javax.swing.JButton;
import javax.swing.JCheckBox;
import javax.swing.JComboBox;
import javax.swing.JLabel;
import javax.swing.JOptionPane;
import javax.swing.JPanel;

/* loaded from: input_file:gui/options/MaxLikePanel.class */
public class MaxLikePanel extends OptionPanel implements ItemListener {
    private JPanel boxPanel;
    private JComboBox colourBox;
    private JCheckBox[] boxes;
    private JCheckBox errors;
    private JButton train;
    private JButton cv;
    private JLabel trainLike;
    private JLabel testLike;
    private Plotter plot;
    private boolean trained = false;

    public MaxLikePanel(Plotter plotter) {
        this.plot = plotter;
        this.alg = new MaxLikelihood();
    }

    @Override // gui.options.OptionPanel
    protected void layoutPanel() {
        this.boxPanel = new JPanel();
        this.boxPanel.setAlignmentX(0.0f);
        this.boxPanel.setLayout(new BoxLayout(this.boxPanel, 3));
        this.boxPanel.add(new JLabel("Select terms to include:"));
        this.boxes = new JCheckBox[11];
        this.boxes[0] = new JCheckBox("Constant term");
        this.boxes[1] = new JCheckBox("x");
        this.boxes[2] = new JCheckBox("<html>x<sup>2</sup></html>");
        this.boxes[3] = new JCheckBox("<html>x<sup>3</sup></html>");
        this.boxes[4] = new JCheckBox("<html>x<sup>4</sup></html>");
        this.boxes[5] = new JCheckBox("<html>x<sup>5</sup></html>");
        this.boxes[6] = new JCheckBox("<html>x<sup>6</sup></html>");
        this.boxes[7] = new JCheckBox("<html>x<sup>7</sup></html>");
        this.boxes[8] = new JCheckBox("<html>x<sup>8</sup></html>");
        this.boxes[9] = new JCheckBox("sin x");
        this.boxes[10] = new JCheckBox("<html>e<sup>x</sup></html>");
        for (Component component : this.boxes) {
            this.boxPanel.add(component);
        }
        add(this.boxPanel);
        JPanel jPanel = new JPanel();
        jPanel.setLayout(new BoxLayout(jPanel, 2));
        jPanel.setAlignmentX(0.0f);
        this.colourBox = new JComboBox();
        this.colourBox.addItem("Black");
        this.colourBox.addItem("Blue");
        this.colourBox.addItem("Red");
        this.colourBox.addItem("Green");
        this.colourBox.addItemListener(this);
        this.colourBox.setMaximumSize(new Dimension(60, 20));
        jPanel.add(new JLabel("Line colour: "));
        jPanel.add(this.colourBox);
        add(jPanel);
        this.train = new JButton("Train model");
        this.train.addActionListener(this);
        add(this.train);
        this.errors = new JCheckBox("Error bars on");
        this.errors.setSelected(true);
        this.errors.addActionListener(this);
        add(this.errors);
        add(Box.createVerticalGlue());
        this.trainLike = new JLabel("Training data likelihood: 0.0");
        add(this.trainLike);
        this.testLike = new JLabel("Test data likelihood: 0.0");
        add(this.testLike);
        this.cv = new JButton("Cross validate");
        this.cv.addActionListener(this);
        add(this.cv);
    }

    public void itemStateChanged(ItemEvent itemEvent) {
        changeColour((String) itemEvent.getItem());
    }

    public void changeColour(String str) {
        Color color = null;
        if (str.equals("Black")) {
            color = Color.black;
        } else if (str.equals("Blue")) {
            color = Color.blue;
        } else if (str.equals("Red")) {
            color = Color.red;
        } else if (str.equals("Green")) {
            color = Color.green;
        }
        this.plot.setAlgorithmColor(this.alg, color);
    }

    @Override // gui.options.OptionPanel
    public void actionPerformed(ActionEvent actionEvent) {
        if (actionEvent.getSource() == this.train) {
            trainModel();
            return;
        }
        if (actionEvent.getSource() == this.cv) {
            crossValidate();
        } else if (actionEvent.getSource() == this.errors) {
            ((MaxLikelihood) this.alg).setErrors(this.errors.isSelected());
            if (this.trained) {
                this.plot.repaint();
            }
        }
    }

    public void updateUI() {
        if (this.trained) {
            this.trainLike.setText(String.format("Training data likelihood: %1.5f", Double.valueOf(((MaxLikelihood) this.alg).getTrainingLikelihood())));
            Dataset testData = this.plot.getTestData();
            if (testData.size() > 0) {
                this.testLike.setText(String.format("Test data likelihood: %1.5f", Double.valueOf(((MaxLikelihood) this.alg).getLikelihood(testData))));
            }
        }
        super.updateUI();
    }

    private void trainModel() {
        Dataset trainingData = this.plot.getTrainingData();
        int i = 0;
        for (int i2 = 0; i2 < this.boxes.length; i2++) {
            if (this.boxes[i2].isSelected()) {
                i++;
            }
            ((MaxLikelihood) this.alg).setTerm(i2, this.boxes[i2].isSelected());
        }
        if (i > trainingData.size() && this.plot.warnings()) {
            JOptionPane.showMessageDialog((Component) null, "There are more terms than data points. The plotted curve will be inaccurate. \nAdd more points to correct this.", "Warning", 2);
        }
        try {
            this.alg.train(trainingData);
            this.trained = true;
            this.alg.setVisible(true);
        } catch (Exception e) {
            if (this.plot.warnings()) {
                JOptionPane.showMessageDialog((Component) null, "The model cannot be drawn - the matrix is singular to the machine's precision.", "Warning", 0);
            }
        }
        this.plot.addAlgorithm(this.alg);
        changeColour((String) this.colourBox.getSelectedItem());
    }

    private void crossValidate() {
        trainModel();
        Dataset trainingData = this.plot.getTrainingData();
        int size = trainingData.size() < 10 ? trainingData.size() : 10;
        if (size == 0) {
            JOptionPane.showMessageDialog((Component) null, "No data has been added.", "Error", 0);
            return;
        }
        Dataset[] fold = trainingData.fold(size);
        double d = 0.0d;
        for (int i = 0; i < fold.length; i++) {
            Dataset dataset = new Dataset();
            for (int i2 = 0; i2 < fold.length; i2++) {
                if (i2 != i) {
                    dataset.addAll(fold[i2]);
                }
            }
            this.alg.train(dataset);
            double likelihood = ((MaxLikelihood) this.alg).getLikelihood(fold[i]);
            if (likelihood != Double.NaN) {
                d += likelihood;
            }
        }
        JOptionPane.showMessageDialog((Component) null, size + "-fold Cross Validation: \n Average likelihood per point = " + (d / trainingData.size()), "CV", 1);
        trainModel();
    }
}
