/*
 *  Title: DaiJa_V4 (Digital-Learning Aide Instrument by JAva)
 *  @author Yoshinari Sasaki
 *  @version 4.0
 *  @since 2020.7.1
 *  Copyright: 2020, 2021
 */
package task.control;

import active.Activator;
import data.Control;
import layer.ControlLayer;
import layer.Layer;
import task.Task;
import util.DJ;
import util.TimeStamp;
import view.PatternViewer;

/**
 * <p> 表　題: Class: Inverser1</p>
 * <p> 説　明: ニューラルネットワークによる速度制御（ステップ応答） 
 *             制御対象の逆関数を学習し、力操作量を出力する</p>
 * <p> 著　者: Yoshinari Sasaki</p>
 * <p> 著作権: Copyright (c) 2020, 2021</p>
 * <p> 作成日: 2020.07.22</p>
 */
public class Inverser1 extends Task {
  
  // 学習・試行制御
  int epoch = 50000; // 1001; // エポックの回数
  int batchNum = 1; //20; // 10; // バッチ数
  int interval = 500; //1; // 経過表示間隔
//  double lastForce; // 前回の力操作量の推定値
  
  // 学習パラメータ
  double initialCoef = 1.0; //0.8; // 重みとバイアスの初期値係数
  double eta = 0.8; //0.5; // 学習係数
  //$  double dropOutRate = 0.0; // 0.5; // ドロップ・アウト率 （0.0～0.5）
  int dataNum = 1000; // エポック毎の学習データ数　1[s]/1[ms] = 1000 [個]

  // 制御系のパラメータ
  int speedGain = 1280; //2048 // 比例ゲイン（速度ループゲイン）
  int integralGain = 0; // 積分ゲイン
  double inertia = 1.0; // 慣性質量［kgm^2/sec^2］
  double friction = 2.0; // 粘性摩擦[kg/sec]

  // 各層のノード数
  int inNodeNum = 1; // 入力層のノード数
  int midNodeNum = 4; //8; // 中間層のノード数
  int outNodeNum = 1; // 出力層のノード数
  
  // レイヤー構成
  private ControlLayer middleLayer0; // 中間層０
  //#  private ControlLayer middleLayer1; // 中間層１
  private ControlLayer outputLayer; // 出力層
  
  // 誤差変数
  double squareError; // 二乗誤差
  double meanError; // 二乗誤差データの平均
  double aveError; // エポック毎の誤差の平均値
  
  /**
   * Taskクラスで新しく作られたスレッドから呼び出される
   */
  @Override
  public void runTask() {
    DJ._print("・タスク開始日時：", TimeStamp.getTimeFormated());
    beginTime = System.currentTimeMillis(); // タスク開始時刻
    
    // パターン・ビューワ・ランチャを得る
    patternViewerFlag = true; // パターン・ビューワを表示する
    patternData0 = new double[10][dataNum]; // ID,出力(力),速度,正解(速度所望値)
    patternData1 = new double[10][dataNum]; // ID,力操作量,加速度,所望速度
    patternViewerLauncher = DJ.pattern( // 判定パターンとデータ
        PatternViewer.PATTERN_TUNER, patternData0, "Inverser1:Speed Step Response",
        PatternViewer.PATTERN_TUNER, patternData1, "Inverser1:Acceleration, Force"); 
    
    // グラフ・ビューワ・ランチャを得る
    graphShift = 10; // グラフの縦軸を移動
    graphData = new double[2];
    dataName = new String[2]; // データ名
    dataName[0] = "squareError"; // 学習時の二乗誤差
    dataName[1] = "aveError"; // 学習時の平均二乗誤差
    graphViewerLauncher = DJ.graph(epoch, interval, dataName, graphData);
    
    inverser(); // タスク本体の呼び出し
  }
  
