/*
 * Decompiled with CFR 0.152.
 */
package layer;

import java.util.ArrayList;
import layer.PlaneLayer;

public class PoolingLayer
extends PlaneLayer {
    private int padding;
    private int allNodeNum;
    private int frameNum;
    private final ArrayList<PoolingFrame> frameList = new ArrayList();

    public PoolingLayer(int inputNodeRow, int inputNodeCol) {
        super(inputNodeRow, inputNodeCol);
    }

    public void initialize(int frameNum, int padding, int poolingRow, int poolingCol) {
        this.frameNum = frameNum;
        this.padding = padding;
        this.outNodeRow = (this.inNodeRow + padding * 2) / poolingRow;
        this.outNodeCol = (this.inNodeCol + padding * 2) / poolingCol;
        this.allNodeNum = frameNum * this.outNodeRow * this.outNodeCol;
        for (int i = 0; i < frameNum; ++i) {
            PoolingFrame pf = new PoolingFrame(this.inNodeRow, this.inNodeCol, poolingRow, poolingCol);
            this.frameList.add(pf);
            this.xList.add(pf.getX());
            this.yList.add(pf.getY());
            this.dEdXList.add(pf.get_dEdX());
            this.dEdYList.add(pf.get_dEdY());
        }
    }

    @Override
    public void set_xList(ArrayList<double[][]> xList) {
        this.xList = xList;
        for (int k = 0; k < xList.size(); ++k) {
            this.frameList.get(k).setX(xList.get(k));
        }
    }

    @Override
    public void forward() {
        for (int i = 0; i < this.frameNum; ++i) {
            this.frameList.get(i).calcuOutput();
        }
    }

    @Override
    public void backward() {
        for (int i = 0; i < this.frameNum; ++i) {
            this.frameList.get(i).calcuGradient();
        }
    }

    @Override
    public void update() {
    }

    class PoolingFrame {
        private final int poolingRow;
        private final int poolingCol;
        private double[][] x;
        private double[][] y;
        private double[][] dEdX;
        private double[][] dEdY;
        private final int[][] maxRow;
        private final int[][] maxCol;

        public PoolingFrame(int inputNodeRow, int inputNodeCol, int poolingRow, int poolingCol) {
            this.poolingRow = poolingRow;
            this.poolingCol = poolingCol;
            this.x = new double[inputNodeRow][inputNodeCol];
            this.y = new double[PoolingLayer.this.outNodeRow][PoolingLayer.this.outNodeCol];
            this.dEdX = new double[inputNodeRow][inputNodeCol];
            this.dEdY = new double[PoolingLayer.this.outNodeRow][PoolingLayer.this.outNodeCol];
            this.maxRow = new int[PoolingLayer.this.outNodeRow][PoolingLayer.this.outNodeCol];
            this.maxCol = new int[PoolingLayer.this.outNodeRow][PoolingLayer.this.outNodeCol];
        }

        public double[][] getX() {
            return this.x;
        }

        public void setX(double[][] x) {
            this.x = x;
        }

        public double[][] getY() {
            return this.y;
        }

        public void setY(double[][] y) {
            this.y = y;
        }

        public double[][] get_dEdX() {
            return this.dEdX;
        }

        public void set_dEdX(double[][] dEdX) {
            this.dEdX = dEdX;
        }

        public double[][] get_dEdY() {
            return this.dEdY;
        }

        public void set_dEdY(double[][] dEdY) {
            this.dEdY = dEdY;
        }

        private void calcuOutput() {
            for (int i = 0; i < PoolingLayer.this.outNodeRow; ++i) {
                for (int j = 0; j < PoolingLayer.this.outNodeCol; ++j) {
                    double max = 0.0;
                    int maxI = 0;
                    int maxJ = 0;
                    for (int m = 0; m < this.poolingRow; ++m) {
                        for (int n = 0; n < this.poolingCol; ++n) {
                            int p = i * this.poolingRow + m;
                            int q = j * this.poolingCol + n;
                            double val = this.x[p][q];
                            if (!(val > max)) continue;
                            max = val;
                            maxI = m;
                            maxJ = n;
                        }
                    }
                    this.y[i][j] = max;
                    this.maxRow[i][j] = maxI;
                    this.maxCol[i][j] = maxJ;
                }
            }
        }

        private void calcuGradient() {
            for (int i = 0; i < PoolingLayer.this.outNodeRow; ++i) {
                for (int j = 0; j < PoolingLayer.this.outNodeCol; ++j) {
                    int p = this.poolingRow * i + this.maxRow[i][j];
                    int q = this.poolingCol * j + this.maxCol[i][j];
                    this.dEdX[p][q] = this.dEdY[i][j];
                }
            }
        }

        public void update() {
        }
    }
}

