/*
 *  Title: DaiJa_V4 ( Digital-Learning Aide Instrument by JAva)
 *  @author Yoshinari Sasaki
 *  @version 4.0
 *  @since 2020.7.1
 *  Copyright: 2020, 2021
 */
package layer;

import java.util.ArrayList;
import java.util.Random;
import static layer.Layer.ADA_GRAD;
import static layer.Layer.SGD;
import util.DJ;

/**
 * <p> 表　題: Class: ConvoluteLayer</p>
 * <p> 説　明:畳込み層クラス</p>
 * <p> 著　者: Yoshinari Sasaki</p>
 * <p> 著作権: Copyright (c) 2020, 2021</p>
 * <p> 作成日: 2020.3.26</p>
 */
public class ConvoluteLayer extends PlaneLayer {
  
  // 変数宣言
  private double eta; // 学習係数
  private double initialCoef; // 初期化係数
  private Random randum; // 乱数
  
  private int filterNum;// フィルタの枚数
  private final ArrayList<Filter> filterList; // フィルター・リスト
  private ArrayList<double[][]> filterWeightList; // フィルターの重みリスト
  
  public ConvoluteLayer(int inputNodeRow, int inputNodeCol) {
    super(inputNodeRow, inputNodeCol);

    filterList = new ArrayList<>();
    filterWeightList = new ArrayList<>();
  }

  // 初期化
  public void initialize(double eta, double initialCoef,
          int filterNum, int filterHight, int filterWidth, int filterStride,
           int gradientDescentType) {
    this.eta = eta;
    this.initialCoef = initialCoef;
    this.filterNum = filterNum; // フィルタの枚数
    
    outNodeRow = (inNodeRow - filterHight) / filterStride + 1;
    outNodeCol = (inNodeCol - filterWidth) / filterStride + 1;
    
    randum =  DJ.getRandom(); // 乱数
    
    // 入出力、フィルタ、勾配のリストを生成
    double[][] x = new double[inNodeRow][inNodeCol]; // 入力

    
    for (int i = 0; i < filterNum; i++) {
      Filter filter = new Filter(filterHight, filterWidth, filterStride);
      filter.initialize(gradientDescentType);
      filterList.add(filter);
      filter.setX(x); // フィルタにｘへの参照を設定
      xList.add(filter.getX());
      yList.add(filter.getY());
      dEdXList.add(filter.get_dEdX());
      dEdYList.add(filter.get_dEdY());
      filterWeightList.add(filter.getFilterWeight());
    }
  } // initialize()
  
  @Override
  public void set_dEdYList(ArrayList<double[][]> dEdYList) { 
    this.dEdYList = dEdYList;
    
    // フィルタのリストにも設定する
    for (int k = 0; k < dEdYList.size(); k++) {
      filterList.get(k).set_dEdY(dEdYList.get(k));
    }
  }
  
  public ArrayList<double[][]> getFilterWeightList() {
    return filterWeightList; }
  public void setFilterWeightList(ArrayList<double[][]> filterWeightList) { 
    this.filterWeightList = filterWeightList;
  }
  
  // 順伝播処理
  @Override
  public void forward() {
    // 出力を求める
    for (int i = 0; i < filterNum; i++) {
      ((Filter)filterList.get(i)).calcuOutput();
    }
  }

  // 逆伝播処理
  @Override
  public void backward() {
    //勾配を求める
    for (int i = 0; i < filterNum; i++) 
      ((Filter)filterList.get(i)).calcuGradient();
    
  }

  // 更新処理
  @Override
  public void update() {
     // 勾配降下法
    for (int i = 0; i < filterNum; i++) {
        filterList.get(i).gradientDescent.descent();
    }
  }
  
  
// -----------------------------------------------------------------------------
/** フィルタ・クラス */
class Filter {
  
  // 変数宣言
  private final int filterHight; // フィルタの行数
  private final int filterWidth; // フィルタの列数
  private final int filterStride; // フィルタの移動量
  
  private double[][] x; // 入力
  private double[][] y; // 出力
  private double[][] filterWeight; // フィルタ値テンソル（フィルタの重み）
  private double filterBias; // フィルタ・バイアス
  private final double[][] u; // u： 内部状態量
  
  private double[][] dEdX; // dE/dX: 入力による誤差の勾配
  private double[][] dEdY; // dE/dY: 出力による誤差の勾配
//  private double[][] dEdW; // dE/dW: 重みによる誤差の勾配
//  private double dEdB; // dE/dB: バイアスによる誤差の勾配
//  private double[][] dEdU; // dE/dU: 内部状態量による誤差の勾配
//  private double[][] dYdU; // dY/dU: 活性化関数の微分

  private final double[][] dEdW_sum; // バッチ用の重みによる誤差勾配の積算値
  private double dEdB_sum; // バッチ用のバイアスによる誤差勾配の積算値
  
  private final double[][] hWeight; // 重みによる誤差勾配更新のAdaGrad法補助変数
  private double hBias; // バイアスによる誤差勾配更新のAdaGrad法補助変数
  
  private GradientDescent gradientDescent; // 勾配降下法
//#  private Activator activator; // 活性化関数
//#  private ReLU reLU; // 活性化関数
  
