package org.encog.neural.networks.training.propagation.sgd;

import java.util.Arrays;
import org.encog.EncogError;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.mathutil.randomize.generate.GenerateRandom;
import org.encog.mathutil.randomize.generate.MersenneTwisterGenerateRandom;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.error.CrossEntropyErrorFunction;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.Momentum;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.neural.networks.training.propagation.sgd.update.AdamUpdate;
import org.encog.neural.networks.training.propagation.sgd.update.UpdateRule;

/* loaded from: classes.dex */
public class StochasticGradientDescent extends BasicTraining implements Momentum, LearningRate {
    private ErrorCalculation errorCalculation;
    private ErrorFunction errorFunction;
    private FlatNetwork flat;
    private final double[] gradients;
    private double l1;
    private double l2;
    private double[] lastDelta;
    private final double[] layerDelta;
    private double learningRate;
    private MLMethod method;
    private double momentum;
    private GenerateRandom rnd;
    private UpdateRule updateRule;

    public StochasticGradientDescent(ContainsFlat containsFlat, MLDataSet mLDataSet) {
        this(containsFlat, mLDataSet, new MersenneTwisterGenerateRandom());
    }

    public StochasticGradientDescent(ContainsFlat containsFlat, MLDataSet mLDataSet, GenerateRandom generateRandom) {
        super(TrainingImplementationType.Iterative);
        this.updateRule = new AdamUpdate();
        this.errorFunction = new CrossEntropyErrorFunction();
        setTraining(mLDataSet);
        if (!(mLDataSet instanceof BatchDataSet)) {
            setBatchSize(25);
        }
        this.method = containsFlat;
        this.flat = containsFlat.getFlat();
        this.layerDelta = new double[this.flat.getLayerOutput().length];
        this.gradients = new double[this.flat.getWeights().length];
        this.errorCalculation = new ErrorCalculation();
        this.rnd = generateRandom;
        this.learningRate = 0.001d;
        this.momentum = 0.9d;
    }

    private void processLevel(int i) {
        int i2 = i + 1;
        int i3 = this.flat.getLayerIndex()[i2];
        int i4 = this.flat.getLayerIndex()[i];
        int i5 = this.flat.getLayerFeedCounts()[i];
        int i6 = this.flat.getWeightIndex()[i];
        ActivationFunction activationFunction = this.flat.getActivationFunctions()[i];
        double[] dArr = this.layerDelta;
        double[] weights = this.flat.getWeights();
        double[] dArr2 = this.gradients;
        double[] layerOutput = this.flat.getLayerOutput();
        double[] layerSums = this.flat.getLayerSums();
        int i7 = 0;
        for (int i8 = this.flat.getLayerCounts()[i2]; i7 < i8; i8 = i8) {
            double d2 = layerOutput[i3];
            int i9 = i6 + i7;
            int i10 = i4 + i5;
            int i11 = i4;
            double d3 = 0.0d;
            while (i11 < i10) {
                dArr2[i9] = (dArr[i11] * d2) + dArr2[i9];
                d3 = (weights[i9] * dArr[i11]) + d3;
                i11++;
                i9 += i8;
            }
            dArr[i3] = activationFunction.derivativeFunction(layerSums[i3], layerOutput[i3]) * d3;
            i3++;
            i7++;
            i4 = i4;
        }
    }

    public void calculateRegularizationPenalty(double[] dArr) {
        for (int i = 0; i < this.flat.getLayerCounts().length - 1; i++) {
            layerRegularizationPenalty(i, dArr);
        }
    }

    @Override // org.encog.ml.train.MLTrain
    public boolean canContinue() {
        return false;
    }

    public int getBatchSize() {
        if (getTraining() instanceof BatchDataSet) {
            return ((BatchDataSet) getTraining()).getBatchSize();
        }
        return 0;
    }

    public FlatNetwork getFlat() {
        return this.flat;
    }

    public double getL1() {
        return this.l1;
    }

    public double getL2() {
        return this.l2;
    }

    @Override // org.encog.neural.networks.training.LearningRate
    public double getLearningRate() {
        return this.learningRate;
    }

    @Override // org.encog.ml.train.MLTrain
    public MLMethod getMethod() {
        return this.method;
    }

    @Override // org.encog.neural.networks.training.Momentum
    public double getMomentum() {
        return this.momentum;
    }

