package weka.clusterers;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.kstar.KStarConstants;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Utils;
import weka.estimators.DiscreteEstimator;
import weka.estimators.Estimator;

/* loaded from: input_file:weka/clusterers/EM.class */
public class EM extends DistributionClusterer implements OptionHandler {
    private Estimator[][] m_model;
    private double[][][] m_modelNormal;
    private double[][] m_weights;
    private double[] m_priors;
    private double m_loglikely;
    private int m_num_clusters;
    private int m_initialNumClusters;
    private int m_num_attribs;
    private int m_num_instances;
    private int m_max_iterations;
    private Random m_rr;
    private int m_rseed;
    private static double m_normConst = Math.sqrt(6.283185307179586d);
    private boolean m_verbose;
    private double m_minStdDev = 1.0E-6d;
    private Instances m_theInstances = null;

    public String globalInfo() {
        return "Cluster data using expectation maximization";
    }

    @Override // weka.core.OptionHandler
    public Enumeration listOptions() {
        Vector vector = new Vector(6);
        vector.addElement(new Option("\tnumber of clusters. If omitted or\n\t-1 specified, then cross validation is used to\n\tselect the number of clusters.", "N", 1, "-N <num>"));
        vector.addElement(new Option("\tmax iterations.\n(default 100)", "I", 1, "-I <num>"));
        vector.addElement(new Option("\trandom number seed.\n(default 1)", "S", 1, "-S <num>"));
        vector.addElement(new Option("\tverbose.", "V", 0, "-V"));
        vector.addElement(new Option("\tminimum allowable standard deviation for normal density computation \n\t(default 1e-6)", "M", 1, "-M <num>"));
        return vector.elements();
    }

