package org.encog.neural.pnn;

import java.util.Iterator;
import org.encog.ml.MLClassification;
import org.encog.ml.MLError;
import org.encog.ml.MLRegression;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.neural.NeuralNetworkError;
import org.encog.util.EngineArray;
import org.encog.util.simple.EncogUtility;

/* loaded from: classes.dex */
public class BasicPNN extends AbstractPNN implements MLRegression, MLError, MLClassification {
    private static final long serialVersionUID = -7990707837655024635L;
    private int[] countPer;
    private double[] priors;
    private BasicMLDataSet samples;
    private final double[] sigma;

    public BasicPNN(PNNKernelType pNNKernelType, PNNOutputMode pNNOutputMode, int i, int i2) {
        super(pNNKernelType, pNNOutputMode, i, i2);
        setSeparateClass(false);
        this.sigma = new double[i];
    }

    @Override // org.encog.ml.MLError
    public double calculateError(MLDataSet mLDataSet) {
        return getOutputMode() == PNNOutputMode.Classification ? EncogUtility.calculateClassificationError(this, mLDataSet) : EncogUtility.calculateRegressionError(this, mLDataSet);
    }

    @Override // org.encog.ml.MLClassification
    public int classify(MLData mLData) {
        return EngineArray.maxIndex(compute(mLData).getData());
    }

    @Override // org.encog.neural.pnn.AbstractPNN, org.encog.ml.MLRegression
    public MLData compute(MLData mLData) {
        double[] dArr = new double[getOutputCount()];
        Iterator it = this.samples.iterator();
        int i = -1;
        double d2 = 0.0d;
        while (true) {
            int i2 = 0;
            if (!it.hasNext()) {
                break;
            }
            MLDataPair mLDataPair = (MLDataPair) it.next();
            i++;
            if (i != getExclude()) {
                double d3 = 0.0d;
                for (int i3 = 0; i3 < getInputCount(); i3++) {
                    double data = (mLData.getData(i3) - mLDataPair.getInput().getData(i3)) / this.sigma[i3];
                    d3 = (data * data) + d3;
                }
                double exp = getKernel() == PNNKernelType.Gaussian ? Math.exp(-d3) : getKernel() == PNNKernelType.Reciprocal ? 1.0d / (d3 + 1.0d) : d3;
                if (exp < 1.0E-40d) {
                    exp = 1.0E-40d;
                }
                if (getOutputMode() == PNNOutputMode.Classification) {
                    int data2 = (int) mLDataPair.getIdeal().getData(0);
                    dArr[data2] = dArr[data2] + exp;
                } else {
                    if (getOutputMode() == PNNOutputMode.Unsupervised) {
                        while (i2 < getInputCount()) {
                            dArr[i2] = (mLDataPair.getInput().getData(i2) * exp) + dArr[i2];
                            i2++;
                        }
                    } else if (getOutputMode() == PNNOutputMode.Regression) {
                        while (i2 < getOutputCount()) {
                            dArr[i2] = (mLDataPair.getIdeal().getData(i2) * exp) + dArr[i2];
                            i2++;
                        }
                    }
                    d2 += exp;
                }
            }
        }
        if (getOutputMode() == PNNOutputMode.Classification) {
            double d4 = 0.0d;
            for (int i4 = 0; i4 < getOutputCount(); i4++) {
                double[] dArr2 = this.priors;
                if (dArr2[i4] >= 0.0d) {
                    double d5 = dArr[i4];
                    double d6 = dArr2[i4];
                    double d7 = this.countPer[i4];
                    Double.isNaN(d7);
                    dArr[i4] = (d6 / d7) * d5;
                }
                d4 += dArr[i4];
            }
            if (d4 < 1.0E-40d) {
                d4 = 1.0E-40d;
            }
            for (int i5 = 0; i5 < getOutputCount(); i5++) {
                dArr[i5] = dArr[i5] / d4;
            }
        } else if (getOutputMode() == PNNOutputMode.Unsupervised) {
            for (int i6 = 0; i6 < getInputCount(); i6++) {
                dArr[i6] = dArr[i6] / d2;
            }
        } else if (getOutputMode() == PNNOutputMode.Regression) {
            for (int i7 = 0; i7 < getOutputCount(); i7++) {
                dArr[i7] = dArr[i7] / d2;
            }
        }
        return new BasicMLData(dArr);
    }

    public int[] getCountPer() {
        return this.countPer;
    }

    public double[] getPriors() {
        return this.priors;
    }

    public BasicMLDataSet getSamples() {
        return this.samples;
    }

    public double[] getSigma() {
        return this.sigma;
    }

    public void setSamples(BasicMLDataSet basicMLDataSet) {
        this.samples = basicMLDataSet;
        if (getOutputMode() != PNNOutputMode.Classification) {
            return;
        }
        this.countPer = new int[getOutputCount()];
        this.priors = new double[getOutputCount()];
        Iterator it = basicMLDataSet.iterator();
        while (true) {
            int i = 0;
            if (it.hasNext()) {
                int data = (int) ((MLDataPair) it.next()).getIdeal().getData(0);
                int[] iArr = this.countPer;
                if (data >= iArr.length) {
                    throw new NeuralNetworkError("Training data contains more classes than neural network has output neurons to hold.");
                }
                iArr[data] = iArr[data] + 1;
            } else {
                while (true) {
                    double[] dArr = this.priors;
                    if (i >= dArr.length) {
                        return;
                    }
                    dArr[i] = -1.0d;
                    i++;
                }
            }
        }
    }

    @Override // org.encog.ml.BasicML, org.encog.ml.MLProperties
    public void updateProperties() {
    }
}
