/*
 *  Title: DaiJa_V4 (Digital-Learning Aide Instrument by JAva)
 *  @author Yoshinari Sasaki
 *  @version 4.0
 *  @since 2020.7.1
 *  Copyright: 2020, 2021
 */
package task.control;

import active.Activator;
import data.Control;
import layer.ControlLayer;
import layer.Layer;
import task.Task;
import util.DJ;
import util.TimeStamp;
import view.PatternViewer;

/**
 * <p> 表　題: Class: PathTracker</p>
 * <p> 説　明: 目標経路より力操作量を推定する逆関数を学習
 *     ランプ速度波形を目標経路とする。
 *     PI補償と学習制御を比較する。
 *     ニューラルネットワークは２層か３層を選択する。</p>
 * <p> 著　者: Yoshinari Sasaki</p>
 * <p> 著作権: Copyright (c) 2020, 2021</p>
 * <p> 作成日: 2021.01.19</p>
 */
public class PathTracker extends Task {
  
  static final int TWO_LAYER = 2; // ２層
  static final int THREE_LAYER = 3; // ３層
  int layerNum = TWO_LAYER; // THREE_LAYER; // ディフォルト層数
  
  // 学習・試行制御
  int epoch = 1000;  // エポックの回数
  int batchNum = 1; //20; // 10; // バッチ数
  int interval = 10; //1; // 経過表示間隔
  
  // 学習パラメータ
  double initialCoef = 0.5; //0.01; // 重みとバイアスの初期値係数
  double eta = 0.3; //0.01; // 学習係数
  int dataNum = 1000; //1000; // エポック毎の学習データ数　1[s]/1[ms] = 1000 [個]

  // 制御系と制御対象のパラメータ
  int speedGain = 2048; //1280; // 比例ゲイン（速度ループゲイン）
  int integralGain = 0; //2048; // 積分ゲイン
  double inertia = 1.0; // 慣性質量［kgm^2/sec^2］
  double friction = 2.0; // 粘性摩擦[kg/sec]

  // 各層のノード数
  int inNodeNum = 2; //3; // 入力層のノード数
  int midNodeNum = 8; // 中間層のノード数
  int outNodeNum = 1; // 出力層のノード数
  
  // レイヤー構成
  private ControlLayer middleLayer0; // 中間層０
  private ControlLayer middleLayer1; // 中間層１
  private ControlLayer outputLayer; // 出力層
  
  /**
   * Taskクラスで新しく作られたスレッドから呼び出される
   */
  @Override
  public void runTask() {
    DJ._print("・タスク開始日時：", TimeStamp.getTimeFormated());
    beginTime = System.currentTimeMillis(); // タスク開始時刻
    
    // パターン・ビューワ・ランチャを得る
    patternViewerFlag = true; // パターン・ビューワを表示する
    patternData0 = new double[10][dataNum]; // ID,出力(力),速度,正解(速度所望値)
    patternData1 = new double[10][dataNum]; // ID,力操作量,加速度,所望速度
    patternViewerLauncher = DJ.pattern( // 判定パターンとデータ
        PatternViewer.PATTERN_TUNER, patternData0, "PathTracker:PID制御", 
        PatternViewer.PATTERN_TUNER, patternData1,  "PathTracker:学習制御"); 
    
    // グラフ・ビューワ・ランチャを得る
    graphShift = 8; // グラフの縦軸を移動
    graphData = new double[2];
    dataName = new String[2]; // データ名
    dataName[0] = "squareError"; // 学習時の二乗誤差
    dataName[1] = "aveError"; // 学習時の平均二乗誤差
    graphViewerLauncher = DJ.graph(epoch, interval, dataName, graphData);
    
    pathTracker(); // タスク本体の呼び出し
  }
  