  /** 
   * ステップ応答を正解データとして制御対象の逆関数を学習する。
   */
  public void inverser() {
    DJ._print("DaiJa_V3, Inverser1.inverser() ==========================");
    
    DJ._print("・パラメータ");
    DJ.print_(" 重みとバイアスの初期値係数:initialCoef=", initialCoef);
    DJ.print_(", 学習係数:eta=",eta);
    DJ.print(", エポックの回数:epoch=",epoch);
    
    DJ.print_(" バッチ数:batchNum=",batchNum);
    DJ.print_(", 経過表示間隔:interval=",interval);
    DJ.print(", 学習データ数：dataNum=", dataNum);
    
    DJ.print("・制御系のパラメータ");
    DJ.print_(" 比例ゲイン:speedGain=",speedGain);
    DJ.print_(", 積分ゲイン:integralGain=",integralGain);
    DJ.print_(", 慣性質量:inertia=",inertia);
    DJ.print(", 粘性摩擦:friction=",friction);
    
    DJ._print("・各層のノード数");
    DJ.print_(" 入力層ノード数:inNodeNum=",inNodeNum);
    DJ.print_(", 中間層ノード数:midNodeNum=",midNodeNum);
    DJ.print(", 出力層ノード数:outNodeNum=",outNodeNum);
    
    DJ._print("・制御系と制御対象（負荷）を設定");
    Control control = new Control();
    control.setParameter(speedGain, integralGain, 1.0, 0.0); // パラメータ
    control.resetControl();
    
    DJ._print("　入力データと正解データを作成");
    double[] inputData = new double[dataNum]; // 入力データ
    double[] correctData = new double[dataNum]; // 正解データ
    double[] forceData = new double[dataNum]; // 力操作量データ
    double[] acceleration = new double[dataNum]; // 加速度データ
    
    double inc = 1.0 / (double)dataNum;
    double clock = -inc;
    for (int j = 0; j < dataNum; j++) {
      clock = clock + inc;
      // DJ._print("・制御対象の入力（力操作量）と出力（応答）を収集する");
      control.execute();
      forceData[j] = control.getForce(); // 力の操作量
      acceleration[j] = control.getAcceleration() * 100.0; // 応答加速度

      inputData[j] = clock; // 入力データ
/**/      correctData[j] = control.speed_step(clock); // 正解データ（理論ステップ速度）
//**/      correctData[j] = control.spd_ramp(clock) * 0.6; // 正解データ（理論ランプ速度）;
//**/      correctData[j] = control.speed_sCurve(clock); // 正解データ（Ｓ字カーブ速度）;

      patternData0[0][j] = clock;
      patternData1[0][j] = clock;
    }
    
    DJ._print("・学習対象（制御系と制御対象）の実行準備");
    control.setParameter(0, 0, inertia, friction); // パラメータ
    control.stopSpeedControl(); // 速度制御を止める（無効化する）
    
    // DJ.print(" 順序をランダムにシャッフルしたインデックスのリスト");
    //#    ArrayList<Integer> indexList = DJ.permutationRandom(dataNum);
    //#    DJ.print("inputData", inputData);
    
    DJ._print("・ニューラルネットの各層の初期化");
    middleLayer0 = new ControlLayer(inNodeNum, midNodeNum);
    middleLayer0.initialize(eta, initialCoef, Activator.SIGMOID, Layer.SGD);
//    middleLayer0.initialize(eta, initialCoef, Activator.SIGMOID, Layer.ADA_GRAD);
    double[] mid0X = middleLayer0.getX();
    double[] mid0Y = middleLayer0.getY();
    // double[] mid0C = middleLayer0.getC();
    // double[] mid0E = middleLayer0.getE();
    // double[] mid0dEdX = middleLayer0.getdEdX();
    //# double[] mid0dEdY = middleLayer0.getdEdY(); // 参照を代入（注１）

    // middleLayer1 = new ControlLayer(midNodeNum, midNodeNum);
    // middleLayer1.initialize(eta, initialCoef, Activator.SIGMOID, Layer.SGD);
    //# double[] mid1X = middleLayer1.getX(); // 参照を代入（注３）
    // double[] mid1Y = middleLayer1.getY();
    // double[] mid1C = middleLayer1.getC();
    // double[] mid1E = middleLayer1.getE();
    // double[] mid1dEdX = middleLayer1.getdEdX();
    //# double[] mid1dEdY = middleLayer1.getdEdY(); // 参照を代入（注２）

    outputLayer = new ControlLayer(midNodeNum, outNodeNum);
    outputLayer.initialize(eta, initialCoef, Activator.IDENTITY, Layer.SGD);
    // outputLayer.initialize(eta, initialCoef, Activator.IDENTITY, Layer.ADA_GRAD);
    //# double[] outX = outputLayer.getX(); // 参照を代入（注４）
    double[] outY = outputLayer.getY();
    double[] outC = outputLayer.getC();
    double[] outE = outputLayer.getE();
    double[] outdEdX = outputLayer.getdEdX();
    // double[] outdEdY = outputLayer.getdEdY();

    double[] outS = new double[outC.length]; // 応答速度

    // middleLayer0.setdEdY(mid1dEdX); // （注１）中間０層のdEdYへ、中間１層のdEdXの参照を設定
    // middleLayer1.setdEdY(outdEdX); // （注２）中間１層のdEdYへ、出力層のdEdXの参照を設定
    // middleLayer1.setX(mid0Y); // （注３）中間１層の入力mid1Xへ、中間０層の出力mid0Yの参照を設定
    // outputLayer.setX(mid1Y); // （注４）出力層の入力outXへ、中間層の出力mid1Yの参照を設定
    
    middleLayer0.setdEdY(outdEdX); // （注１）中間０層のdEdYへ、中間１層のdEdXの参照を設定
    // middleLayer1.setdEdY(outdEdX); // （注２）中間１層のdEdYへ、出力層のdEdXの参照を設定
    // middleLayer1.setX(mid0Y); // （注３）中間１層の入力mid1Xへ、中間０層の出力mid0Yの参照を設定
    outputLayer.setX(mid0Y); // （注４）出力層の入力outXへ、中間層の出力mid1Yの参照を設定
    
    // DJ._print("PatternViewerへの参照を得る");
    patternViewer = patternViewerLauncher.getPatternViewer();
    // DJ._print("GraphViewerへの参照を得る");
    graphViewer = graphViewerLauncher.getGraphViewer();
    if (graphViewer != null) graphViewer.shiftGraphAxis(graphShift); // グラフの縦軸をシフトする
    
    DJ._print(" ##### ニューラルネットの学習開始 #####");
    for (int i = 0; i <= epoch; i++) {
      startTime = System.nanoTime(); // 実行開始時刻
      intervalFlag = (i % interval == interval - 1) | 
              (i == epoch); //経過表示フラグ

      // DJ._print("・学習用データのインデックスをシャッフル");
      //#      Collections.shuffle(indexList);
      
      // DJ.print("・学習誤差と試行誤差の初期化");
      squareError = 0.0; // 二乗誤差
      meanError = 0.0; // 二乗誤差データの平均
      
      // 制御をリセット
      control.resetControl();
    
      // 全データを学習
      for (int j = 0; j < dataNum; j++) {
        // DJ._print("・エポックの実行ループ開始 ----------------------------");
        // DJ.print(" エポック:i=" + i + ", 学習データ:j=" + j);
        
        //#        int index = indexList.get(j); // ランダム・インデックス
        int index = j; // 非ランダム・インデックス
        // DJ._print(" 学習用データID：learnIndex=", index);
        
        // DJ._print(" ##### 前向き伝播の呼び出し #####");
        // DJ._print("・入力画像から入力データを取り出し");
        mid0X[0] = inputData[index]; // 入力データ
        middleLayer0.forward(); // 中間０層の順伝播処理を呼び出し
        //#        middleLayer1.forward(); // 中間１層の順伝播処理を呼び出し
        outputLayer.forward(); // 出力層の順伝播処理を呼び出し
        
        // DJ._print(" 制御対象の実行（力操作量　--->　応答速度）");
        // DJ._print(" 推定された力操作量を制御対象（負荷）に与える");
        control.setForce(outY[0] / Control.DAC_FACTOR); // 力の操作量
        control.execute(); // 制御対象の実行
        outS[0] = control.getSpeed(); // 応答速度
        
        // DJ._print(" ##### 後向き伝播の呼び出し #####");
        outC[0] = correctData[index]; // 正解データを取出し
        //#        outputLayer.backward(outC); // 出力層の逆伝播処理を呼び出し
        outputLayer.backward(outC, outS); // 出力層の逆伝播処理を呼び出し
        //#        middleLayer1.backward(); // 中間層１の逆伝播処理を呼び出し
        middleLayer0.backward(); // 中間層０の逆伝播処理を呼び出し
        
        // DJ._print("・バッチ処理");
        if ((j % batchNum) == 0 ) {
          // DJ._print(" ##### 更新の呼び出し #####");
          middleLayer0.update(); // 中間層０の更新処理を呼び出し
          //#          middleLayer1.update(); // 中間層１の更新処理を呼び出し
          outputLayer.update(); // 出力層の更新処理を呼び出し
        }

        // DJ._print(" ##### 学習時の誤差の保存 #####");
        if (intervalFlag) { // 実行時間に影響を与える
          // 二乗誤差の算出
          squareError = outE[0] * outE[0] / 2.0;
          meanError = meanError + squareError;
        
          // パターン表示用データの保存
          patternData0[1][j] = outY[0] * 0.1; // ＮＮで算出された力操作量
          patternData0[2][j] = outS[0]; // 応答速度
          patternData0[3][j] = correctData[j]; // 所望速度
          patternData0[4][j] = control.getAcceleration() * 100.0; // 応答加速度
          patternData0[5][j] = 10.0; // ダミー
          patternData0[6][j] = 10.0; // ダミー
          patternData0[7][j] = forceData[j] * Control.DAC_FACTOR * 0.1; // 力操作量データ
          patternData0[8][j] = acceleration[j]; // 加速度データ
          patternData0[9][j] = 10.0; // ダミー
          
          patternData1[1][j] = 10.0; // ダミー
          patternData1[2][j] = 10.0; // ダミー
          patternData1[3][j] = 10.0; // ダミー
          patternData1[4][j] = patternData0[1][j]; // ＮＮで算出された力操作量
          patternData1[5][j] = patternData0[4][j]; // 応答加速度
          patternData1[6][j] = patternData0[2][j]; // 応答速度
          patternData1[7][j] = patternData0[7][j]; // 力操作量データ
          patternData1[8][j] = patternData0[8][j] - 0.01; // 加速度データ
          patternData1[9][j] = patternData0[3][j]; // 所望速度
        }
        
        // DJ._print(" End of one data ---------------");
      } // 全データを学習
      // DJ._print(" End of one epoch ---------------");
      
      // DJ._print("　エポック毎の誤差の平均値を保存する");
      aveError = meanError / dataNum ;
      
      // 実行時間の累積
      endTime = System.nanoTime(); // 休止時刻
      double lapTime_ = (endTime - startTime) / 1000000.0;
      if (lapTime_ > 0.0) lapTime = lapTime_; // オーバーフロー対策
      totalTime = totalTime + lapTime; // 経過時間を追加

       // 経過表示インターバル
      if (intervalFlag) { // 実行時間に影響を与える
        // DJ._print(" ##### １エポックの実行結果 #####");
        DJ.print_(" i=" + i);
        DJ.print_(", squareError = ", squareError);
        DJ.print(", aveError = ", aveError);

        // グラフ表示用データを代入
        graphData[0] = squareError;
        graphData[1] = aveError;
        updateViewer(i); // ビューワの表示を更新
          
        DJ.print_(", lapTime = ", lapTime); DJ.print("[msec]");
        
        // スレッドの休止（実行速度の調整および経過表示のため）
        synchronized(this) {
          try {
            // DJ.print("Ｅnter to wait(sleepTime)");
            wait(SLEEP_TIME); // タイムアウト付きで待機状態
            // DJ.print("Resume from wait(sleepTime)");
            if (pauseFlag) wait(); // 休止状態
          
          }
          catch (InterruptedException e) {
             DJ.print("***** ERROR ***** " + getClass().getName() + "\n"
                 + " Exception occur in wait(sleepTime):" + e.toString());
          }
        } // synchronized()
      } // interval
    
      // 実行処理を強制的に終了させる
      if (abortFlag) {
        DJ._print("##### Abort action requested");
        epoch = i; // 現在のエポック回数iをepochに代入し、実行を強制的に終了させる
      }
      
      // DJ._print(" End of one epoch ---------------------------------------");
    }
    DJ._print(" End of all epoch --------------------------------------------");
    
    DJ._print("・エポック実行回数");
    DJ.print(" Last epoch = ", epoch);

    DJ._print("・最終誤差: ");
    for (int k = 0; k < graphData.length; k++) {
      DJ.print("  " + dataName[k] + "=" + graphData[k]);
    }
    
//    DJ._print("・力操作量の推定値と応答速度: ");
//    DJ.print("P", patternData0);
    
    DJ._print_("・最終学習時のデータ  ");
    DJ.print("収集データ数: ", dataNum);
    DJ.print("力操作量の推定値: f~", patternData0[1]);
    DJ.print("応答速度: v", patternData0[2]);
    
    // 処理時間の算出
    DJ._print_("・総実行時間：" + (totalTime / 1000.0) + " [sec]");
    double aveTime = totalTime / epoch;
    DJ.print(", 平均実行時間：" + aveTime + " [msec/epoch]");
    finishTime = System.currentTimeMillis(); // タスク終了時刻
    DJ.print_(", タスク処理時間：" + 
            ((finishTime - beginTime) / 1000.0) + " [sec]");
    DJ.print(", タスク終了日時：", TimeStamp.getTimeFormated());
    
    DJ._print("##### Inverser1 終了 #####");
  } // inverser()
} // Inverser1 Class

// End of file
