/*
 * Decompiled with CFR 0.152.
 */
package layer;

import java.util.ArrayList;
import java.util.Random;
import layer.Layer;
import layer.PlaneLayer;
import util.DJ;

public class ConvoluteLayer
extends PlaneLayer {
    private double eta;
    private double initialCoef;
    private Random randum;
    private int filterNum;
    private final ArrayList<Filter> filterList = new ArrayList();
    private ArrayList<double[][]> filterWeightList = new ArrayList();

    public ConvoluteLayer(int inputNodeRow, int inputNodeCol) {
        super(inputNodeRow, inputNodeCol);
    }

    public void initialize(double eta, double initialCoef, int filterNum, int filterHight, int filterWidth, int filterStride, int gradientDescentType) {
        this.eta = eta;
        this.initialCoef = initialCoef;
        this.filterNum = filterNum;
        this.outNodeRow = (this.inNodeRow - filterHight) / filterStride + 1;
        this.outNodeCol = (this.inNodeCol - filterWidth) / filterStride + 1;
        this.randum = DJ.getRandom();
        double[][] x = new double[this.inNodeRow][this.inNodeCol];
        for (int i = 0; i < filterNum; ++i) {
            Filter filter = new Filter(filterHight, filterWidth, filterStride);
            filter.initialize(gradientDescentType);
            this.filterList.add(filter);
            filter.setX(x);
            this.xList.add(filter.getX());
            this.yList.add(filter.getY());
            this.dEdXList.add(filter.get_dEdX());
            this.dEdYList.add(filter.get_dEdY());
            this.filterWeightList.add(filter.getFilterWeight());
        }
    }

    @Override
    public void set_dEdYList(ArrayList<double[][]> dEdYList) {
        this.dEdYList = dEdYList;
        for (int k = 0; k < dEdYList.size(); ++k) {
            this.filterList.get(k).set_dEdY(dEdYList.get(k));
        }
    }

    public ArrayList<double[][]> getFilterWeightList() {
        return this.filterWeightList;
    }

    public void setFilterWeightList(ArrayList<double[][]> filterWeightList) {
        this.filterWeightList = filterWeightList;
    }

    @Override
    public void forward() {
        for (int i = 0; i < this.filterNum; ++i) {
            this.filterList.get(i).calcuOutput();
        }
    }

    @Override
    public void backward() {
        for (int i = 0; i < this.filterNum; ++i) {
            this.filterList.get(i).calcuGradient();
        }
    }

    @Override
    public void update() {
        for (int i = 0; i < this.filterNum; ++i) {
            this.filterList.get(i).gradientDescent.descent();
        }
    }

    class Filter {
        private final int filterHight;
        private final int filterWidth;
        private final int filterStride;
        private double[][] x;
        private double[][] y;
        private double[][] filterWeight;
        private double filterBias;
        private final double[][] u;
        private double[][] dEdX;
        private double[][] dEdY;
        private final double[][] dEdW_sum;
        private double dEdB_sum;
        private final double[][] hWeight;
        private double hBias;
        private Layer.GradientDescent gradientDescent;

        public Filter(int filterHight, int filterWidth, int filterStride) {
            this.filterHight = filterHight;
            this.filterWidth = filterWidth;
            this.filterStride = filterStride;
            this.filterWeight = new double[filterHight][filterWidth];
            this.filterBias = 0.0;
            this.u = new double[ConvoluteLayer.this.outNodeRow][ConvoluteLayer.this.outNodeCol];
            this.y = new double[ConvoluteLayer.this.outNodeRow][ConvoluteLayer.this.outNodeCol];
            this.dEdX = new double[ConvoluteLayer.this.inNodeRow][ConvoluteLayer.this.inNodeCol];
            this.dEdY = new double[ConvoluteLayer.this.outNodeRow][ConvoluteLayer.this.outNodeCol];
            this.dEdW_sum = new double[filterHight][filterWidth];
            this.dEdB_sum = 0.0;
            this.hWeight = new double[filterHight][filterWidth];
            this.hBias = 0.0;
        }

        public void initialize(int gradientDescentType) {
            for (int i = 0; i < this.filterHight; ++i) {
                for (int j = 0; j < this.filterWidth; ++j) {
                    this.filterWeight[i][j] = ConvoluteLayer.this.initialCoef * ConvoluteLayer.this.randum.nextGaussian();
                    this.hWeight[i][j] = 1.0E-8;
                }
                this.filterBias = ConvoluteLayer.this.initialCoef * ConvoluteLayer.this.randum.nextGaussian();
                this.hBias = 1.0E-8;
            }
            switch (gradientDescentType) {
                case 1: {
                    this.gradientDescent = new AdaGrad();
                    break;
                }
                default: {
                    this.gradientDescent = new SGD();
                }
            }
        }

        public double[][] getX() {
            return this.x;
        }

        public void setX(double[][] x) {
            this.x = x;
        }

        public double[][] getY() {
            return this.y;
        }

        public void setY(double[][] y) {
            this.y = y;
        }

        public double[][] getFilterWeight() {
            return this.filterWeight;
        }

        public void setFilterWeight(double[][] filterWeight) {
            this.filterWeight = filterWeight;
        }

        public double[][] get_dEdX() {
            return this.dEdX;
        }

        public void set_dEdX(double[][] dEdX) {
            this.dEdX = dEdX;
        }

        public double[][] get_dEdY() {
            return this.dEdY;
        }

        public void set_dEdY(double[][] dEdY) {
            this.dEdY = dEdY;
        }

        private void calcuOutput() {
            for (int i = 0; i < ConvoluteLayer.this.outNodeRow; ++i) {
                for (int j = 0; j < ConvoluteLayer.this.outNodeCol; ++j) {
                    this.u[i][j] = this.filterBias;
                    for (int m = 0; m < this.filterHight; ++m) {
                        for (int n = 0; n < this.filterWidth; ++n) {
                            int p = i * this.filterStride + m;
                            int q = j * this.filterStride + n;
                            this.u[i][j] = this.u[i][j] + this.x[p][q] * this.filterWeight[m][n];
                        }
                    }
                    this.y[i][j] = this.u[i][j] <= 0.0 ? 0.0 : this.u[i][j];
                }
            }
        }

        private void calcuGradient() {
            for (int i = 0; i < ConvoluteLayer.this.outNodeRow; ++i) {
                for (int j = 0; j < ConvoluteLayer.this.outNodeCol; ++j) {
                    double delta = this.dEdY[i][j];
                    if (!(this.u[i][j] > 0.0)) continue;
                    for (int m = 0; m < this.filterHight; ++m) {
                        for (int n = 0; n < this.filterWidth; ++n) {
                            int p = i * this.filterStride + m;
                            int q = j * this.filterStride + n;
                            double deltaX = delta * this.x[p][q];
                            this.dEdW_sum[m][n] = this.dEdW_sum[m][n] + deltaX;
                            this.dEdX[p][q] = delta * this.filterWeight[m][n];
                        }
                    }
                    this.dEdB_sum += delta;
                }
            }
        }

        class AdaGrad
        extends Layer.GradientDescent {
            AdaGrad() {
                super(ConvoluteLayer.this);
            }

            @Override
            void descent() {
                Filter.this.hBias = Filter.this.hBias + Filter.this.dEdB_sum * Filter.this.dEdB_sum;
                Filter.this.filterBias = Filter.this.filterBias - ConvoluteLayer.this.eta / Math.sqrt(Filter.this.hBias) * Filter.this.dEdB_sum;
                Filter.this.dEdB_sum = 0.0;
                for (int i = 0; i < Filter.this.filterHight; ++i) {
                    for (int j = 0; j < Filter.this.filterWidth; ++j) {
                        ((Filter)Filter.this).hWeight[i][j] = Filter.this.hWeight[i][j] + Filter.this.dEdW_sum[i][j] * Filter.this.dEdW_sum[i][j];
                        ((Filter)Filter.this).filterWeight[i][j] = Filter.this.filterWeight[i][j] - ConvoluteLayer.this.eta / Math.sqrt(Filter.this.hWeight[i][j]) * Filter.this.dEdW_sum[i][j];
                        ((Filter)Filter.this).dEdW_sum[i][j] = 0.0;
                    }
                }
            }
        }

        class SGD
        extends Layer.GradientDescent {
            SGD() {
                super(ConvoluteLayer.this);
            }

            @Override
            void descent() {
                Filter.this.filterBias = Filter.this.filterBias - ConvoluteLayer.this.eta * Filter.this.dEdB_sum;
                Filter.this.dEdB_sum = 0.0;
                for (int i = 0; i < Filter.this.filterHight; ++i) {
                    for (int j = 0; j < Filter.this.filterWidth; ++j) {
                        ((Filter)Filter.this).filterWeight[i][j] = Filter.this.filterWeight[i][j] - ConvoluteLayer.this.eta * Filter.this.dEdW_sum[i][j];
                        ((Filter)Filter.this).dEdW_sum[i][j] = 0.0;
                    }
                }
            }
        }
    }
}

