/*
 *  Title: DaiJa_V4 ( Digital-Learning Aide Instrument by JAva)
 *  @author Yoshinari Sasaki
 *  @version 4.0
 *  @since 2020.7.1
 *  Copyright: 2020, 2021
 */
package layer;

import java.util.ArrayList;

/**
 * <p> 表　題: Class: PoolingLayer</p>
 * <p> 説　明: プーリング層クラス</p>
 * <p> 著　者: Yoshinari Sasaki</p>
 * <p> 著作権: Copyright (c) 2020, 2021</p>
 * <p> 作成日: 2020.3.28</p>
 */
public class PoolingLayer extends PlaneLayer {
  
  // 変数宣言
  private int padding; // パデング量
  private int allNodeNum; // プーリング層の総出力ノード数  
  
  private int frameNum;// プーリング枠の枚数（フィルタの枚数に等しい）
  private final ArrayList<PoolingFrame> frameList; // プーリング枠のリスト
  
  
  public PoolingLayer(int inputNodeRow, int inputNodeCol) {
    super(inputNodeRow, inputNodeCol);

    frameList = new ArrayList<>();
  }

  // 初期化
  public void initialize(int frameNum, int padding,
          int poolingRow, int poolingCol) {
    this.frameNum = frameNum;
    this.padding = padding;
    
    outNodeRow = (inNodeRow + padding * 2) / poolingRow;
    outNodeCol = (inNodeCol + padding * 2) / poolingCol;
    allNodeNum = frameNum * outNodeRow * outNodeCol;
    
    // 入力と出力のリストを生成
    for (int i = 0; i < frameNum; i++) {
      PoolingFrame pf = new PoolingFrame(
              inNodeRow, inNodeCol, poolingRow, poolingCol);
      frameList.add(pf);
      xList.add(pf.getX());
      yList.add(pf.getY());
      dEdXList.add(pf.get_dEdX());
      dEdYList.add(pf.get_dEdY());
    }
  } // initialize()
  
  // アクセス・メソッド
  @Override
  public void set_xList(ArrayList<double[][]> xList) { 
    this.xList = xList;
    
    // プーリング枠のリストにも設定する
    for (int k = 0; k < xList.size(); k++) {
      frameList.get(k).setX(xList.get(k));
    }
  }
  
  // 順伝播処理
  @Override
  public void forward() {
    // 出力を求める
    for (int i = 0; i < frameNum; i++) {
      frameList.get(i).calcuOutput();
    }
  }

  // 逆伝播処理
  @Override
  public void backward() {
    //勾配を求める
    for (int i = 0; i < frameNum; i++) {
      frameList.get(i).calcuGradient();
    }
  }

  // 更新処理
  @Override
  public void update() {
     // 何もしない
  }
  
  // ---------------------------------------------------------------------------
  /** プーリング枠クラス */
  class PoolingFrame {
    
    // 変数宣言
    private final int poolingRow; // プーリング枠の行数
    private final int poolingCol; // プーリング枠の列数

    private double[][] x; // 入力
    private double[][] y; // 出力
    
    private double[][] dEdX; // dE/dX: 入力による誤差の勾配
    private double[][] dEdY; // dE/dY: 出力による誤差の勾配
    
    private final int[][] maxRow; // 最大値を持つノードの行インデックス
    private final int[][] maxCol; // 最大値を持つノードの列インデックス

    // 生成子
    public PoolingFrame(int inputNodeRow, int inputNodeCol,
                       int poolingRow, int poolingCol) {
    
      this.poolingRow = poolingRow;
      this.poolingCol = poolingCol;

      x = new double[inputNodeRow][inputNodeCol];
      y = new double[outNodeRow][outNodeCol];
    
      dEdX = new double[inputNodeRow][inputNodeCol];
      dEdY = new double[outNodeRow][outNodeCol];

      maxRow = new int[outNodeRow][outNodeCol];
      maxCol = new int[outNodeRow][outNodeCol];
    }
    
    // アクセス・メソッド
    public double[][] getX() { return x; }
    public void setX(double[][] x) { this.x = x; }
    
    public double[][] getY() { return y; }
    public void setY(double[][] y) { this.y = y; }
    
    public double[][] get_dEdX() { return dEdX; }
    public void set_dEdX(double[][] dEdX) { this.dEdX = dEdX; }
    
    public double[][] get_dEdY() { return dEdY; }
    public void set_dEdY(double[][] dEdY) { this.dEdY = dEdY; }
    
    // 出力を求める
    private void calcuOutput() {
      // 出力テンソルを一巡
      for (int i = 0; i < outNodeRow; i++) {
        for (int j = 0; j < outNodeCol; j++) {
          // プーリング枠内を一巡
          double max = 0;
          int maxI = 0; // 最大値を持つノードの行インデックス
          int maxJ = 0; // 最大値を持つノードの列インデックス
          for (int m = 0; m < poolingRow; m++) {
            for (int n = 0; n < poolingCol; n++) {
              // 最大値を抽出
              int p = i * poolingRow + m;
              int q = j * poolingCol + n;
              double val = x[p][q];
              if (val > max) {
                max = val;
                maxI = m;
                maxJ = n;
              }
            } // n
          } // m プーリング枠内を一巡
          
          y[i][j] = max; // 最大値を出力
          maxRow[i][j] = maxI; // 最大値を持つノードの行インデックス
          maxCol[i][j] = maxJ; // 最大値を持つノードの列インデックス
        }
      } // 出力テンソルを一巡
    } // calcuOutput()

    // 勾配を求める
    private void calcuGradient() {
      // 出力テンソルを一巡
      for (int i = 0; i < outNodeRow; i++) {
        for (int j = 0; j < outNodeCol; j++) {
          // プーリング枠内で勾配を設定
          int p = poolingRow * i + maxRow[i][j];
          int q = poolingCol * j + maxCol[i][j];
          dEdX[p][q] = dEdY[i][j];
        }
      } // 出力テンソルを一巡
    } // calcuGradient()
    
    // 更新処理
    public void update() {
      // 何もしない
    }

  } // PoolingFrame
  
} // PoolingLayer

// EOF
