/*
 *  Title: DaiJa_V4 (Digital-Learning Aide Instrument by JAva)
 *  @author Yoshinari Sasaki
 *  @version 4.0
 *  @since 2020.7.1
 *  Copyright: 2020, 2021
 */
package task.control;

import active.Activator;
import data.Control;
import layer.ControlLayer;
import layer.Layer;
import task.Task;
import util.DJ;
import util.TimeStamp;
import view.PatternViewer;

/**
 * <p> 表　題: Class: TargetController</p>
 * <p> 説　明: TargetEmulatorを用いて制御対象の逆関数NNinvを学習し、
 *              fwdNNinvで力操作量f^を求めて制御対象を制御する。</p>
 * <p> 著　者: Yoshinari Sasaki</p>
 * <p> 著作権: Copyright (c) 2020, 2021</p>
 * <p> 作成日: 2020.08.23</p>
 */
public class TargetController extends Task {

  // 学習・試行制御
  int epoch = 20000; // 1000; // エポックの回数
  int batchNum = 1; // 10; // バッチ数
  int interval = 200; //1; // 経過表示間隔

  // 学習パラメータ
  double initialCoef = 0.8; //0.2; // 重みとバイアスの初期値係数
  double eta = 0.3; //0.2; // 学習係数
  int dataNum = 1000; // エポック毎の学習データ数1[s]/1[ms] = 1000 [個]

  // 制御系とパラメータ
  Control control; // 制御系と制御対象
  int speedGain = 1280; //2048 // 比例ゲイン（速度ループゲイン）
  int integralGain = 0; // 積分ゲイン
  double inertia = 1.0; // 慣性質量［kgm^2/sec^2］
  double friction = 2.0; // 粘性摩擦[kg/sec]

  // 各層のノード数
  int inNodeNum = 2; // 入力層のノード数
  int midNodeNum = 7; //8; // 中間層のノード数
  int outNodeNum = 1; // 出力層のノード数

  // レイヤー構成
  private ControlLayer middleLayer0; // 中間層０
  //#  private ControlLayer middleLayer1; // 中間層１
  private ControlLayer outputLayer; // 出力層

  // 誤差変数
  double squareError; // 二乗誤差
  double meanError; // 二乗誤差データの平均
  double aveError; // エポック毎の誤差の平均値
  
  /**
   * Taskクラスで新しく作られたスレッドから呼び出される
   */
  @Override
  public void runTask() {
    DJ._print("・タスク開始日時：", TimeStamp.getTimeFormated());
    beginTime = System.currentTimeMillis(); // タスク開始時刻

    // パターン・ビューワ・ランチャを得る
    patternViewerFlag = true; // パターン・ビューワを表示する
    patternData0 = new double[10][dataNum]; // ID,出力(力),速度,正解(速度所望値)
    patternData1 = new double[10][dataNum]; // ID,力操作量,加速度,所望速度
    patternViewerLauncher = DJ.pattern( // 判定パターンとデータ
        PatternViewer.PATTERN_TUNER, patternData0, "TargetController:Learning",
        PatternViewer.PATTERN_TUNER, patternData1, "TargetController:Trial");
    
    // グラフ・ビューワ・ランチャを得る
    graphShift = 8; // グラフの縦軸を移動
    graphData = new double[2];
    dataName = new String[2]; // データ名
    dataName[0] = "squareError"; // 学習時の二乗誤差
    dataName[1] = "aveError"; // 学習時の平均二乗誤差
    graphViewerLauncher = DJ.graph(epoch, interval, dataName, graphData);
    
    targetController(); // タスク本体の呼び出し
  }

