/*
 *  Title: DaiJa_V4 ( Digital-Learning Aide Instrument by JAva)
 *  @author Yoshinari Sasaki
 *  @version 4.0
 *  @since 2020.7.1
 *  Copyright: 2020, 2021
 */
package layer;

import active.Activator;
import java.util.Random;
import util.DJ;

/**
 * <p> 表　題: Class: LinearLayer</p>
 * <p> 説　明:１次元層クラス</p>
 * <p> 著　者: Yoshinari Sasaki</p>
 * <p> 著作権: Copyright (c) 2020, 2021</p>
 * <p> 作成日: 2020.3.12</p>
 */
public abstract class Layer {
  
  // 勾配降下法の選択肢
  public static final int SGD = 0;
  public static final int ADA_GRAD = 1;
  
  // ノード数
  public int inNodeNum; // Input Node Number: 入力ノード数
  public int outNodeNum; // Output Node Number: 出力ノード数
  
  // 変数宣言
  public double[] x; // Input Node: 入力ノード
  public double[][] w; // Weight： 重み
  public double[] b; // Bias： バイアス
  public double[] u; // u： 内部状態量
  public double[] y; // Output Node: 出力ノード
  public double[] c; // Correct answer: 正解
  public double[] e; // Error: 誤差
  
  public double[] dEdX; // dE/dX: 入力による誤差の勾配
  public double[][] dEdW; // dE/dW: 重みによる誤差の勾配
  public double[] dEdB; // dE/dB: バイアスによる誤差の勾配
  public double[] dEdU; // dE/dU: 内部状態量による誤差の勾配
  public double[] dYdU; // dY/dU: 活性化関数の微分
  public double[] dEdY; // dE/dY: 出力による誤差の勾配
  
  public double[][] dEdW_sum; // バッチ用の重みによる誤差勾配の積算値
  public double[] dEdB_sum; // バッチ用のバイアスによる誤差勾配の積算値
  
  public double[][] hWeight; // 重みによる誤差勾配更新のAdaGrad法補助変数
  public double[] hBias; // バイアスによる誤差勾配更新のAdaGrad法補助変数
  
  public Activator activator; // 活性化関数
  public String activatorName; // 活性化関数の名称
  public GradientDescent gradientDescent; // 勾配降下法
  
  public double eta; // 学習係数
  
  // 生成子
  public Layer(int inputNodeNum, int outputNodeNum) {
    inNodeNum = inputNodeNum;
    outNodeNum = outputNodeNum;
    
    x = new double[inputNodeNum];
    w = new double[outputNodeNum][inputNodeNum];
    b = new double[outputNodeNum];
    u = new double[outputNodeNum];
    y = new double[outputNodeNum];
    c = new double[outputNodeNum];
    e = new double[outputNodeNum];
    
    dEdX = new double[inputNodeNum];
    dEdW = new double[outputNodeNum][inputNodeNum];
    dEdB = new double[outputNodeNum];
    dEdU = new double[outputNodeNum];
    dYdU = new double[outputNodeNum];
    dEdY = new double[outputNodeNum];
    
    dEdW_sum = new double[outputNodeNum][inputNodeNum];
    dEdB_sum = new double[outputNodeNum];
    
    hWeight = new double[outputNodeNum][inputNodeNum];
    hBias = new double[outputNodeNum];
  }
  
  //初期化
  public void initialize() { }  

  public void initialize(double eta, double initialCoef,
          String activatorName, int gradientDescentType) {
    
    this.activatorName = activatorName; // 活性化関数の名称
    
    this.eta = eta; // 学習係数
    Random randum =  DJ.getRandom(); // 乱数
    
    for (int i = 0; i < outNodeNum; i++) {
      for (int j = 0; j < inNodeNum; j++) {
        w[i][j] = initialCoef * randum.nextGaussian(); // 重み
        hWeight[i][j] = 1.0E-8;
      }
      b[i] = initialCoef * randum.nextGaussian(); // バイアス
      hBias[i] = 1.0E-8;
    }
    
    // 活性化関数を生成
    activator = Activator.createActivator(activatorName);
    
    // 勾配降下法を生成
    switch (gradientDescentType) {
      case ADA_GRAD:
        gradientDescent = new AdaGrad();
        break;
      case SGD:
      default:      
        gradientDescent = new SGD();
        break;
    }
  }