    @Override // weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        resetOptions();
        setDebug(Utils.getFlag('V', strArr));
        String option = Utils.getOption('I', strArr);
        if (option.length() != 0) {
            setMaxIterations(Integer.parseInt(option));
        }
        String option2 = Utils.getOption('N', strArr);
        if (option2.length() != 0) {
            setNumClusters(Integer.parseInt(option2));
        }
        String option3 = Utils.getOption('S', strArr);
        if (option3.length() != 0) {
            setSeed(Integer.parseInt(option3));
        }
        String option4 = Utils.getOption('M', strArr);
        if (option4.length() != 0) {
            setMinStdDev(new Double(option4).doubleValue());
        }
    }

    public String minStdDevTipText() {
        return "set minimum allowable standard deviation";
    }

    public void setMinStdDev(double d) {
        this.m_minStdDev = d;
    }

    public double getMinStdDev() {
        return this.m_minStdDev;
    }

    public String seedTipText() {
        return "random number seed";
    }

    public void setSeed(int i) {
        this.m_rseed = i;
    }

    public int getSeed() {
        return this.m_rseed;
    }

    public String numClustersTipText() {
        return "set number of clusters. -1 to select number of clusters automatically by cross validation.";
    }

    public void setNumClusters(int i) throws Exception {
        if (i == 0) {
            throw new Exception("Number of clusters must be > 0. (or -1 to select by cross validation).");
        }
        if (i < 0) {
            this.m_num_clusters = -1;
            this.m_initialNumClusters = -1;
        } else {
            this.m_num_clusters = i;
            this.m_initialNumClusters = i;
        }
    }

    public int getNumClusters() {
        return this.m_initialNumClusters;
    }

    public String maxIterationsTipText() {
        return "maximum number of iterations";
    }

    public void setMaxIterations(int i) throws Exception {
        if (i < 1) {
            throw new Exception("Maximum number of iterations must be > 0!");
        }
        this.m_max_iterations = i;
    }

    public int getMaxIterations() {
        return this.m_max_iterations;
    }

    public void setDebug(boolean z) {
        this.m_verbose = z;
    }

    public boolean getDebug() {
        return this.m_verbose;
    }

    @Override // weka.core.OptionHandler
    public String[] getOptions() {
        String[] strArr = new String[9];
        int i = 0;
        if (this.m_verbose) {
            i = 0 + 1;
            strArr[0] = "-V";
        }
        int i2 = i;
        int i3 = i + 1;
        strArr[i2] = "-I";
        int i4 = i3 + 1;
        strArr[i3] = "".concat(String.valueOf(String.valueOf(this.m_max_iterations)));
        int i5 = i4 + 1;
        strArr[i4] = "-N";
        int i6 = i5 + 1;
        strArr[i5] = "".concat(String.valueOf(String.valueOf(getNumClusters())));
        int i7 = i6 + 1;
        strArr[i6] = "-S";
        int i8 = i7 + 1;
        strArr[i7] = "".concat(String.valueOf(String.valueOf(this.m_rseed)));
        int i9 = i8 + 1;
        strArr[i8] = "-M";
        int i10 = i9 + 1;
        strArr[i9] = "".concat(String.valueOf(String.valueOf(getMinStdDev())));
        while (i10 < strArr.length) {
            int i11 = i10;
            i10++;
            strArr[i11] = "";
        }
        return strArr;
    }

    private void EM_Init(Instances instances, int i) throws Exception {
        this.m_weights = new double[instances.numInstances()][i];
        this.m_model = new Estimator[i][this.m_num_attribs];
        this.m_modelNormal = new double[i][this.m_num_attribs][3];
        this.m_priors = new double[i];
        for (int i2 = 0; i2 < instances.numInstances(); i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                this.m_weights[i2][i3] = this.m_rr.nextDouble();
            }
            Utils.normalize(this.m_weights[i2]);
        }
        estimate_priors(instances, i);
    }

    private void estimate_priors(Instances instances, int i) throws Exception {
        for (int i2 = 0; i2 < i; i2++) {
            this.m_priors[i2] = 0.0d;
        }
        for (int i3 = 0; i3 < instances.numInstances(); i3++) {
            for (int i4 = 0; i4 < i; i4++) {
                double[] dArr = this.m_priors;
                int i5 = i4;
                dArr[i5] = dArr[i5] + this.m_weights[i3][i4];
            }
        }
        Utils.normalize(this.m_priors);
    }

    private double normalDens(double d, double d2, double d3) {
        double d4 = d - d2;
        return (1 / (m_normConst * d3)) * Math.exp(-((d4 * d4) / ((2 * d3) * d3)));
    }

    private void new_estimators(int i) {
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < this.m_num_attribs; i3++) {
                if (this.m_theInstances.attribute(i3).isNominal()) {
                    this.m_model[i2][i3] = new DiscreteEstimator(this.m_theInstances.attribute(i3).numValues(), true);
                } else {
                    double[] dArr = this.m_modelNormal[i2][i3];
                    double[] dArr2 = this.m_modelNormal[i2][i3];
                    this.m_modelNormal[i2][i3][2] = 0.0d;
                    dArr2[1] = 0.0d;
                    dArr[0] = 0.0d;
                }
            }
        }
    }

    private void M(Instances instances, int i) throws Exception {
        new_estimators(i);
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < this.m_num_attribs; i3++) {
                for (int i4 = 0; i4 < instances.numInstances(); i4++) {
                    if (!instances.instance(i4).isMissing(i3)) {
                        if (instances.attribute(i3).isNominal()) {
                            this.m_model[i2][i3].addValue(instances.instance(i4).value(i3), this.m_weights[i4][i2]);
                        } else {
                            double[] dArr = this.m_modelNormal[i2][i3];
                            dArr[0] = dArr[0] + (instances.instance(i4).value(i3) * this.m_weights[i4][i2]);
                            double[] dArr2 = this.m_modelNormal[i2][i3];
                            dArr2[2] = dArr2[2] + this.m_weights[i4][i2];
                            double[] dArr3 = this.m_modelNormal[i2][i3];
                            dArr3[1] = dArr3[1] + (instances.instance(i4).value(i3) * instances.instance(i4).value(i3) * this.m_weights[i4][i2]);
                        }
                    }
                }
            }
        }
        for (int i5 = 0; i5 < this.m_num_attribs; i5++) {
            if (!instances.attribute(i5).isNominal()) {
                for (int i6 = 0; i6 < i; i6++) {
                    if (this.m_modelNormal[i6][i5][2] < 0) {
                        this.m_modelNormal[i6][i5][1] = 0.0d;
                    } else {
                        this.m_modelNormal[i6][i5][1] = (this.m_modelNormal[i6][i5][1] - ((this.m_modelNormal[i6][i5][0] * this.m_modelNormal[i6][i5][0]) / this.m_modelNormal[i6][i5][2])) / this.m_modelNormal[i6][i5][2];
                        this.m_modelNormal[i6][i5][1] = Math.sqrt(this.m_modelNormal[i6][i5][1]);
                        if (this.m_modelNormal[i6][i5][1] <= this.m_minStdDev || Double.isNaN(this.m_modelNormal[i6][i5][1])) {
                            this.m_modelNormal[i6][i5][1] = this.m_minStdDev;
                        }
                        if (this.m_modelNormal[i6][i5][2] > KStarConstants.FLOOR) {
                            double[] dArr4 = this.m_modelNormal[i6][i5];
                            dArr4[0] = dArr4[0] / this.m_modelNormal[i6][i5][2];
                        }
                    }
                }
            }
        }
    }

    private double E(Instances instances, int i) throws Exception {
        double d = 0.0d;
        for (int i2 = 0; i2 < instances.numInstances(); i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                this.m_weights[i2][i3] = this.m_priors[i3];
            }
            for (int i4 = 0; i4 < this.m_num_attribs; i4++) {
                double d2 = 0.0d;
                for (int i5 = 0; i5 < i; i5++) {
                    if (!instances.instance(i2).isMissing(i4)) {
                        if (instances.attribute(i4).isNominal()) {
                            double[] dArr = this.m_weights[i2];
                            int i6 = i5;
                            dArr[i6] = dArr[i6] * this.m_model[i5][i4].getProbability(instances.instance(i2).value(i4));
                        } else {
                            double[] dArr2 = this.m_weights[i2];
                            int i7 = i5;
                            dArr2[i7] = dArr2[i7] * normalDens(instances.instance(i2).value(i4), this.m_modelNormal[i5][i4][0], this.m_modelNormal[i5][i4][1]);
                            if (Double.isInfinite(this.m_weights[i2][i5])) {
                                throw new Exception("Joint density has overflowed. Try increasing the minimum allowable standard deviation for normal density calculation.");
                            }
                        }
                        if (this.m_weights[i2][i5] > d2) {
                            d2 = this.m_weights[i2][i5];
                        }
                    }
                }
                if (d2 > 0 && d2 < 1.0E-75d) {
                    for (int i8 = 0; i8 < i; i8++) {
                        double[] dArr3 = this.m_weights[i2];
                        int i9 = i8;
                        dArr3[i9] = dArr3[i9] * 1.0E75d;
                    }
                }
            }
            double d3 = 0.0d;
            for (int i10 = 0; i10 < i; i10++) {
                d3 += this.m_weights[i2][i10];
            }
            if (d3 > 0) {
                d += Math.log(d3);
            }
            try {
                Utils.normalize(this.m_weights[i2]);
            } catch (Exception e) {
                throw new Exception("An instance has zero cluster memberships. Try increasing the minimum allowable standard deviation for normal density calculation.");
            }
        }
        estimate_priors(instances, i);
        return d / instances.numInstances();
    }

    public EM() {
        resetOptions();
    }

    protected void resetOptions() {
        this.m_minStdDev = 1.0E-6d;
        this.m_max_iterations = 100;
        this.m_rseed = 100;
        this.m_num_clusters = -1;
        this.m_initialNumClusters = -1;
        this.m_verbose = false;
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("\nEM\n==\n");
        if (this.m_initialNumClusters == -1) {
            stringBuffer.append(String.valueOf(String.valueOf(new StringBuffer("\nNumber of clusters selected by cross validation: ").append(this.m_num_clusters).append("\n"))));
        } else {
            stringBuffer.append(String.valueOf(String.valueOf(new StringBuffer("\nNumber of clusters: ").append(this.m_num_clusters).append("\n"))));
        }
        for (int i = 0; i < this.m_num_clusters; i++) {
            stringBuffer.append(String.valueOf(String.valueOf(new StringBuffer("\nCluster: ").append(i).append(" Prior probability: ").append(Utils.doubleToString(this.m_priors[i], 4)).append("\n\n"))));
            for (int i2 = 0; i2 < this.m_num_attribs; i2++) {
                stringBuffer.append(String.valueOf(String.valueOf(new StringBuffer("Attribute: ").append(this.m_theInstances.attribute(i2).name()).append("\n"))));
                if (!this.m_theInstances.attribute(i2).isNominal()) {
                    stringBuffer.append(String.valueOf(String.valueOf(new StringBuffer("Normal Distribution. Mean = ").append(Utils.doubleToString(this.m_modelNormal[i][i2][0], 4)).append(" StdDev = ").append(Utils.doubleToString(this.m_modelNormal[i][i2][1], 4)).append("\n"))));
                } else if (this.m_model[i][i2] != null) {
                    stringBuffer.append(this.m_model[i][i2].toString());
                }
            }
        }
        return stringBuffer.toString();
    }

    private void EM_Report(Instances instances) {
        System.out.println("======================================");
        for (int i = 0; i < this.m_num_clusters; i++) {
            for (int i2 = 0; i2 < this.m_num_attribs; i2++) {
                System.out.println(String.valueOf(String.valueOf(new StringBuffer("Clust: ").append(i).append(" att: ").append(i2).append("\n"))));
                if (!this.m_theInstances.attribute(i2).isNominal()) {
                    System.out.println(String.valueOf(String.valueOf(new StringBuffer("Normal Distribution. Mean = ").append(Utils.doubleToString(this.m_modelNormal[i][i2][0], 8, 4)).append(" StandardDev = ").append(Utils.doubleToString(this.m_modelNormal[i][i2][1], 8, 4)).append(" WeightSum = ").append(Utils.doubleToString(this.m_modelNormal[i][i2][2], 8, 4)))));
                } else if (this.m_model[i][i2] != null) {
                    System.out.println(this.m_model[i][i2].toString());
                }
            }
        }
        for (int i3 = 0; i3 < instances.numInstances(); i3++) {
            System.out.print(String.valueOf(String.valueOf(new StringBuffer("Inst ").append(Utils.doubleToString(i3, 5, 0)).append(" Class ").append(Utils.maxIndex(this.m_weights[i3])).append("\t"))));
            for (int i4 = 0; i4 < this.m_num_clusters; i4++) {
                System.out.print(String.valueOf(String.valueOf(Utils.doubleToString(this.m_weights[i3][i4], 7, 5))).concat("  "));
            }
            System.out.println();
        }
    }

    private int CVClusters() throws Exception {
        double d = -1.7976931348623157E308d;
        boolean z = true;
        int i = 1;
        int numInstances = this.m_theInstances.numInstances() < 10 ? this.m_theInstances.numInstances() : 10;
        while (z) {
            z = false;
            Random random = new Random(this.m_rseed);
            Instances instances = new Instances(this.m_theInstances);
            instances.randomize(random);
            double d2 = 0.0d;
            for (int i2 = 0; i2 < numInstances; i2++) {
                Instances trainCV = instances.trainCV(numInstances, i2);
                Instances testCV = instances.testCV(numInstances, i2);
                EM_Init(trainCV, i);
                iterate(trainCV, i, false);
                double E = E(testCV, i);
                if (this.m_verbose) {
                    System.out.println(String.valueOf(String.valueOf(new StringBuffer("# clust: ").append(i).append(" Fold: ").append(i2).append(" Loglikely: ").append(E))));
                }
                d2 += E;
            }
            double d3 = d2 / numInstances;
            if (this.m_verbose) {
                System.out.println(String.valueOf(String.valueOf(new StringBuffer("=================================================\n# clust: ").append(i).append(" Mean Loglikely: ").append(d3).append("\n================================").append("================="))));
            }
            if (d3 > d) {
                d = d3;
                z = true;
                i++;
            }
        }
        if (this.m_verbose) {
            System.out.println("Number of clusters: ".concat(String.valueOf(String.valueOf(i - 1))));
        }
        return i - 1;
    }

    @Override // weka.clusterers.Clusterer
    public int numberOfClusters() throws Exception {
        if (this.m_num_clusters == -1) {
            throw new Exception("Haven't generated any clusters!");
        }
        return this.m_num_clusters;
    }

    @Override // weka.clusterers.Clusterer
    public void buildClusterer(Instances instances) throws Exception {
        if (instances.checkForStringAttributes()) {
            throw new Exception("Can't handle string attributes!");
        }
        this.m_theInstances = instances;
        doEM();
        this.m_theInstances = new Instances(this.m_theInstances, 0);
    }

    @Override // weka.clusterers.DistributionClusterer
    public double densityForInstance(Instance instance) throws Exception {
        return Utils.sum(weightsForInstance(instance));
    }

    @Override // weka.clusterers.DistributionClusterer
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] weightsForInstance = weightsForInstance(instance);
        Utils.normalize(weightsForInstance);
        return weightsForInstance;
    }

    protected double[] weightsForInstance(Instance instance) throws Exception {
        double[] dArr = new double[this.m_num_clusters];
        for (int i = 0; i < this.m_num_clusters; i++) {
            double d = 1.0d;
            for (int i2 = 0; i2 < this.m_num_attribs; i2++) {
                if (!instance.isMissing(i2)) {
                    d = instance.attribute(i2).isNominal() ? d * this.m_model[i][i2].getProbability(instance.value(i2)) : d * normalDens(instance.value(i2), this.m_modelNormal[i][i2][0], this.m_modelNormal[i][i2][1]);
                }
            }
            dArr[i] = d * this.m_priors[i];
        }
        return dArr;
    }

    private void doEM() throws Exception {
        if (this.m_verbose) {
            System.out.println("Seed: ".concat(String.valueOf(String.valueOf(this.m_rseed))));
        }
        this.m_rr = new Random(this.m_rseed);
        this.m_num_instances = this.m_theInstances.numInstances();
        this.m_num_attribs = this.m_theInstances.numAttributes();
        if (this.m_verbose) {
            System.out.println(String.valueOf(String.valueOf(new StringBuffer("Number of instances: ").append(this.m_num_instances).append("\nNumber of atts: ").append(this.m_num_attribs).append("\n"))));
        }
        if (this.m_initialNumClusters == -1) {
            if (this.m_theInstances.numInstances() > 9) {
                this.m_num_clusters = CVClusters();
            } else {
                this.m_num_clusters = 1;
            }
        }
        EM_Init(this.m_theInstances, this.m_num_clusters);
        this.m_loglikely = iterate(this.m_theInstances, this.m_num_clusters, this.m_verbose);
    }

    private double iterate(Instances instances, int i, boolean z) throws Exception {
        double d = 0.0d;
        if (z) {
            EM_Report(instances);
        }
        for (int i2 = 0; i2 < this.m_max_iterations; i2++) {
            M(instances, i);
            double d2 = d;
            d = E(instances, i);
            if (z) {
                System.out.println("Loglikely: ".concat(String.valueOf(String.valueOf(d))));
            }
            if (i2 > 0 && d - d2 < 1.0E-6d) {
                break;
            }
        }
        if (z) {
            EM_Report(instances);
        }
        return d;
    }

    public static void main(String[] strArr) {
        try {
            System.out.println(ClusterEvaluation.evaluateClusterer(new EM(), strArr));
        } catch (Exception e) {
            System.out.println(e.getMessage());
            e.printStackTrace();
        }
    }
}