  /**
   * 制御対象の逆関数を学習し、力操作量を求めて制御対象を制御する
   */
  public void targetController() {
    DJ._print("DaiJa_V3, TargetController.targetController() ================");

    DJ._print("・パラメータ");
    DJ.print_(" 重みとバイアスの初期値係数:initialCoef=", initialCoef);
    DJ.print_(", 学習係数:eta=", eta);
    DJ.print(", エポックの回数:epoch=", epoch);

    DJ.print_(" バッチ数:batchNum=", batchNum);
    //#    DJ.print_(", ドロップ・アウト率:dropOutRate=",dropOutRate);
    DJ.print_(", 経過表示間隔:interval=", interval);
    DJ.print(", 学習データ数：dataNum=", dataNum);

    DJ.print("・制御系のパラメータ");
    DJ.print_(" 比例ゲイン:speedGain=", speedGain);
    DJ.print_(", 積分ゲイン:integralGain=", integralGain);
    DJ.print_(", 慣性質量:inertia=", inertia);
    DJ.print(", 粘性摩擦:friction=", friction);

    DJ._print("・各層のノード数");
    DJ.print_(" 入力層ノード数:inNodeNum=", inNodeNum);
    DJ.print_(", 中間層ノード数:midNodeNum=", midNodeNum);
    DJ.print(", 出力層ノード数:outNodeNum=", outNodeNum);
    
    DJ._print("・制御系と制御対象を設定");
    control = new Control();
    //#    control.setParameter(speedGain, integralGain, 1.0, 0.0); // パラメータ
    control.setParameter(0, 0, inertia, friction); // パラメータ
    control.stopSpeedControl(); // 速度制御を止める（無効化する）
    control.resetControl(); // 制御系をリセット
    
    DJ._print("・制御対象から入力（力操作量）と出力（応答速度）を収集する");
    DJ._print("　入力データと教師データを作成");
    double[] inputData = new double[dataNum]; // 入力データ
    double[] correctData = new double[dataNum]; // 正解データ
    //# double[] forceData = new double[dataNum]; // 力操作量データ
    //# double[] acceleration = new double[dataNum]; // 加速度データ
     
    double inc = 1.0 / (double)dataNum;
    double clock = -inc;
    for (int j = 0; j < dataNum; j++) {
      clock = clock + inc;
      inputData[j] = clock; // 入力データ（クロック）
//**/      correctData[j] = control.speed_step(clock); // 所望速度（理論ステップ応答）
/**/      correctData[j] = control.getS_Curve(clock); // 所望速度（Ｓ字カーブ応答）
//**/      correctData[j] = control.spd_ramp(clock) * 0.6; // 所望速度（理論ランプ応答）

      patternData0[0][j] = clock; // パターン表示（0頁目）の時間軸
      patternData1[0][j] = clock; // パターン表示（1頁目）の時間軸
    }
    
//    DJ.print("inputData", inputData);
//    DJ.print("correctData", correctData);

    DJ._print("・制御対象（仮想制御対象）を設定");
    TargetEmulator emulator = new TargetEmulator();
    emulator.runTask();
    emulator.initEmulator(); // 仮想制御対象を初期化
    double force; // 力操作量
    double acceleration; // 推定加速度
    double speed; // 推定速度
    double position; // 推定位置
    
    // DJ.print(" 順序をランダムにシャッフルしたインデックスのリスト");
    //#    ArrayList<Integer> indexList = DJ.permutationRandom(dataNum);
    //#    DJ.print("inputData", inputData);
    
    DJ._print("・ニューラルネットの各層の初期化");
    middleLayer0 = new ControlLayer(inNodeNum, midNodeNum);
//    middleLayer0.initialize(eta, initialCoef, Activator.SIGMOID, Layer.SGD);
    middleLayer0.initialize(eta, initialCoef, Activator.SIGMOID, Layer.ADA_GRAD);
    double[] mid0X = middleLayer0.getX();
    double[] mid0Y = middleLayer0.getY();
    // double[] mid0C = middleLayer0.getC();
    // double[] mid0E = middleLayer0.getE();
    // double[] mid0dEdX = middleLayer0.getdEdX();
    //# double[] mid0dEdY = middleLayer0.getdEdY(); // 参照を代入（注１）

    // middleLayer1 = new ControlLayer(midNodeNum, midNodeNum);
    //// middleLayer1.initialize(eta, initialCoef, Activator.SIGMOID, Layer.SGD);
    //  middleLayer1.initialize(eta, initialCoef, Activator.SIGMOID, Layer.ADA_GRAD);
    //# double[] mid1X = middleLayer1.getX(); // 参照を代入（注３）
    // double[] mid1Y = middleLayer1.getY();
    // double[] mid1C = middleLayer1.getC();
    // double[] mid1E = middleLayer1.getE();
    // double[] mid1dEdX = middleLayer1.getdEdX();
    //# double[] mid1dEdY = middleLayer1.getdEdY(); // 参照を代入（注２）

    outputLayer = new ControlLayer(midNodeNum, outNodeNum);
//    outputLayer.initialize(eta, initialCoef, Activator.IDENTITY, Layer.SGD);
    outputLayer.initialize(eta, initialCoef, Activator.IDENTITY, Layer.ADA_GRAD);
    //#    double[] outX = outputLayer.getX(); // 参照を代入（注４）
    double[] outY = outputLayer.getY();
    double[] outC = outputLayer.getC();
    double[] outE = outputLayer.getE();
    double[] outdEdX = outputLayer.getdEdX();
    //#    double[] outdEdY = outputLayer.getdEdY();

    double[] outS = new double[outC.length]; // 正解（所望速度）

    //#    middleLayer0.setdEdY(mid1dEdX); // （注１）中間０層のdEdYへ、中間１層のdEdXの参照を設定
    //#    middleLayer1.setdEdY(outdEdX); // （注２）中間１層のdEdYへ、出力層のdEdXの参照を設定
    //#    middleLayer1.setX(mid0Y); // （注３）中間１層の入力mid1Xへ、中間０層の出力mid0Yの参照を設定
    //#    outputLayer.setX(mid1Y); // （注４）出力層の入力outXへ、中間層の出力mid1Yの参照を設定
    middleLayer0.setdEdY(outdEdX); // （注１）中間０層のdEdYへ、中間１層のdEdXの参照を設定
    //#    middleLayer1.setdEdY(outdEdX); // （注２）中間１層のdEdYへ、出力層のdEdXの参照を設定
    //#    middleLayer1.setX(mid0Y); // （注３）中間１層の入力mid1Xへ、中間０層の出力mid0Yの参照を設定
    outputLayer.setX(mid0Y); // （注４）出力層の入力outXへ、中間層の出力mid1Yの参照を設定

    // DJ._print("PatternViewerへの参照を得る");
    patternViewer = patternViewerLauncher.getPatternViewer();
    // DJ._print("GraphViewerへの参照を得る");
    graphViewer = graphViewerLauncher.getGraphViewer();
    if (graphViewer != null) {
      graphViewer.shiftGraphAxis(graphShift); // グラフの縦軸をシフトする
    }

    DJ._print(" ##### ニューラルネットの学習開始 #####");
    for (int i = 0; i <= epoch; i++) {
      startTime = System.nanoTime(); // 実行開始時刻
      // DJ.print("・経過表示フラグ");
      intervalFlag = (i % interval == interval - 1) | (i == epoch);

      // DJ._print("・学習用データのインデックスをシャッフル");
      //#      Collections.shuffle(indexList);

      // DJ.print("・学習誤差と試行誤差の初期化");
      squareError = 0.0; // 二乗誤差
      meanError = 0.0; // 二乗誤差データの平均
    
      // DJ._print("・制御系と制御対象の実行準備");
      control.setParameter(0, 0, inertia, friction); // パラメータ
      control.stopSpeedControl(); // 速度制御を止める（無効化する）
      control.resetControl();
      
      // DJ._print("・制御対象（仮想制御対象）の実行準備");
      emulator.initEmulator(); // 仮想制御対象を初期化

      // 全データを学習
      for (int j = 0; j < dataNum; j++) {
        // DJ._print("・エポックの実行ループ開始 ----------------------------");
        // DJ.print(" エポック:i=" + i + ", 学習データ:j=" + j);

        // DJ._print(" ##### 前向き伝播の呼び出し #####");
        // DJ._print("・入力データを取り出し");
        //#        int index = indexList.get(j); // ランダム・インデックス
        int index = j; // 非ランダム・インデックス
        // DJ._print(" 学習用データID：learnIndex=", learnIndex);
        mid0X[0] = inputData[index]; // クロックを入力
        mid0X[1] = correctData[index]; // 正解データ（所望速度を応答速度として入力）
        middleLayer0.forward(); // 中間０層の順伝播処理を呼び出し
        //#        middleLayer1.forward(); // 中間１層の順伝播処理を呼び出し
        outputLayer.forward(); // 出力層の順伝播処理を呼び出し
        
        // DJ._print(" 比較のため、制御対象も実行する");
//        control.setForce(force); // 力の操作量 frc_mnp
        control.setForce(outY[0] / Control.DAC_FACTOR); // 力の操作量 frc_mnp
        control.execute(); // 制御対象の実行        
        
        
        // DJ._print(" 推定された力操作量を制御対象（負荷）に与える");
        force = outY[0] / Control.DAC_FACTOR; // 力操作量
        acceleration = emulator.setForce(force); // 力の操作量 -> 応答加速度
        speed = emulator.getSpeed(); // 応答速度
        outS[0] = speed * 10.0; // 応答速度 ※スケーリングが必要
        position = emulator.getPosition();
        
        // DJ._print(" ##### 後向き伝播の呼び出し #####");
        outC[0] = correctData[index] * 1.0; // 正解データを取出し
        //#        outputLayer.backward(outC); // 出力層の逆伝播処理を呼び出し
        outputLayer.backward(outC, outS); // 出力層の逆伝播処理を呼び出し
        //#        middleLayer1.backward(); // 中間１層の逆伝播処理を呼び出し
        middleLayer0.backward(); // 中間０層の逆伝播処理を呼び出し

        // DJ._print("・バッチ処理");
        if ((j % batchNum) == 0) {
          // DJ._print(" ##### 更新の呼び出し #####");
          middleLayer0.update(); // 中間０層の更新処理を呼び出し
          //#          middleLayer1.update(); // 中間層１の更新処理を呼び出し
          outputLayer.update(); // 出力層の更新処理を呼び出し
        }

        // DJ._print(" ##### 学習時の誤差の保存 #####");
        if (intervalFlag) { // 実行時間に影響を与える
          // 二乗誤差の算出
          squareError = outE[0] * outE[0] / 2.0;
          meanError = meanError + squareError;

          // パターン表示用データの保存
          // TargetEmulator（仮想制御対象）
          patternData0[1][j] = outY[0] * Control.DAC_FACTOR * 20.0; // 推定された力操作量
          patternData0[2][j] = outS[0]; // * 0.96; // 推定された応答速度（赤）
          patternData0[3][j] = correctData[j]; // * 0.96; // 所望速度（黒）
          
          patternData0[4][j] = -10.0; // ダミー
          patternData0[5][j] = speed * 10.0 + 0.01; // 推定された応答速度
          patternData0[6][j] = position * 10.0; // 推定された応答位置

          patternData0[7][j] = -10.0; // ダミー
          patternData0[8][j] = control.getSpeed(); // 応答速度
          patternData0[9][j] = control.getPosition(); // 応答位置

          // 制御対象
          patternData1[1][j] = patternData0[1][j]; // 推定された力操作量
          patternData1[2][j] = -10.0; // ダミー
          patternData1[3][j] = -10.0; // ダミー
          
          patternData1[4][j] = patternData0[3][j]; // 所望速度
          patternData1[5][j] = patternData0[2][j]; // 推定された応答速度
          patternData1[6][j] = patternData0[4][j]; // 推定された応答位置
          
          patternData1[7][j] = -10.0; // ダミー
          patternData1[8][j] = control.getSpeed(); // 応答速度
          patternData1[9][j] = control.getPosition(); // 応答位置
        }

        // DJ._print(" End of one data ---------------");
      } // 全データを学習
      // DJ._print(" End of one epoch ---------------");

      // DJ._print("　エポック毎の誤差の平均値を保存する");
      aveError = meanError / dataNum;

      // 実行時間の累積
      endTime = System.nanoTime(); // 休止時刻
      double lapTime_ = (endTime - startTime) / 1000000.0;
      if (lapTime_ > 0.0) lapTime = lapTime_; // オーバーフロー対策
      totalTime = totalTime + lapTime; // 経過時間を追加

      // 経過表示インターバル
      if (intervalFlag) { // 実行時間に影響を与える
        // DJ._print(" ##### １エポックの実行結果 #####");
        DJ.print_(" i=" + i);
        DJ.print_(", squareError = ", squareError);
        DJ.print(", aveError = ", aveError);

        graphData[0] = squareError;
        graphData[1] = aveError;
        updateViewer(i); // ビューワの表示を更新
        
        DJ.print_(", lapTime = ", lapTime);
        DJ.print("[msec]");

        // スレッドの休止（実行速度の調整および経過表示のため）
        synchronized (this) {
          try {
            // DJ.print("Ｅnter to wait(sleepTime)");
            wait(SLEEP_TIME); // タイムアウト付きで待機状態
            // DJ.print("Resume from wait(sleepTime)");
            if (pauseFlag) wait(); // 休止状態
          }
          catch (InterruptedException e) {
            DJ.print("***** ERROR ***** " + getClass().getName() + "\n"
                + " Exception occur in wait(sleepTime):" + e.toString());
          }
        } // synchronized()
      } // interval

      // 実行処理を強制的に終了させる
      if (abortFlag) {
        DJ._print("##### Abort action requested");
        epoch = i; // 現在のエポック回数iをepochに代入し、実行を強制的に終了させる
      }

      // DJ._print(" End of one epoch ---------------------------------------");
    }
    DJ._print(" End of all epoch --------------------------------------------");

    DJ._print("・エポック実行回数");
    DJ.print(" Last epoch = ", epoch);
    DJ.print_("・最終誤差: ");
    for (int k = 0; k < graphData.length; k++) {
      DJ.print_("  " + dataName[k] + "=" + graphData[k]);
    }
    DJ.print("");

    DJ._print("・制御系と制御対象のパラメータ");
    DJ.print_(" 比例ゲイン:speedGain=", speedGain);
    DJ.print_(", 積分ゲイン:integralGain=", integralGain);
    DJ.print_(", 慣性質量:inertia=", inertia);
    DJ.print(", 粘性摩擦:friction=", friction);

    DJ._print("・収集データ数: ", dataNum);
    DJ.print("力操作量の推定値: f~", patternData0[1]);
    DJ.print("応答速度: v", patternData0[2]);

    
    DJ._print("・学習した逆関数で制御対象を駆動し、応答を得る");
    // DJ._print("・制御系を停止し、制御対象をリセットする");
    control.setParameter(0, 0, inertia, friction); // パラメータ
    control.stopSpeedControl(); // 速度制御を止める（無効化する）
    control.resetControl();
    
    for (int j = 0; j < dataNum; j++) {
      mid0X[0] = inputData[j]; // 入力データを取出し
      mid0X[1] = correctData[j]; // 正解データ（所望速度を応答速度として入力）
      middleLayer0.forward(); // 中間０層の順伝播処理を呼び出し
      //#        middleLayer1.forward(); // 中間１層の順伝播処理を呼び出し
      outputLayer.forward(); // 出力層の順伝播処理を呼び出し
      
      control.setForce(outY[0] / Control.DAC_FACTOR); // 力の操作量 frc_mnp
      control.execute(); // 制御対象の実行

      patternData1[1][j] = outY[0] * Control.DAC_FACTOR * 20.0; // 力操作量
      patternData1[2][j] = -10.0; // ダミー
      patternData1[3][j] = -10.0; // ダミー
      patternData1[4][j] = -10.0; // ダミー
      patternData1[5][j] = correctData[j]; // 所望速度 
      patternData1[6][j] = -10.0; // ダミー
      patternData1[7][j] = -10.0; // ダミー
      patternData1[8][j] = control.getSpeed() * 1.0; // 速度(推定値)
      patternData1[9][j] = control.getPosition() * 1.0; // 応答位置
    }
    updatePattern(); // パターン・ビューワの表示を更新
    
    // 処理時間の算出
    DJ._print_("・総実行時間：" + (totalTime / 1000.0) + " [sec]");
    double aveTime = totalTime / epoch;
    DJ.print(", 平均実行時間：" + aveTime + " [msec/epoch]");
    finishTime = System.currentTimeMillis(); // タスク開始時刻
    DJ.print_("・タスク処理時間："
            + ((finishTime - beginTime) / 1000.0) + " [sec]");
    DJ.print(", タスク終了日時：", TimeStamp.getTimeFormated());

    DJ._print("##### TargetController 終了 #####");
  } // targetController()

} // TargetController Class

// End of file