  // 生成子
  public Filter(int filterHight, int filterWidth, int filterStride) {
    this.filterHight = filterHight;
    this.filterWidth = filterWidth;
    this.filterStride = filterStride;
    
//    x = new double[inNodeRow][inNodeCol];
    filterWeight  = new double[filterHight][filterWidth];
    filterBias = 0.0;
    u  = new double[outNodeRow][outNodeCol];
    y = new double[outNodeRow][outNodeCol];
    
    dEdX = new double[inNodeRow][inNodeCol];
    dEdY = new double[outNodeRow][outNodeCol];
    
//    dEdW = new double[outNodeRow][outNodeCol];
//    dEdB = 0.0;
//    dEdU = new double[outNodeRow][outNodeCol];
//    dYdU = new double[outNodeRow][outNodeCol];
    
    dEdW_sum = new double[filterHight][filterWidth];
    dEdB_sum = 0.0;
    
    hWeight = new double[filterHight][filterWidth];
    hBias = 0.0;
    
  }
    
//#  public void initialize(String activatorName, int gradientDescentType) {
  public void initialize(int gradientDescentType) {
    
    for (int i = 0; i < filterHight; i++) {
      for (int j = 0; j < filterWidth; j++) {
        filterWeight[i][j] = initialCoef * randum.nextGaussian(); // 重み
        hWeight[i][j] = 1.0E-8;
      }
      filterBias = initialCoef * randum.nextGaussian(); // バイアス
      hBias = 1.0E-8;
    }
    
    // 活性化関数を生成
//#    activator = Activator.createActivator(activatorName);
//#    reLU = new ReLU();
    
    // 勾配降下法を生成
    switch (gradientDescentType) {
      case ADA_GRAD:
        gradientDescent = new AdaGrad();
        break;
      case SGD:
      default:      
        gradientDescent = new SGD();
        break;
    }
    
  } // initialize()
  
    // アクセス・メソッド
    public double[][] getX() { return x; }
    public void setX(double[][] x) { this.x = x; }
    
    public double[][] getY() { return y; }
    public void setY(double[][] y) { this.y = y; }
    
     public double[][] getFilterWeight() { return filterWeight; }
    public void setFilterWeight(double[][] filterWeight) {
      this.filterWeight = filterWeight; }
    
   public double[][] get_dEdX() { return dEdX; }
    public void set_dEdX(double[][] dEdX) { this.dEdX = dEdX; }
    
    public double[][] get_dEdY() { return dEdY; }
    public void set_dEdY(double[][] dEdY) { this.dEdY = dEdY; }
    
  // 出力を求める
  private void calcuOutput() {
    // 出力テンソルを一巡
    for (int i = 0; i < outNodeRow; i++) {
      for (int j = 0; j < outNodeCol; j++) {
        u[i][j] = filterBias;
        // フィルタ内を一巡
        for (int m = 0; m < filterHight; m++) {
          for (int n = 0; n < filterWidth; n++) {
            int p = i * filterStride + m;
            int q = j * filterStride + n;
            u[i][j] = u[i][j] + x[p][q] * filterWeight[m][n];
          }
        }
        
        // 活性化関数：ReLU関数
//#        activator.function(u, y);
//#        reLU.function(u, y);
        if (u[i][j] <= 0.0) y[i][j] = 0.0; 
        else y[i][j] = u[i][j];
      }
    }
  } // calcuOutput()
  
  //勾配を求める
  private void calcuGradient() {
    
//#    activator.derivative(u, y, dYdU); // 活性化関数の微分
    
    // 出力テンソル内一巡
    for (int i = 0; i < outNodeRow; i++) {
      for (int j = 0; j < outNodeCol; j++) {
        
        double delta = dEdY[i][j]; // delta
        
        if (u[i][j] > 0.0) {  // ReLu関数：uの正負で積分値は１か０
          
          // フィルタ内を一巡
          for (int m = 0; m < filterHight; m++) {
            for (int n = 0; n < filterWidth; n++) {
              int p = i * filterStride + m; // 入力データの行インデックス
              int q = j * filterStride + n; // 入力データの列インデックス
              
              // フィルタの重みによる誤差の勾配dEdW
              double deltaX = delta * x[p][q];
              dEdW_sum[m][n] = dEdW_sum[m][n] + deltaX;
               
              // 入力xによる誤差の勾配dEdX
              dEdX[p][q] = delta * filterWeight[m][n];
            } // n
          } // m
          
          // バイアスbによる誤差の勾配dEdB
          dEdB_sum = dEdB_sum + delta;
          
        } // ReLu
        
      } // j
    } // i
    
  } // calcuGradient()
  
  
  // ---------------------------------------------------------------------------
  
  // 勾配降下法（確率的勾配降下法SGD）
  class SGD extends GradientDescent {
    
    @Override
    void descent() {
      // バイアスによる誤差の勾配
      filterBias = filterBias - eta * dEdB_sum; // バイアスによる誤差の勾配
        dEdB_sum = 0.0;
      for (int i = 0; i < filterHight; i++) {
        for (int j = 0; j < filterWidth; j++) {
          // 重みによる誤差の勾配
          filterWeight[i][j] = filterWeight[i][j] - eta * dEdW_sum[i][j];
          dEdW_sum[i][j] = 0.0;
        }
      }
    } // descent()
    
  } // SGD
  
  // 勾配降下法（AdaGrad法）
  class AdaGrad extends GradientDescent {
    
    @Override
    void descent() {
      // バイアスによる誤差の勾配
      hBias = hBias + (dEdB_sum * dEdB_sum);
      filterBias = filterBias - (eta / Math.sqrt(hBias) * dEdB_sum);
      dEdB_sum = 0.0;
      for (int i = 0; i < filterHight; i++) {
        for (int j = 0; j < filterWidth; j++) {
          // 重みによる誤差の勾配
          hWeight[i][j] = hWeight[i][j] + (dEdW_sum[i][j] * dEdW_sum[i][j]);
          filterWeight[i][j] = filterWeight[i][j]
                  - (eta / Math.sqrt(hWeight[i][j]) * dEdW_sum[i][j]);
          dEdW_sum[i][j] = 0.0;
        }
      }
    } // descent()
    
  } // class AdaGrad

} // class Filter
  
} // class ConvoluteLayer

// EOF

