/*
 * Decompiled with CFR 0.152.
 */
package layer;

import active.Activator;
import java.util.Random;
import util.DJ;

public abstract class Layer {
    public static final int SGD = 0;
    public static final int ADA_GRAD = 1;
    public int inNodeNum;
    public int outNodeNum;
    public double[] x;
    public double[][] w;
    public double[] b;
    public double[] u;
    public double[] y;
    public double[] c;
    public double[] e;
    public double[] dEdX;
    public double[][] dEdW;
    public double[] dEdB;
    public double[] dEdU;
    public double[] dYdU;
    public double[] dEdY;
    public double[][] dEdW_sum;
    public double[] dEdB_sum;
    public double[][] hWeight;
    public double[] hBias;
    public Activator activator;
    public String activatorName;
    public GradientDescent gradientDescent;
    public double eta;

    public Layer(int inputNodeNum, int outputNodeNum) {
        this.inNodeNum = inputNodeNum;
        this.outNodeNum = outputNodeNum;
        this.x = new double[inputNodeNum];
        this.w = new double[outputNodeNum][inputNodeNum];
        this.b = new double[outputNodeNum];
        this.u = new double[outputNodeNum];
        this.y = new double[outputNodeNum];
        this.c = new double[outputNodeNum];
        this.e = new double[outputNodeNum];
        this.dEdX = new double[inputNodeNum];
        this.dEdW = new double[outputNodeNum][inputNodeNum];
        this.dEdB = new double[outputNodeNum];
        this.dEdU = new double[outputNodeNum];
        this.dYdU = new double[outputNodeNum];
        this.dEdY = new double[outputNodeNum];
        this.dEdW_sum = new double[outputNodeNum][inputNodeNum];
        this.dEdB_sum = new double[outputNodeNum];
        this.hWeight = new double[outputNodeNum][inputNodeNum];
        this.hBias = new double[outputNodeNum];
    }

    public void initialize() {
    }

    public void initialize(double eta, double initialCoef, String activatorName, int gradientDescentType) {
        this.activatorName = activatorName;
        this.eta = eta;
        Random randum = DJ.getRandom();
        int i = 0;
        while (i < this.outNodeNum) {
            int j = 0;
            while (j < this.inNodeNum) {
                this.w[i][j] = initialCoef * randum.nextGaussian();
                this.hWeight[i][j] = 1.0E-8;
                ++j;
            }
            this.b[i] = initialCoef * randum.nextGaussian();
            this.hBias[i] = 1.0E-8;
            ++i;
        }
        this.activator = Activator.createActivator(activatorName);
        switch (gradientDescentType) {
            case 1: {
                this.gradientDescent = new AdaGrad();
                break;
            }
            default: {
                this.gradientDescent = new SGD();
            }
        }
    }

    public double[] getX() {
        return this.x;
    }

    public void setX(double[] x) {
        this.x = x;
    }

    public double[][] getW() {
        return this.w;
    }

    public void setW(double[][] w) {
        this.w = w;
    }

    public double[] getB() {
        return this.b;
    }

    public void setB(double[] b) {
        this.b = b;
    }

    public double[] getU() {
        return this.u;
    }

    public void setU(double[] u) {
        this.u = u;
    }

    public double[] getY() {
        return this.y;
    }

    public void setY(double[] y) {
        this.y = y;
    }

    public double[] getC() {
        return this.c;
    }

    public void setC(double[] c) {
        this.c = c;
    }

    public double[] getE() {
        return this.e;
    }

    public void setE(double[] e) {
        this.e = e;
    }

    public double[] getdEdX() {
        return this.dEdX;
    }

    public void setdEdX(double[] dEdX) {
        this.dEdX = dEdX;
    }

    public double[] getdEdY() {
        return this.dEdY;
    }

    public void setdEdY(double[] dEdY) {
        this.dEdY = dEdY;
    }

    public void forward() {
        this.calcuOutput();
    }

    private void calcuOutput() {
        int i = 0;
        while (i < this.outNodeNum) {
            this.u[i] = this.b[i];
            int j = 0;
            while (j < this.inNodeNum) {
                this.u[i] = this.u[i] + this.w[i][j] * this.x[j];
                ++j;
            }
            ++i;
        }
        this.activator.function(this.u, this.y);
    }

    public void backward(double[] c) {
        this.calcuGradient(c);
    }

    public void backward() {
        this.calcuGradient();
    }

    public void calcuError() {
        int i = 0;
        while (i < this.outNodeNum) {
            this.e[i] = this.y[i] - this.c[i];
            ++i;
        }
        this.dEdY = this.e;
    }

