package org.encog.neural.networks.training.propagation;

import java.util.Arrays;
import java.util.Random;
import org.encog.EncogError;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.mathutil.IntRange;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.error.LinearErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.BatchSize;
import org.encog.neural.networks.training.Train;
import org.encog.util.EncogValidate;
import org.encog.util.EngineArray;
import org.encog.util.concurrency.DetermineWorkload;
import org.encog.util.concurrency.EngineConcurrency;
import org.encog.util.concurrency.MultiThreadable;
import org.encog.util.concurrency.TaskGroup;
import org.encog.util.logging.EncogLogging;

/* loaded from: classes.dex */
public abstract class Propagation extends BasicTraining implements Train, MultiThreadable, BatchSize, GradientWorkerOwner {
    private int batchSize;
    private FlatNetwork currentFlatNetwork;
    protected Random dropoutRandomSource;
    private double dropoutRate;
    private ErrorFunction ef;
    private boolean finalized;
    private double[] flatSpot;
    protected double[] gradients;
    private final MLDataSet indexable;
    private int iteration;
    private double l1;
    private double l2;
    private final double[] lastGradient;
    protected final ContainsFlat network;
    private int numThreads;
    private Throwable reportedException;
    private boolean shouldFixFlatSpot;
    private double totalError;
    private GradientWorker[] workers;

    public Propagation(ContainsFlat containsFlat, MLDataSet mLDataSet) {
        super(TrainingImplementationType.Iterative);
        this.dropoutRandomSource = new Random();
        this.dropoutRate = 0.0d;
        this.ef = new LinearErrorFunction();
        this.batchSize = 0;
        this.finalized = false;
        this.network = containsFlat;
        this.currentFlatNetwork = containsFlat.getFlat();
        setTraining(mLDataSet);
        this.gradients = new double[this.currentFlatNetwork.getWeights().length];
        this.lastGradient = new double[this.currentFlatNetwork.getWeights().length];
        this.indexable = mLDataSet;
        this.numThreads = 0;
        this.reportedException = null;
        this.shouldFixFlatSpot = true;
    }

    private void copyContexts() {
        int i = 0;
        while (true) {
            GradientWorker[] gradientWorkerArr = this.workers;
            if (i >= gradientWorkerArr.length - 1) {
                EngineArray.arrayCopy(gradientWorkerArr[gradientWorkerArr.length - 1].getNetwork().getLayerOutput(), this.currentFlatNetwork.getLayerOutput());
                return;
            } else {
                double[] layerOutput = gradientWorkerArr[i].getNetwork().getLayerOutput();
                i++;
                EngineArray.arrayCopy(layerOutput, this.workers[i].getNetwork().getLayerOutput());
            }
        }
    }

    private void init() {
        this.flatSpot = new double[this.currentFlatNetwork.getActivationFunctions().length];
        int i = 0;
        if (this.shouldFixFlatSpot) {
            for (int i2 = 0; i2 < this.currentFlatNetwork.getActivationFunctions().length; i2++) {
                if (this.currentFlatNetwork.getActivationFunctions()[i2] instanceof ActivationSigmoid) {
                    this.flatSpot[i2] = 0.1d;
                } else {
                    this.flatSpot[i2] = 0.0d;
                }
            }
        } else {
            Arrays.fill(this.flatSpot, 0.0d);
        }
        if (this.batchSize != 0) {
            this.numThreads = 1;
        }
        DetermineWorkload determineWorkload = new DetermineWorkload(this.numThreads, (int) this.indexable.getRecordCount());
        this.workers = new GradientWorker[determineWorkload.getThreadCount()];
        for (IntRange intRange : determineWorkload.calculateWorkers()) {
            this.workers[i] = new GradientWorker(this.currentFlatNetwork.clone(), this, this.indexable.openAdditional(), intRange.getLow(), intRange.getHigh(), this.flatSpot, this.ef);
            i++;
        }
        initOthers();
    }

    private void processBatches() {
        if (this.workers == null) {
            init();
        }
        if (this.currentFlatNetwork.getHasContext()) {
            this.workers[0].getNetwork().clearContext();
        }
        this.workers[0].getErrorCalculation().reset();
        int i = 0;
        for (int i2 = 0; i2 < getTraining().size(); i2++) {
            this.workers[0].run(i2);
            int i3 = i + 1;
            int i4 = i3 + 1;
            if (i3 >= this.batchSize) {
                if (this.currentFlatNetwork.isLimited()) {
                    learnLimited();
                } else {
                    learn();
                    i = 0;
                }
            }
            i = i4;
        }
        if (i > 0) {
            learn();
        }
        setError(this.workers[0].getErrorCalculation().calculate());
    }

    private void processPureBatch() {
        calculateGradients();
        if (this.currentFlatNetwork.isLimited()) {
            learnLimited();
        } else {
            learn();
        }
    }

    public void calculateGradients() {
        if (this.workers == null) {
            init();
        }
        if (this.currentFlatNetwork.getHasContext()) {
            this.workers[0].getNetwork().clearContext();
        }
        this.totalError = 0.0d;
        GradientWorker[] gradientWorkerArr = this.workers;
        if (gradientWorkerArr.length > 1) {
            TaskGroup createTaskGroup = EngineConcurrency.getInstance().createTaskGroup();
            for (GradientWorker gradientWorker : this.workers) {
                EngineConcurrency.getInstance().processTask(gradientWorker, createTaskGroup);
            }
            createTaskGroup.waitForComplete();
        } else {
            gradientWorkerArr[0].run();
        }
        double d2 = this.totalError;
        double length = this.workers.length;
        Double.isNaN(length);
        setError(d2 / length);
    }

