package org.encog.ml.bayesian.training.search.k2;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.encog.mathutil.EncogMath;
import org.encog.ml.bayesian.BayesianEvent;
import org.encog.ml.bayesian.BayesianNetwork;
import org.encog.ml.bayesian.query.enumerate.EnumerationQuery;
import org.encog.ml.bayesian.training.TrainBayesian;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;

/* loaded from: classes.dex */
public class SearchK2 implements BayesSearch {
    private MLDataSet data;
    private double lastCalculatedP;
    private BayesianNetwork network;
    private TrainBayesian train;
    private final List nodeOrdering = new ArrayList();
    private int index = -1;

    private BayesianEvent findZ(BayesianEvent bayesianEvent, int i, double d2) {
        BayesianEvent bayesianEvent2 = null;
        double d3 = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < i; i2++) {
            BayesianEvent bayesianEvent3 = (BayesianEvent) this.nodeOrdering.get(i2);
            ArrayList arrayList = new ArrayList();
            arrayList.addAll(bayesianEvent.getParents());
            arrayList.add(bayesianEvent3);
            this.lastCalculatedP = calculateG(this.network, bayesianEvent, arrayList);
            double d4 = this.lastCalculatedP;
            if (d4 > d2 && d4 > d3) {
                bayesianEvent2 = bayesianEvent3;
                d3 = d4;
            }
        }
        this.lastCalculatedP = d3;
        return bayesianEvent2;
    }

    private void orderNodes() {
        this.nodeOrdering.clear();
        if (this.network.getClassificationTarget() != -1) {
            this.nodeOrdering.add(this.network.getClassificationTargetEvent());
        }
        for (BayesianEvent bayesianEvent : this.network.getEvents()) {
            if (!this.nodeOrdering.contains(bayesianEvent)) {
                this.nodeOrdering.add(bayesianEvent);
            }
        }
    }

    public double calculateG(BayesianNetwork bayesianNetwork, BayesianEvent bayesianEvent, List list) {
        int size = bayesianEvent.getChoices().size();
        int[] iArr = new int[list.size()];
        double d2 = 1.0d;
        do {
            double factorial = EncogMath.factorial(size - 1) / EncogMath.factorial((calculateN(bayesianNetwork, bayesianEvent, list, iArr) + size) - 1);
            double d3 = 1.0d;
            for (int i = 0; i < bayesianEvent.getChoices().size(); i++) {
                d3 *= EncogMath.factorial(calculateN(bayesianNetwork, bayesianEvent, list, iArr, i));
            }
            d2 *= factorial * d3;
        } while (EnumerationQuery.roll(list, iArr));
        return d2;
    }

    public int calculateN(BayesianNetwork bayesianNetwork, BayesianEvent bayesianEvent, List list, int[] iArr) {
        boolean z;
        Iterator it = this.data.iterator();
        int i = 0;
        while (it.hasNext()) {
            int[] determineClasses = this.network.determineClasses(((MLDataPair) it.next()).getInput());
            int i2 = 0;
            while (true) {
                if (i2 >= iArr.length) {
                    z = false;
                    break;
                }
                if (iArr[i2] != determineClasses[bayesianNetwork.getEventIndex((BayesianEvent) list.get(i2))]) {
                    z = true;
                    break;
                }
                i2++;
            }
            if (!z) {
                i++;
            }
        }
        return i;
    }

    public int calculateN(BayesianNetwork bayesianNetwork, BayesianEvent bayesianEvent, List list, int[] iArr, int i) {
        boolean z;
        int eventIndex = bayesianNetwork.getEventIndex(bayesianEvent);
        Iterator it = this.data.iterator();
        int i2 = 0;
        while (it.hasNext()) {
            int[] determineClasses = this.network.determineClasses(((MLDataPair) it.next()).getInput());
            if (determineClasses[eventIndex] == i) {
                int i3 = 0;
                while (true) {
                    if (i3 >= iArr.length) {
                        z = false;
                        break;
                    }
                    if (iArr[i3] != determineClasses[bayesianNetwork.getEventIndex((BayesianEvent) list.get(i3))]) {
                        z = true;
                        break;
                    }
                    i3++;
                }
                if (!z) {
                    i2++;
                }
            }
        }
        return i2;
    }

    @Override // org.encog.ml.bayesian.training.search.k2.BayesSearch
    public void init(TrainBayesian trainBayesian, BayesianNetwork bayesianNetwork, MLDataSet mLDataSet) {
        this.network = bayesianNetwork;
        this.data = mLDataSet;
        this.train = trainBayesian;
        orderNodes();
        this.index = -1;
    }

    @Override // org.encog.ml.bayesian.training.search.k2.BayesSearch
    public boolean iteration() {
        BayesianEvent findZ;
        int i = this.index;
        if (i == -1) {
            orderNodes();
        } else {
            BayesianEvent bayesianEvent = (BayesianEvent) this.nodeOrdering.get(i);
            double calculateG = calculateG(this.network, bayesianEvent, bayesianEvent.getParents());
            while (bayesianEvent.getParents().size() < this.train.getMaximumParents() && (findZ = findZ(bayesianEvent, this.index, calculateG)) != null) {
                this.network.createDependency(findZ, bayesianEvent);
                calculateG = this.lastCalculatedP;
            }
        }
        this.index++;
        return this.index < this.data.getInputSize();
    }
}
