package org.encog.ensemble;

import b.a.a.a.a;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Iterator;
import org.encog.ensemble.EnsembleTypes;
import org.encog.ensemble.data.EnsembleDataSet;
import org.encog.ensemble.data.factories.EnsembleDataSetFactory;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;

/* loaded from: classes.dex */
public abstract class Ensemble {
    private final int DEFAULT_MAX_ITERATIONS = 2000;
    protected EnsembleAggregator aggregator;
    protected MLDataSet aggregatorDataSet;
    protected EnsembleDataSetFactory dataSetFactory;
    protected ArrayList members;
    protected EnsembleMLMethodFactory mlFactory;
    protected EnsembleTrainFactory trainFactory;

    /* loaded from: classes.dex */
    public class NotPossibleInThisMethod extends Exception {
        private static final long serialVersionUID = 5118253806179408868L;

        public NotPossibleInThisMethod() {
        }
    }

    /* loaded from: classes.dex */
    public class TrainingAborted extends Exception {
        private static final long serialVersionUID = -5074472788684621859L;

        public TrainingAborted(String str) {
            super(str);
        }
    }

    public void addMember(EnsembleML ensembleML) {
        this.members.add(ensembleML);
    }

    public void addNewMember() {
        this.members.add(generateNewMember());
    }

    public MLData compute(MLData mLData) {
        ArrayList arrayList = new ArrayList();
        Iterator it = this.members.iterator();
        while (it.hasNext()) {
            arrayList.add(((EnsembleML) it.next()).compute(mLData));
        }
        return this.aggregator.evaluate(arrayList);
    }

    public EnsembleML generateNewMember() {
        GenericEnsembleML genericEnsembleML = new GenericEnsembleML(this.mlFactory.createML(this.dataSetFactory.getInputCount(), this.dataSetFactory.getOutputCount()), this.mlFactory.getLabel());
        genericEnsembleML.setTrainingSet(this.dataSetFactory.getNewDataSet());
        genericEnsembleML.setTraining(this.trainFactory.getTraining(genericEnsembleML.getMl(), genericEnsembleML.getTrainingSet()));
        return genericEnsembleML;
    }

    public EnsembleAggregator getAggregator() {
        return this.aggregator;
    }

    public EnsembleML getMember(int i) {
        return (EnsembleML) this.members.get(i);
    }

    public abstract EnsembleTypes.ProblemType getProblemType();

    public MLDataSet getTrainingSet(int i) {
        return ((EnsembleML) this.members.get(i)).getTrainingSet();
    }

    public abstract void initMembers();

    public void initMembersBySplits(int i) {
        EnsembleDataSetFactory ensembleDataSetFactory = this.dataSetFactory;
        if (ensembleDataSetFactory == null || i <= 0 || !ensembleDataSetFactory.hasSource()) {
            return;
        }
        for (int i2 = 0; i2 < i; i2++) {
            GenericEnsembleML genericEnsembleML = new GenericEnsembleML(this.mlFactory.createML(this.dataSetFactory.getInputCount(), this.dataSetFactory.getOutputCount()), this.mlFactory.getLabel());
            genericEnsembleML.setTrainingSet(this.dataSetFactory.getNewDataSet());
            genericEnsembleML.setTraining(this.trainFactory.getTraining(genericEnsembleML.getMl(), genericEnsembleML.getTrainingSet()));
            this.members.add(genericEnsembleML);
        }
        if (this.aggregator.needsTraining()) {
            this.aggregatorDataSet = this.dataSetFactory.getNewDataSet();
        }
    }