    public UpdateRule getUpdateRule() {
        return this.updateRule;
    }

    public boolean isValidResume(TrainingContinuation trainingContinuation) {
        return false;
    }

    @Override // org.encog.ml.train.MLTrain
    public void iteration() {
        for (int i = 0; i < getTraining().size(); i++) {
            process(getTraining().get(i));
        }
        if (getIteration() == 0) {
            this.updateRule.init(this);
        }
        preIteration();
        update();
        postIteration();
        if (getTraining() instanceof BatchDataSet) {
            ((BatchDataSet) getTraining()).advance();
        }
    }

    public void layerRegularizationPenalty(int i, double[] dArr) {
        int layerTotalNeuronCount = this.flat.getLayerTotalNeuronCount(i);
        int layerNeuronCount = this.flat.getLayerNeuronCount(i + 1);
        for (int i2 = 0; i2 < layerTotalNeuronCount; i2++) {
            for (int i3 = 0; i3 < layerNeuronCount; i3++) {
                double weight = this.flat.getWeight(i, i2, i3);
                dArr[0] = Math.abs(weight) + dArr[0];
                dArr[1] = (weight * weight) + dArr[1];
            }
        }
    }

    @Override // org.encog.ml.train.MLTrain
    public TrainingContinuation pause() {
        return null;
    }

    @Override // org.encog.ml.train.BasicTraining
    public void preIteration() {
        super.preIteration();
    }

    public void process(MLDataPair mLDataPair) {
        this.errorCalculation = new ErrorCalculation();
        double[] dArr = new double[this.flat.getOutputCount()];
        this.flat.compute(mLDataPair.getInputArray(), dArr);
        this.errorCalculation.updateError(dArr, mLDataPair.getIdealArray(), mLDataPair.getSignificance());
        this.errorFunction.calculateError(this.flat.getActivationFunctions()[0], this.flat.getLayerSums(), this.flat.getLayerOutput(), mLDataPair.getIdeal().getData(), dArr, this.layerDelta, 0.0d, mLDataPair.getSignificance());
        if (this.l1 > 1.0E-13d || this.l2 > 1.0E-13d) {
            double[] dArr2 = new double[2];
            calculateRegularizationPenalty(dArr2);
            for (int i = 0; i < dArr.length; i++) {
                double d2 = (dArr2[1] * this.l2) + (dArr2[0] * this.l1);
                double[] dArr3 = this.layerDelta;
                dArr3[i] = dArr3[i] + d2;
            }
        }
        for (int beginTraining = this.flat.getBeginTraining(); beginTraining < this.flat.getEndTraining(); beginTraining++) {
            processLevel(beginTraining);
        }
    }

    public void resetError() {
        this.errorCalculation.reset();
    }

    @Override // org.encog.ml.train.MLTrain
    public void resume(TrainingContinuation trainingContinuation) {
        throw new EncogError("Resume not currently supported.");
    }

    public void setBatchSize(int i) {
        if (getTraining() instanceof BatchDataSet) {
            ((BatchDataSet) getTraining()).setBatchSize(i);
            return;
        }
        BatchDataSet batchDataSet = new BatchDataSet(getTraining(), this.rnd);
        batchDataSet.setBatchSize(i);
        setTraining(batchDataSet);
    }

    public void setL1(double d2) {
        this.l1 = d2;
    }

    public void setL2(double d2) {
        this.l2 = d2;
    }

    @Override // org.encog.neural.networks.training.LearningRate
    public void setLearningRate(double d2) {
        this.learningRate = d2;
    }

    @Override // org.encog.neural.networks.training.Momentum
    public void setMomentum(double d2) {
        this.momentum = d2;
    }

    public void setUpdateRule(UpdateRule updateRule) {
        this.updateRule = updateRule;
    }

    public void update() {
        if (getIteration() == 0) {
            this.updateRule.init(this);
        }
        preIteration();
        this.updateRule.update(this.gradients, this.flat.getWeights());
        setError(this.errorCalculation.calculate());
        postIteration();
        Arrays.fill(this.gradients, 0.0d);
        this.errorCalculation.reset();
        if (getTraining() instanceof BatchDataSet) {
            ((BatchDataSet) getTraining()).advance();
        }
    }
}