    @Override // org.encog.ml.train.BasicTraining, org.encog.ml.train.MLTrain
    public void finishTraining() {
        finishTraining(this.dropoutRate);
    }

    public void finishTraining(double d2) {
        if (this.finalized) {
            return;
        }
        double[] weights = this.currentFlatNetwork.getWeights();
        if (d2 > 0.0d) {
            for (int i = 0; i < weights.length; i++) {
                weights[i] = (1.0d - d2) * weights[i];
            }
        }
        this.finalized = true;
    }

    public void fixFlatSpot(boolean z) {
        this.shouldFixFlatSpot = z;
    }

    @Override // org.encog.neural.networks.training.BatchSize
    public int getBatchSize() {
        return this.batchSize;
    }

    public FlatNetwork getCurrentFlatNetwork() {
        return this.currentFlatNetwork;
    }

    public double getDropoutRate() {
        return this.dropoutRate;
    }

    @Override // org.encog.neural.networks.training.propagation.GradientWorkerOwner
    public double getL1() {
        return this.l1;
    }

    @Override // org.encog.neural.networks.training.propagation.GradientWorkerOwner
    public double getL2() {
        return this.l2;
    }

    public double[] getLastGradient() {
        return this.lastGradient;
    }

    @Override // org.encog.ml.train.MLTrain
    public MLMethod getMethod() {
        return this.network;
    }

    @Override // org.encog.util.concurrency.MultiThreadable
    public int getThreadCount() {
        return this.numThreads;
    }

    public abstract void initOthers();

    @Override // org.encog.ml.train.MLTrain
    public void iteration() {
        iteration(1);
    }

    @Override // org.encog.ml.train.BasicTraining, org.encog.ml.train.MLTrain
    public void iteration(int i) {
        for (int i2 = 0; i2 < i; i2++) {
            try {
                preIteration();
                rollIteration();
                if (this.batchSize == 0) {
                    processPureBatch();
                } else {
                    processBatches();
                }
                for (GradientWorker gradientWorker : this.workers) {
                    System.arraycopy(this.currentFlatNetwork.getWeights(), 0, gradientWorker.getWeights(), 0, this.currentFlatNetwork.getWeights().length);
                }
                if (this.currentFlatNetwork.getHasContext()) {
                    copyContexts();
                }
                if (this.reportedException != null) {
                    throw new EncogError(this.reportedException);
                }
                postIteration();
                EncogLogging.log(1, "Training iteration done, error: " + getError());
            } catch (ArrayIndexOutOfBoundsException e) {
                EncogValidate.validateNetworkForTraining(this.network, getTraining());
                throw new EncogError(e);
            }
        }
    }

    protected void learn() {
        double[] weights = this.currentFlatNetwork.getWeights();
        int i = 0;
        if (this.dropoutRate > 0.0d) {
            while (true) {
                double[] dArr = this.gradients;
                if (i >= dArr.length) {
                    return;
                }
                weights[i] = weights[i] + updateWeight(dArr, this.lastGradient, i, this.dropoutRate);
                this.gradients[i] = 0.0d;
                i++;
            }
        } else {
            while (true) {
                double[] dArr2 = this.gradients;
                if (i >= dArr2.length) {
                    return;
                }
                weights[i] = updateWeight(dArr2, this.lastGradient, i) + weights[i];
                this.gradients[i] = 0.0d;
                i++;
            }
        }
    }

    protected void learnLimited() {
        double connectionLimit = this.currentFlatNetwork.getConnectionLimit();
        double[] weights = this.currentFlatNetwork.getWeights();
        if (this.dropoutRate > 0.0d) {
            for (int i = 0; i < this.gradients.length; i++) {
                if (Math.abs(weights[i]) < connectionLimit) {
                    weights[i] = 0.0d;
                } else {
                    weights[i] = weights[i] + updateWeight(this.gradients, this.lastGradient, i, this.dropoutRate);
                }
                this.gradients[i] = 0.0d;
            }
        } else {
            for (int i2 = 0; i2 < this.gradients.length; i2++) {
                if (Math.abs(weights[i2]) < connectionLimit) {
                    weights[i2] = 0.0d;
                } else {
                    weights[i2] = updateWeight(this.gradients, this.lastGradient, i2) + weights[i2];
                }
                this.gradients[i2] = 0.0d;
            }
        }
        for (int i3 = 0; i3 < this.gradients.length; i3++) {
        }
    }

    @Override // org.encog.neural.networks.training.propagation.GradientWorkerOwner
    public void report(double[] dArr, double d2, Throwable th) {
        synchronized (this) {
            if (th == null) {
                for (int i = 0; i < dArr.length; i++) {
                    double[] dArr2 = this.gradients;
                    dArr2[i] = dArr2[i] + dArr[i];
                }
                this.totalError += d2;
            } else {
                this.reportedException = th;
            }
        }
    }

    public void rollIteration() {
        this.iteration++;
    }

    @Override // org.encog.neural.networks.training.BatchSize
    public void setBatchSize(int i) {
        this.batchSize = i;
    }

    public void setDroupoutRate(double d2) {
        this.dropoutRate = d2;
    }

    public void setErrorFunction(ErrorFunction errorFunction) {
        this.ef = errorFunction;
    }

    public void setL1(double d2) {
        this.l1 = d2;
    }

    public void setL2(double d2) {
        this.l2 = d2;
    }

    @Override // org.encog.util.concurrency.MultiThreadable
    public void setThreadCount(int i) {
        this.numThreads = i;
    }

    public abstract double updateWeight(double[] dArr, double[] dArr2, int i);

    public abstract double updateWeight(double[] dArr, double[] dArr2, int i, double d2);
}