    public void retrainAggregator() {
        EnsembleDataSet ensembleDataSet = new EnsembleDataSet(this.aggregatorDataSet.getIdealSize() * this.members.size(), this.aggregatorDataSet.getIdealSize());
        for (MLDataPair mLDataPair : this.aggregatorDataSet) {
            BasicMLData basicMLData = new BasicMLData(this.aggregatorDataSet.getIdealSize() * this.members.size());
            Iterator it = this.members.iterator();
            int i = 0;
            while (it.hasNext()) {
                double[] data = ((EnsembleML) it.next()).compute(mLDataPair.getInput()).getData();
                int length = data.length;
                int i2 = i;
                int i3 = 0;
                while (i3 < length) {
                    basicMLData.add(i2, data[i3]);
                    i3++;
                    i2++;
                }
                i = i2;
            }
            ensembleDataSet.add(basicMLData, mLDataPair.getIdeal());
        }
        this.aggregator.setTrainingSet(ensembleDataSet);
        this.aggregator.train();
    }

    public void setAggregator(EnsembleAggregator ensembleAggregator) {
        this.aggregator = ensembleAggregator;
    }

    public void setTrainingData(MLDataSet mLDataSet) {
        this.dataSetFactory.setInputData(mLDataSet);
        initMembers();
    }

    public void setTrainingDataFactory(EnsembleDataSetFactory ensembleDataSetFactory) {
        this.dataSetFactory = ensembleDataSetFactory;
        initMembers();
    }

    public void setTrainingMethod(EnsembleTrainFactory ensembleTrainFactory) {
        this.trainFactory = ensembleTrainFactory;
        initMembers();
    }

    public void train(double d2, double d3, int i, int i2, EnsembleDataSet ensembleDataSet, boolean z) {
        Iterator it = this.members.iterator();
        while (it.hasNext()) {
            trainMember((EnsembleML) it.next(), d2, d3, i, i2, ensembleDataSet, z);
        }
        if (this.aggregator.needsTraining()) {
            retrainAggregator();
        }
    }

    public void train(double d2, double d3, int i, EnsembleDataSet ensembleDataSet) {
        train(d2, d3, i, 2000, ensembleDataSet, false);
    }

    public void train(double d2, double d3, EnsembleDataSet ensembleDataSet) {
        train(d2, d3, ensembleDataSet, false);
    }

    public void train(double d2, double d3, EnsembleDataSet ensembleDataSet, boolean z) {
        train(d2, d3, 2000, 2000, ensembleDataSet, z);
    }

    public void trainMember(int i, double d2, double d3, int i2, EnsembleDataSet ensembleDataSet, boolean z) {
        trainMember((EnsembleML) this.members.get(i), d2, d3, i2, 2000, ensembleDataSet, z);
    }

    public void trainMember(int i, double d2, double d3, EnsembleDataSet ensembleDataSet, boolean z) {
        trainMember(i, d2, d3, 2000, ensembleDataSet, z);
    }

    public void trainMember(EnsembleML ensembleML, double d2, double d3, int i, int i2, EnsembleDataSet ensembleDataSet, boolean z) {
        int i3 = 0;
        do {
            long nanoTime = System.nanoTime();
            this.mlFactory.reInit(ensembleML.getMl());
            ensembleML.train(d2, i, z);
            long nanoTime2 = System.nanoTime();
            if (z) {
                PrintStream printStream = System.out;
                StringBuilder a2 = a.a("training took ");
                double d4 = nanoTime2 - nanoTime;
                Double.isNaN(d4);
                a2.append(d4 / 1.0E9d);
                printStream.println(a2.toString());
                PrintStream printStream2 = System.out;
                StringBuilder a3 = a.a("test MSE: ");
                a3.append(ensembleML.getError(ensembleDataSet));
                a3.append(" on ");
                a3.append(ensembleDataSet.size());
                a3.append(" data points");
                printStream2.println(a3.toString());
            }
            i3++;
            if (i3 > i2) {
                throw new TrainingAborted("Too many attempts at training ensemble member");
            }
        } while (ensembleML.getError(ensembleDataSet) > d3);
    }

    public void trainMember(EnsembleML ensembleML, double d2, double d3, EnsembleDataSet ensembleDataSet, boolean z) {
        trainMember(ensembleML, d2, d3, 2000, 2000, ensembleDataSet, z);
    }
}