  /** 
   * 目標経路より力操作量を推定する逆関数を学習
   */
  public void pathTracker() {
    DJ._print("PathTracker.pathTracker() ==========================");
    
    DJ._print("・パラメータ");
    DJ.print_(" 重みとバイアスの初期値係数:initialCoef=", initialCoef);
    DJ.print_(", 学習係数:eta=",eta);
    DJ.print(", エポックの回数:epoch=",epoch);
    
    DJ.print_(" バッチ数:batchNum=",batchNum);
    DJ.print_(", 経過表示間隔:interval=",interval);
    DJ.print(", 学習データ数：dataNum=", dataNum);
    
    DJ.print("・制御系と制御対象のパラメータ");
    DJ.print_(" 比例ゲイン:speedGain=",speedGain);
    DJ.print_(", 積分ゲイン:integralGain=",integralGain);
    DJ.print_(", 慣性質量:inertia=",inertia);
    DJ.print(", 粘性摩擦:friction=",friction);
    
    DJ._print("・各層のノード数");
    DJ.print_(" 入力層ノード数:inNodeNum=",inNodeNum);
    DJ.print_(", 中間層ノード数:midNodeNum=",midNodeNum);
    DJ.print(", 出力層ノード数:outNodeNum=",outNodeNum);
    
    DJ._print("・学習対象（制御系と制御対象）の実行準備");
    double[][] pidData = new double[4][dataNum]; // PID制御試行結果
    Control control = new Control();
//    control.startSpeedControl(); // 速度制御を開始する（有効化する）
//    control.startFeedback(); // フィードバック処理を開始する（有効化する）
    control.setParameter(speedGain, integralGain, inertia, friction); // パラメータ
    control.resetControl();
    
    DJ._print("・入力データと教師データを作成");
    DJ._print(" 目標値を収集する");
    //# double[] inputData = new double[dataNum]; // 入力データ
    double[] correctData = new double[dataNum]; // 正解データ
    double[] accelCmd = new double[dataNum]; // 加速度目標値
    double[] speedCmd = new double[dataNum]; // 速度目標値
    double[] positionCmd = new double[dataNum]; // 位置目標値
    
    accelCmd[0] = 0.0;
    speedCmd[0] = 0.0;
    positionCmd[0] = 0.0;
    double clock = 0.0;
    double deltaClock = 1.0 / dataNum;
    for (int j = 1; j < dataNum; j++) {
      clock = clock + deltaClock;
      patternData0[0][j] = clock;
      patternData1[0][j] = clock;
      //# inputData[j] = clock; // クロック
     
      // 速度経路
      if (clock > 0.7) {
        speedCmd[j] = 0.6; // 速度目標値
        accelCmd[j] = 0.0; // 加速度目標値
        positionCmd[j] = positionCmd[j - 1] + speedCmd[j] / dataNum ;
      }
      else if (clock > 0.1) {
        speedCmd[j] = clock - 0.1; // 速度目標値
        accelCmd[j] = (speedCmd[j] - speedCmd[j - 1]) * dataNum; // 加速度目標値
        positionCmd[j] = positionCmd[j - 1] + speedCmd[j] / dataNum ;
      }
      else {
        accelCmd[j] =  0.0; // 加速度目標値
        speedCmd[j] = 0.0; // 速度目標値
        positionCmd[j] = 0.0; // 位置目標値
      }
      correctData[j] = speedCmd[j]; // 所望速度（ランプ波形）;
              
      control.setSpeed(speedCmd[j]);
      control.execute();
      pidData[0][j] = control.getForce(); // 力操作量
      pidData[1][j] = control.getAcceleration(); // 応答加速度
      pidData[2][j] = control.getSpeed(); // 応答速度
      pidData[3][j] = control.getPosition(); // 応答位置
    }

    // DJ.print(" 順序をランダムにシャッフルしたインデックスのリスト");
    //#    ArrayList<Integer> indexList = DJ.permutationRandom(dataNum);
    //#    DJ.print("inputData", inputData);
    
    DJ._print("・ニューラルネットの各層の初期化");
    middleLayer0 = new ControlLayer(inNodeNum, midNodeNum);
    if (layerNum == THREE_LAYER) // ３層：中間１層を使用する
      middleLayer0.initialize(eta, initialCoef, Activator.RELU, Layer.ADA_GRAD);
    else // ２層：中間１層を飛ばして中間０層と出力層でデータを受け渡しする
      middleLayer0.initialize(eta, initialCoef, Activator.SIGMOID, Layer.SGD);
    double[] mid0X = middleLayer0.getX();
    double[] mid0Y = middleLayer0.getY();
    // double[] mid0C = middleLayer0.getC();
    // double[] mid0E = middleLayer0.getE();
    // double[] mid0dEdX = middleLayer0.getdEdX();
    //# double[] mid0dEdY = middleLayer0.getdEdY(); // 参照を代入（注１）

    double[] mid1Y;
    double[] mid1dEdX;
    middleLayer1 = new ControlLayer(midNodeNum, midNodeNum);
    if (layerNum == THREE_LAYER) // ３層：中間１層を使用する
      middleLayer1.initialize(eta, initialCoef, Activator.RELU, Layer.ADA_GRAD);
    else // ２層：中間１層を飛ばして中間０層と出力層でデータを受け渡しする
      // middleLayer1.initialize(eta, initialCoef, Activator.RELU, Layer.SGD);
      middleLayer1.initialize(eta, initialCoef, Activator.SIGMOID, Layer.SGD);
    // double[] mid1X = middleLayer1.getX(); // 参照を代入（注３）
    mid1Y = middleLayer1.getY();
    // double[] mid1C = middleLayer1.getC();
    // double[] mid1E = middleLayer1.getE();
    mid1dEdX = middleLayer1.getdEdX();
    //# double[] mid1dEdY = middleLayer1.getdEdY(); // 参照を代入（注２）

    outputLayer = new ControlLayer(midNodeNum, outNodeNum);
    // outputLayer.initialize(eta, initialCoef, Activator.IDENTITY, Layer.SGD);
    outputLayer.initialize(eta, initialCoef, Activator.IDENTITY, Layer.ADA_GRAD);
    //# double[] outX = outputLayer.getX(); // 参照を代入（注４）
    double[] outY = outputLayer.getY();
    double[] outC = outputLayer.getC();
    /**/    double[] outE = outputLayer.getE();
    double[] outdEdX = outputLayer.getdEdX();
    // double[] outdEdY = outputLayer.getdEdY();

    double[] outS = new double[outC.length]; // 正解（所望速度）

    if (layerNum == THREE_LAYER) { // ３層：中間１層を使用する
      middleLayer0.setdEdY(mid1dEdX); // （注１）中間０層のdEdYへ、中間１層のdEdXの参照を設定
      middleLayer1.setdEdY(outdEdX); // （注２）中間１層のdEdYへ、出力層のdEdXの参照を設定
      middleLayer1.setX(mid0Y); // （注３）中間１層の入力mid1Xへ、中間０層の出力mid0Yの参照を設定
      outputLayer.setX(mid1Y); // （注４）出力層の入力outXへ、中間層の出力mid1Yの参照を設定
    }
    else { // ２層：中間１層を飛ばして中間０層と出力層でデータを受け渡しする
      middleLayer0.setdEdY(outdEdX); // （注１）中間０層のdEdYへ、中間１層のdEdXの参照を設定
      // middleLayer1.setdEdY(outdEdX); // （注２）中間１層のdEdYへ、出力層のdEdXの参照を設定
      // middleLayer1.setX(mid0Y); // （注３）中間１層の入力mid1Xへ、中間０層の出力mid0Yの参照を設定
      outputLayer.setX(mid0Y); // （注４）出力層の入力outXへ、中間層の出力mid1Yの参照を設定
    }
    
    // DJ._print("PatternViewerへの参照を得る");
    patternViewer = patternViewerLauncher.getPatternViewer();
    // DJ._print("GraphViewerへの参照を得る");
    graphViewer = graphViewerLauncher.getGraphViewer();
    if (graphViewer != null) graphViewer.shiftGraphAxis(graphShift); // グラフの縦軸をシフトする
    
    // DJ._print("・学習誤差の計測用データ");
    double squareError; // 二乗誤差
    double meanError; // 二乗誤差データの平均
    
    DJ._print(" ##### ニューラルネットの学習開始 #####");
    for (int i = 0; i <= epoch; i++) {
      startTime = System.nanoTime(); // 実行開始時刻
      intervalFlag = (i % interval == interval - 1) | 
              (i == epoch); //経過表示フラグ

      // DJ._print("・学習用データのインデックスをシャッフル");
      //# Collections.shuffle(indexList);
      
      // DJ.print("・学習誤差と試行誤差の初期化");
      squareError = 0.0; // 二乗誤差
      meanError = 0.0; // 二乗誤差データの平均
    
      //# int index = indexList.get(i); // ランダム・インデックス
      
      // 制御をリセット
      control.setParameter(0, 0, inertia, friction);
      control.stopSpeedControl(); // 速度制御を止める（無効化する）
      control.stopFeedback(); // フィードバック処理を止める（無効化する）
      control.resetControl();
      
      // 全データを学習
      for (int j = 0; j < dataNum; j++) {
        // DJ._print("・エポックの実行ループ開始 ----------------------------");
        // DJ.print(" エポック:i=" + i + ", 学習データ:j=" + j);
        
        // DJ._print(" ##### 前向き伝播の呼び出し #####");
        // DJ._print("・入力データを取り出し");
        //#        int index = indexList.get(j); // ランダム・インデックス
        int index = j; // 非ランダム・インデックス
        // DJ._print(" 学習用データID：learnIndex=", learnIndex);
        //# mid0X[0] = inputData[index]; // 入力データをランダムに取出し
        mid0X[0] = accelCmd[index] * 0.05; // 加速度目標値
        mid0X[1] = correctData[index]; // 正解データ（所望速度を応答速度として入力）
//        mid0X[2] = positionCmd[index]; // 位置目標値
        middleLayer0.forward(); // 中間０層の順伝播処理を呼び出し
        if (layerNum == THREE_LAYER) // ３層：中間１層を使用する
          middleLayer1.forward(); // 中間１層の順伝播処理を呼び出し
        outputLayer.forward(); // 出力層の順伝播処理を呼び出し
        
        // DJ._print(" 推定された力操作量を制御対象（負荷）に与える");
        double force = outY[0] / Control.DAC_FACTOR;
        control.setForce(force); // 力の操作量
        control.execute(); // 制御対象の実行
        outS[0] = control.getSpeed(); // 応答速度;
        
        // DJ._print(" ##### 後向き伝播の呼び出し #####");
        outC[0] = correctData[index]; // 正解データを取出し
        //# outputLayer.backward(outC); // 出力層の逆伝播処理を呼び出し
        outputLayer.backward(outC, outS); // 出力層の逆伝播処理を呼び出し（逆関数の学習時）
        if (layerNum == THREE_LAYER) // ３層：中間１層を使用する
          middleLayer1.backward(); // 中間層１の逆伝播処理を呼び出し
        middleLayer0.backward(); // 中間層０の逆伝播処理を呼び出し
        
        // DJ._print("・バッチ処理");
        if ((j % batchNum) == 0 ) {
          // DJ._print(" ##### 更新の呼び出し #####");
          middleLayer0.update(); // 中間層０の更新処理を呼び出し
          if (layerNum == THREE_LAYER) // ３層：中間１層を使用する
            middleLayer1.update(); // 中間層１の更新処理を呼び出し
          outputLayer.update(); // 出力層の更新処理を呼び出し
        }

        // DJ._print(" ##### 学習時の誤差の保存 #####");
        if (intervalFlag) { // 実行時間に影響を与える
          // 二乗誤差の算出
          squareError = outE[0] * outE[0] / 2.0;
          meanError = meanError + squareError;
        
          // パターン表示用データの保存
          patternData0[1][j] = force * 0.001; // 推定された力操作量
//          patternData0[2][j] = outS[0]; // 応答速度
//          patternData0[3][j] = correctData[j]; // 正解（所望速度）
          patternData0[2][j] = -10.0; //
          patternData0[3][j] = -10.0; //
          
//          patternData0[4][j] = control.getAcceleration() * 100.0; // 応答加速度
//          patternData0[5][j] = -10.0; // 
//          patternData0[6][j] = control.getPosition(); // 応答位置加速度
          
          patternData0[4][j] = pidData[0][j] / 1000.0; // 力操作量（PID制御）
          patternData0[5][j] = pidData[1][j] * 100.0; // 応答加速度（PID制御）
          patternData0[6][j] = pidData[2][j]; // 応答速度（PID制御）

          patternData0[7][j] = accelCmd[j] * 0.1; // 加速度目標値 
          patternData0[8][j] = speedCmd[j]; // 速度目標値
//          patternData0[9][j] = positionCmd[j]; // 位置目標値
          patternData0[9][j] = -10.0; //
        }
        // DJ._print(" End of one data ---------------");
      } // 全データを学習
      // DJ._print(" End of one epoch ---------------");
      
      // DJ._print("　エポック毎の誤差の平均値を保存する");
      double aveError = meanError / dataNum ;
      
      // 実行時間の累積
      endTime = System.nanoTime(); // 休止時刻
      double lapTime_ = (endTime - startTime) / 1000000.0;
      if (lapTime_ > 0.0) lapTime = lapTime_; // オーバーフロー対策
      totalTime = totalTime + lapTime; // 経過時間を追加

       // 経過表示インターバル
      if (intervalFlag) { // 実行時間に影響を与える
        // DJ._print(" ##### １エポックの実行結果 #####");
        DJ.print_(" i=" + i);
        DJ.print_(", squareError = ", squareError);
        DJ.print(", aveError = ", aveError);

        // グラフ表示用データを代入
        graphData[0] = squareError;
        graphData[1] = aveError;
        updateViewer(i); // ビューワの表示を更新
        
        DJ.print_(", lapTime = ", lapTime); DJ.print("[msec]");
        
        // スレッドの休止（実行速度の調整および経過表示のため）
        synchronized(this) {
          try {
            // DJ.print("Ｅnter to wait(sleepTime)");
            wait(SLEEP_TIME); // タイムアウト付きで待機状態
            // DJ.print("Resume from wait(sleepTime)");
            if (pauseFlag) wait(); // 休止状態
          }
          catch (InterruptedException e) {
             DJ.print("***** ERROR ***** " + getClass().getName() + "\n"
                 + " Exception occur in wait(sleepTime):" + e.toString());
          }
        } // synchronized()
      } // interval
    
      // 実行処理を強制的に終了させる
      if (abortFlag) {
        DJ._print("##### Abort action requested");
        epoch = i; // 現在のエポック回数iをepochに代入し、実行を強制的に終了させる
      }
      
      // DJ._print(" End of one epoch ---------------------------------------");
    }
    DJ._print(" End of all epoch --------------------------------------------");
    
    DJ._print("・エポック実行回数");
    DJ.print(" Last epoch = ", epoch);
    DJ.print_("・最終誤差: ");
    for (int k = 0; k < graphData.length; k++) {
      DJ.print_("  " + dataName[k] + "=" + graphData[k]);
    }
    DJ.print("");
    
    DJ._print("・検証：学習した逆関数で制御対象を駆動し、応答速度を比較する");
    control.setParameter(0, 0, inertia, friction); // パラメータ
    control.stopSpeedControl(); // 速度制御を止める（無効化する）
    control.resetControl();
    
    for (int j = 0; j < dataNum; j++) {
      //# mid0X[0] = inputData[j]; // 入力データを取出し
      mid0X[0] = accelCmd[j] * 0.05; // 加速度目標値
      //# mid0X[1] = control.getSpeed(); // 応答速度をフィードバックして入力
      mid0X[1] = correctData[j]; // 所望速度
//      mid0X[2] = positionCmd[j]; // 位置目標値
      middleLayer0.forward(); // 中間０層の順伝播処理を呼び出し
      if (layerNum == THREE_LAYER) // ３層：中間１層を使用する
        middleLayer1.forward(); // 中間１層の順伝播処理を呼び出し
      outputLayer.forward(); // 出力層の順伝播処理を呼び出し
      
      double force = outY[0] / Control.DAC_FACTOR; // ＮＮで算出された力操作量
      control.setForce(force); // 力の操作量 frc_mnp
      control.execute(); // 制御対象の実行

      patternData0[1][j] = force * 0.001; // ＮＮで算出された力操作量
      
      patternData1[1][j] = patternData0[1][j]; // ＮＮで算出された力操作量
      patternData1[2][j] = -10.0; // 
      patternData1[3][j] = -10.0; // 
      
      patternData1[4][j] = control.getAcceleration() * 100.0; // 応答加速度
      patternData1[5][j] = control.getSpeed(); // 応答速度
      patternData1[6][j] = control.getPosition(); // 応答位置速度
      
      patternData1[7][j] = accelCmd[j] * 0.1; // 加速度目標値
      patternData1[8][j] = speedCmd[j]; // 速度目標値（所望速度）
      patternData1[9][j] = positionCmd[j]; // 位置目標値
    }
    
    updatePattern();
    
    DJ._print("・制御系と制御対象のパラメータ");
    DJ.print_(" 比例ゲイン:speedGain=",speedGain);
    DJ.print_(", 積分ゲイン:integralGain=",integralGain);
    DJ.print_(", 慣性質量:inertia=",inertia);
    DJ.print(", 粘性摩擦:friction=",friction);
    
    DJ._print("・収集データ数: ", dataNum);
    DJ.print("力操作量の推定値: f~", patternData1[1]);
    DJ.print("応答速度: v", patternData1[5]);
    DJ.print("応答速度（PID制御）: v", patternData0[6]);
    
    // 処理時間の算出
    DJ._print_("・総実行時間：" + (totalTime / 1000.0) + " [sec]");
    double aveTime = totalTime / epoch;
    DJ.print(", 平均実行時間：" + aveTime + " [msec/epoch]");
    finishTime = System.currentTimeMillis(); // タスク開始時刻
    DJ.print_("・タスク処理時間：" + 
            ((finishTime - beginTime) / 1000.0) + " [sec]");
    DJ.print(", タスク終了日時：", TimeStamp.getTimeFormated());
    
    DJ._print("##### PathTracker 終了 #####");
  } // pathTracker()

} // PathTracker Class

// End of file