  // アクセス・メソッド
  public double[] getX() { return x; }
  public void setX(double[] x) { this.x = x; }
  public double[][] getW() { return w; }
  public void setW(double[][] w) { this.w = w; }
  public double[] getB() { return b; }
  public void setB(double[] b) { this.b = b; }
  public double[] getU() { return u; }
  public void setU(double[] u) { this.u = u; }
  public double[] getY() { return y; }
  public void setY(double[] y) { this.y = y; }
  public double[] getC() { return c; }
  public void setC(double[] c) { this.c = c; }
  public double[] getE() { return e; }
  public void setE(double[] e) { this.e = e; }
  
  public double[] getdEdX() { return dEdX; }
  public void setdEdX(double[] dEdX) { this.dEdX = dEdX; }
  public double[] getdEdY() { return dEdY; }
  public void setdEdY(double[] dEdY) { this.dEdY = dEdY; }

  
  // 順伝播処理 ----------------------------------------------------------------
//  public abstract void forward();

  public void forward() {
    calcuOutput(); // 出力を求める
  }
  
  // 出力を求める
  private void calcuOutput() {
    for (int i = 0; i < outNodeNum; i++) {
      u[i] = b[i];
      for (int j = 0; j < inNodeNum; j++) {
        u[i] = u[i] +  w[i][j] * x[j];
      }
    }
    activator.function(u, y);
  }
  
  // 逆伝播処理（出力層用）-----------------------------------------------------
  public void backward(double[] c) {
    calcuGradient(c); //勾配を求める
  }
  
  // 逆伝播処理（中間層用：引数なし）
  public void backward() {
    calcuGradient(); //勾配を求める
  }
  
  // 誤差を求める
  public void calcuError() {
    for (int i = 0; i < outNodeNum; i++) {
      e[i] = y[i] - c[i];
    }
    dEdY = e;
  }
  
  // 勾配を求める
  private void calcuGradient(double[] c) {
    activator.derivative(u, y, dYdU); // 活性化関数の微分

    for (int j = 0; j < inNodeNum; j++) {
      dEdX[j] = 0.0;
    }
    
    for (int i = 0; i < outNodeNum; i++) {
      // 誤差を求める
      e[i] = y[i] - c[i]; // 注）二乗誤差の微分は2*eとなる。2は除外
      dEdU[i] = e[i] * dYdU[i]; // 内部状態量による誤差の勾配
      dEdB[i] = dEdU[i]; // バイアスによる誤差の勾配
      dEdB_sum[i] = dEdB_sum[i] + dEdB[i]; // バッチ用のバイアスによる誤差の勾配
      for (int j = 0; j < inNodeNum; j++) {
        dEdW[i][j] = dEdU[i] * x[j]; // 重みによる誤差の勾配
        dEdW_sum[i][j] = dEdW_sum[i][j] + dEdW[i][j]; // バッチ用の重みによる誤差の勾配
        dEdX[j] = dEdX[j] + w[i][j] * dEdU[i]; // 入力による誤差の勾配
      }
    }

  }
  