    private void calcuGradient(double[] c) {
        this.activator.derivative(this.u, this.y, this.dYdU);
        int j = 0;
        while (j < this.inNodeNum) {
            this.dEdX[j] = 0.0;
            ++j;
        }
        int i = 0;
        while (i < this.outNodeNum) {
            this.e[i] = this.y[i] - c[i];
            this.dEdU[i] = this.e[i] * this.dYdU[i];
            this.dEdB[i] = this.dEdU[i];
            this.dEdB_sum[i] = this.dEdB_sum[i] + this.dEdB[i];
            int j2 = 0;
            while (j2 < this.inNodeNum) {
                this.dEdW[i][j2] = this.dEdU[i] * this.x[j2];
                this.dEdW_sum[i][j2] = this.dEdW_sum[i][j2] + this.dEdW[i][j2];
                this.dEdX[j2] = this.dEdX[j2] + this.w[i][j2] * this.dEdU[i];
                ++j2;
            }
            ++i;
        }
    }

    /*
     * Exception decompiling
     */
    private void calcuGradient() {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Can't sort instructions [@NONE, blocks:[5] lbl63 : CaseStatement: default:\u000a, @NONE, blocks:[5] lbl63 : CaseStatement: default:\u000a]
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op3rewriters.CompareByIndex.compare(CompareByIndex.java:25)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op3rewriters.CompareByIndex.compare(CompareByIndex.java:8)
         *     at java.base/java.util.TimSort.countRunAndMakeAscending(TimSort.java:360)
         *     at java.base/java.util.TimSort.sort(TimSort.java:220)
         *     at java.base/java.util.Arrays.sort(Arrays.java:1308)
         *     at java.base/java.util.ArrayList.sort(ArrayList.java:1804)
         *     at java.base/java.util.Collections.sort(Collections.java:178)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op3rewriters.SwitchReplacer.buildSwitchCases(SwitchReplacer.java:271)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op3rewriters.SwitchReplacer.replaceRawSwitch(SwitchReplacer.java:258)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op3rewriters.SwitchReplacer.replaceRawSwitches(SwitchReplacer.java:66)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:517)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    public void update() {
        this.gradientDescent.descent();
    }

    private void updateGradient() {
        int i = 0;
        while (i < this.outNodeNum) {
            this.b[i] = this.b[i] - this.eta * this.dEdB_sum[i];
            this.dEdB_sum[i] = 0.0;
            int j = 0;
            while (j < this.inNodeNum) {
                this.w[i][j] = this.w[i][j] - this.eta * this.dEdW_sum[i][j];
                this.dEdW_sum[i][j] = 0.0;
                ++j;
            }
            ++i;
        }
    }

    class AdaGrad
    extends GradientDescent {
        AdaGrad() {
        }

        @Override
        void descent() {
            int i = 0;
            while (i < Layer.this.outNodeNum) {
                Layer.this.hBias[i] = Layer.this.hBias[i] + Layer.this.dEdB_sum[i] * Layer.this.dEdB_sum[i];
                Layer.this.b[i] = Layer.this.b[i] - Layer.this.eta / Math.sqrt(Layer.this.hBias[i]) * Layer.this.dEdB_sum[i];
                Layer.this.dEdB_sum[i] = 0.0;
                int j = 0;
                while (j < Layer.this.inNodeNum) {
                    Layer.this.hWeight[i][j] = Layer.this.hWeight[i][j] + Layer.this.dEdW_sum[i][j] * Layer.this.dEdW_sum[i][j];
                    Layer.this.w[i][j] = Layer.this.w[i][j] - Layer.this.eta / Math.sqrt(Layer.this.hWeight[i][j]) * Layer.this.dEdW_sum[i][j];
                    Layer.this.dEdW_sum[i][j] = 0.0;
                    ++j;
                }
                ++i;
            }
        }
    }

    abstract class GradientDescent {
        GradientDescent() {
        }

        abstract void descent();
    }

    class SGD
    extends GradientDescent {
        SGD() {
        }

        @Override
        void descent() {
            int i = 0;
            while (i < Layer.this.outNodeNum) {
                Layer.this.b[i] = Layer.this.b[i] - Layer.this.eta * Layer.this.dEdB_sum[i];
                Layer.this.dEdB_sum[i] = 0.0;
                int j = 0;
                while (j < Layer.this.inNodeNum) {
                    Layer.this.w[i][j] = Layer.this.w[i][j] - Layer.this.eta * Layer.this.dEdW_sum[i][j];
                    Layer.this.dEdW_sum[i][j] = 0.0;
                    ++j;
                }
                ++i;
            }
        }
    }
}