  //勾配を求める
  private void calcuGradient() {
    
    switch (activatorName) {
      case Activator.SIGMOID: // シグモイド関数
        for (int i = 0; i < outNodeNum; i++) {
         // シグモイド関数の微分
          dYdU[i] = (1.0 - y[i]) * y[i]; // シグモイド関数の微分
          dEdU[i] = dEdY[i] * dYdU[i]; // 内部状態量による誤差の勾配
          dEdB[i] = dEdU[i]; // バイアスによる誤差の勾配
          dEdB_sum[i] = dEdB_sum[i] + dEdB[i]; // バッチ用のバイアスによる誤差の勾配
          for (int j = 0; j < inNodeNum; j++) {
            dEdW[i][j] = dEdU[i] * x[j]; // 重みによる誤差の勾配
            dEdW_sum[i][j] = dEdW_sum[i][j] + dEdW[i][j]; // バッチ用の重みによる誤差の勾配
            dEdX[j] = dEdX[j] + w[i][j] * dEdU[i]; // 入力による誤差の勾配
          }
        }
        break;
      case Activator.RELU: // ReLU関数 y[i] = reLU(u[i])
        // ReLU演算用に初期化
        for (int j = 0; j < inNodeNum; j++) {
          dEdX[j] = 0.0; // 入力による誤差の勾配
        }

        for (int i = 0; i < outNodeNum; i++) {
           // ReLU関数の微分
          if (u[i] >= 0.0) {
            dYdU[i] = 1.0;
            dEdU[i] = dEdY[i]; // 内部状態量による誤差の勾配
            dEdB[i] = dEdU[i]; // バイアスによる誤差の勾配
            dEdB_sum[i] = dEdB_sum[i] + dEdB[i]; // バッチ用のバイアスによる誤差の勾配
            for (int j = 0; j < inNodeNum; j++) {
              dEdW[i][j] = dEdU[i] * x[j]; // 重みによる誤差の勾配
              dEdW_sum[i][j] = dEdW_sum[i][j] + dEdW[i][j]; // バッチ用の重みによる誤差の勾配
              dEdX[j] = dEdX[j] + w[i][j] * dEdU[i]; // 入力による誤差の勾配
            }
          }
        }
        break;
      case Activator.IDENTITY: // 恒等関数
      default:
        for (int j = 0; j < inNodeNum; j++) {
          dEdX[j] = 1.0; // 入力による誤差の勾配
        }
        break;    
    }
    
  } // calcuGradient()
  
  // 更新処理 ------------------------------------------------------------------
  public void update() {
    gradientDescent.descent(); // 勾配降下法
  }
  
  // パラメータを更新する
  private void updateGradient() {
    // 勾配降下法（確率的勾配降下法SDG）
    for (int i = 0; i < outNodeNum; i++) {
      b[i] = b[i] - eta * dEdB_sum[i]; // バイアスによる誤差の勾配
      dEdB_sum[i] = 0.0;
      for (int j = 0; j < inNodeNum; j++) {
        w[i][j] = w[i][j] - eta * dEdW_sum[i][j]; // 重みによる誤差の勾配
        dEdW_sum[i][j] = 0.0;
      }
    }
  }

// -----------------------------------------------------------------------------
  
// 勾配降下法
abstract class GradientDescent {
  abstract void descent();
}
  
  // 勾配降下法（確率的勾配降下法SGD）
  class SGD extends GradientDescent {
    @Override
    void descent() {
      for (int i = 0; i < outNodeNum; i++) {
        b[i] = b[i] - eta * dEdB_sum[i]; // バイアスによる誤差の勾配
        dEdB_sum[i] = 0.0;
        for (int j = 0; j < inNodeNum; j++) {
          w[i][j] = w[i][j] - eta * dEdW_sum[i][j]; // 重みによる誤差の勾配
          dEdW_sum[i][j] = 0.0;
        }
      }
    } // descent()
  } // SGD
  
  // 勾配降下法（AdaGrad法）
  class AdaGrad extends GradientDescent {
    @Override
    void descent() {
      for (int i = 0; i < outNodeNum; i++) {
        hBias[i] = hBias[i] + (dEdB_sum[i] * dEdB_sum[i]);
        b[i] = b[i] - (eta / Math.sqrt(hBias[i]) * dEdB_sum[i]);
        dEdB_sum[i] = 0.0;
        for (int j = 0; j < inNodeNum; j++) {
          hWeight[i][j] = hWeight[i][j] + (dEdW_sum[i][j] * dEdW_sum[i][j]);
          w[i][j] = w[i][j] - (eta / Math.sqrt(hWeight[i][j]) * dEdW_sum[i][j]);
          dEdW_sum[i][j] = 0.0;
        }
      }
    } // descent()
  } // class AdaGrad

} // class Layer

// EOF 