/*
 *  Title: DaiJa_V4 (Digital-Learning Aide Instrument by JAva)
 *  @author Yoshinari Sasaki
 *  @version 4.0
 *  @since 2020.4.4
 *  Copyright: 2020, 2021
 */
package task;

import active.Activator;
import data.IrisData;
import java.util.ArrayList;
import java.util.Collections;
import layer.Layer;
import layer.LinearLayer;
import util.DJ;
import util.TimeStamp;
import view.PatternViewer;

/**
 * <p> 表　題: Class: Classifier</p>
 * <p> 説　明: アイリスの品種の分類</p>
 * <p> 著　者: Yoshinari Sasaki</p>
 * <p> 著作権: Copyright (c) 2019, 2021</p>
 * <p> 作成日: 2020.4.4</p>
 */
public class Classifier extends Task {
  private boolean trialFlag = false; // 試行フラグ　true:試行時　false:学習時

  // 学習・試行制御
  int epoch = 300; // 1000; // エポックの回数
  int batchNum = 10; //20; // 8; // バッチ数
  int interval = 5; //10; // 経過表示間隔
    
  // 学習パラメータ
  double initialCoef = 0.01; // 重みとバイアスの初期値係数
  double eta = 0.05; // 学習係数
  double dropOutRate = 0.0; //0.3; //0.5 // ドロップ・アウト率（0.0～0.5）

  // 各層のノード数
  int inNodeNum = 4; // 入力層のノード数（データの種類）
  int midNodeNum = 50; // 25; // 中間層のノード数
  int outNodeNum = 3; // 出力層（分類数）
    
  // レイヤー構成
  private LinearLayer middleLayer0; // 中間層０
  private LinearLayer middleLayer1; // 中間層１
  private LinearLayer outputLayer; // 出力層

  /**
   * Taskクラスで新しく作られたスレッドから呼び出される
   */
  @Override
  public void runTask() {
    DJ._print("・タスク開始日時：", TimeStamp.getTimeFormated());
    beginTime = System.currentTimeMillis(); // タスク開始時刻
    
    // パターン・ビューワ・ランチャを得る
    patternViewerFlag = true; // パターン・ビューワを表示する
    patternData0 = new double[4][IrisData.DATA_NUM]; // パターン・データ２個、出力と正解
    patternData1 = new double[4][IrisData.DATA_NUM]; // パターン・データ２個、出力と正解
    patternViewerLauncher = DJ.pattern(
        PatternViewer.PATTERN_GROUP, patternData0, "Classifier0",  // 分類
        PatternViewer.PATTERN_GROUP, patternData1, "Classifier1"); // 分類
    
    // グラフ・ビューワ・ランチャを得る
    graphShift = 1; // グラフの縦軸を移動
    graphData = new double[4];
    dataName = new String[4]; // データ名
    dataName[0] = "LearnError"; // 学習時の二乗誤差
    dataName[1] = "LearnEntropy"; // 学習時の交差エントロピー誤差
    dataName[2] = "TrialError"; // 試行時の二乗誤差
    dataName[3] = "TrialEntropy"; // 試行時の交差エントロピー誤差
    graphViewerLauncher = DJ.graph(epoch, interval, dataName, graphData);
    
    // ニューラルネットワークによる分類
    classifier();
  }
  
  /** 
   * ニューラルネットワークによる分類
   */
  public void classifier() {
    DJ._print("Classifier.classifier() ==============================");
    
    DJ.print("・パラメータ");
    DJ.print_(" initialCoef=", initialCoef);
    DJ.print_(", eta=", eta);
    DJ.print(" epoch=", epoch);
    
    DJ.print_(", batchNum=", batchNum);
    DJ.print_(", dropOutRate=", dropOutRate);
    DJ.print(", interval=", interval);
    
    DJ._print("　各層のノード数");
    DJ.print_(" inNodeNum=", inNodeNum);
    DJ.print_(", midNodeNum=", midNodeNum);
    DJ.print(", outNodeNum=", outNodeNum);
    
    DJ._print("・入力データと教師データを作成");
    int partNum = IrisData.PART_NUM; // 種類ごとのデータ数
    int dataNum = IrisData.DATA_NUM; // 全データ数
    DJ.print_(" partNum=", partNum);
    DJ.print(", dataNum=", dataNum);
    
    double[][] learnData = IrisData.getLearnData(); // 学習用データ
    double[][] trialData = IrisData.getTrialData(); // 試行用データ
    double[][] groupData = IrisData.getGroupData(); // 種別データ
    DJ.print("　学習用データ：learnData=", learnData);
    DJ.print("　試行用データ：trialData=", trialData);
    DJ.print("　種別データ：groupData=", groupData);
    
    // DJ.print("順序をランダムにシャッフルしたインデックスのリスト");
    ArrayList<Integer> randomIndexList = DJ.permutationRandom(dataNum);
    
    // パターン・ビューワに渡すデータ
    patternData0 = new double[4][dataNum]; // パターン・データ２個、出力と正解
    patternData1 = new double[4][dataNum]; // パターン・データ２個、出力と正解

    DJ._print("・ニューラルネットの各層の初期化");
    middleLayer0 = new LinearLayer(inNodeNum, midNodeNum);
    middleLayer0.initialize(eta, initialCoef, Activator.RELU, Layer.ADA_GRAD);
    double[] mid0X = middleLayer0.getX();
    double[] mid0Y = middleLayer0.getY();
//    double[] mid0C = middleLayer0.getC();
//    double[] mid0E = middleLayer0.getE();
//    double[] mid0dEdX = middleLayer0.getdEdX();
//#    double[] mid0dEdY = middleLayer0.getdEdY(); // 参照を代入（注１）
    
    middleLayer1 = new LinearLayer(midNodeNum, midNodeNum);
    middleLayer1.initialize(eta, initialCoef, Activator.RELU, Layer.ADA_GRAD);
//#    double[] mid1X = middleLayer1.getX(); // 参照を代入（注３）
    double[] mid1Y = middleLayer1.getY();
//    double[] mid1C = middleLayer1.getC();
//    double[] mid1E = middleLayer1.getE();
    double[] mid1dEdX = middleLayer1.getdEdX();
//#    double[] mid1dEdY = middleLayer1.getdEdY(); // 参照を代入（注２）
    
    outputLayer = new LinearLayer(midNodeNum, outNodeNum);
    outputLayer.initialize(eta, initialCoef, Activator.SOFTMAX, Layer.ADA_GRAD);
//#    double[] outX = outputLayer.getX(); // 参照を代入（注４）
    double[] outY = outputLayer.getY();
    double[] outC = outputLayer.getC();
//    double[] outE = outputLayer.getE();
    double[] outdEdX = outputLayer.getdEdX();
//    double[] outdEdY = outputLayer.getdEdY();
    
    middleLayer0.setdEdY(mid1dEdX); // （注１）中間０層のdEdYへ、中間１層のdEdXの参照を設定
    middleLayer1.setdEdY(outdEdX); // （注２）中間１層のdEdYへ、出力層のdEdXの参照を設定
    middleLayer1.setX(mid0Y); // （注３）中間１層の入力mid1Xへ、中間０層の出力mid0Yの参照を設定
    outputLayer.setX(mid1Y); // （注４）出力層の入力outXへ、中間層の出力mid1Yの参照を設定
    
    // DJ._print("PatternViewerへの参照を得る");
    patternViewer = patternViewerLauncher.getPatternViewer();
    // DJ._print("GraphViewerへの参照を得る");
    graphViewer = graphViewerLauncher.getGraphViewer();
    if (graphViewer != null) graphViewer.shiftGraphAxis(graphShift); // グラフの縦軸をシフトする
    
    // DJ._print("学習データによる誤差の初期化");
    double errorSum = 0.0; //  二乗誤差
    double entropySum = 0.0; // 交差エントロピー誤差
    
    // DJ._print("試行データによる誤差の初期化");
    double trialErrorSum = 0.0; // 二乗誤差
    double trialEntropySum = 0.0; // 交差エントロピー誤差
    
    DJ._print(" ##### ニューラルネットの学習開始 #####");
    for (int i = 0; i <= epoch ; i++) {
      startTime = System.nanoTime(); // 実行開始時刻
      intervalFlag = (i % interval == interval - 1) | 
              (i == epoch); //経過表示フラグ

      // DJ._print("・学習用データのインデックスをシャッフル");
      Collections.shuffle(randomIndexList);
    
      // 全データを学習
      for (int j = 0; j < dataNum; j++) {
        // DJ._print(" Learning loop started. ------------------------------");
        // DJ.print_(" i=" + i + ", j=" + j);
        trialFlag = false; // 試行フラグ　true:試行時　false:学習時
        
        // ランダム・インデックスの取り出し
        int randomIndex = randomIndexList.get(j);
//        DJ.print(", randomIndex = " + randomIndex );
        // DJ._print(" ##### 前向き伝播の呼び出し #####");
//        for (int k = 0; k < inNodeNum; k++) {
//          mid0X[k] = learnData[randomIndex][k];
//        }
        System.arraycopy(learnData[randomIndex], 0, mid0X, 0, inNodeNum);
        middleLayer0.forward(); // 中間０層の順伝播処理を呼び出し
        middleLayer1.forward(); // 中間１層の順伝播処理を呼び出し
        outputLayer.forward(); // 出力層の順伝播処理を呼び出し
        // DJ._print(" ##### 後向き伝播の呼び出し #####");
//        for (int k = 0; k < outNodeNum; k++) {
//          outC[k] = groupData[randomIndex][k]; // 正解データ
//        }
        System.arraycopy(groupData[randomIndex], 0, outC, 0, outNodeNum); // 正解データ
        outputLayer.backward(outC); // 出力層の逆伝播処理を呼び出し
        middleLayer1.backward(); // 中間層１の逆伝播処理を呼び出し
        middleLayer0.backward(); // 中間層０の逆伝播処理を呼び出し
        
        // バッチ処理
        if ((j % batchNum) == 0 ) {
          // DJ._print(" ##### 更新の呼び出し #####");
          middleLayer0.update(); // 中間層０の更新処理を呼び出し
          middleLayer1.update(); // 中間層１の更新処理を呼び出し
          outputLayer.update(); // 出力層の更新処理を呼び出し
        }

        // DJ._print(" ##### 学習時の誤差の保存 #####");
        errorSum += DJ.getSquareError(outY, outC);
        entropySum += DJ.getEntropyError(outY, outC);
//        DJ._print(" 学習時の誤差");
//        DJ.print_(" errorSum=" + errorSum); // 差の二乗誤差
//        DJ.print(" entropySum=" + entropySum); // エントロピー誤差
        
        // DJ._print(" 試行データによる実行（前向き処理のみ） #####");
        trialFlag = true; // 試行フラグ　true:試行時　false:学習時
        for (int k = 0; k < inNodeNum; k++) {
          double trialVal = trialData[randomIndex][k];
          mid0X[k] = trialVal;
          if (k < 2) patternData0[k][j] = trialVal;
          else       patternData1[k - 2][j] = trialVal;
        }
        middleLayer0.forward(); // 中間０層の順伝播処理を呼び出し
        middleLayer1.forward(); // 中間１層の順伝播処理を呼び出し
        outputLayer.forward(); // 出力層の順伝播処理を呼び出し
        
        // teachDataは共通なので、試行用に再作成しない

        if (intervalFlag) { // 実行時間に影響を与える
          // DJ._print(" パターン表示用の正解と教師データ");
          for (int k = 0; k < outNodeNum; k++) {
            double teachVal = groupData[randomIndex][k];
            if(teachVal == 1.0) {
              patternData0[2][j] = outY[k];
              patternData0[3][j] = k;
              patternData1[2][j] = outY[k];
              patternData1[3][j] = k;
            }
          }
        }
 
        // DJ._print(" ##### 試行時の誤差の保存 #####");
        trialErrorSum += DJ.getSquareError( outY, outC);
        trialEntropySum += DJ.getEntropyError(outY, outC);
//        DJ._print(" 試行時の誤差");
//        DJ._print_(" trialErrorSum = " + trialErrorSum);
//        DJ.print(" trialEntropySum = " + trialEntropySum);

        // DJ._print(" End of one data ---------------");
      } // 全データを学習
      // DJ._print(" End of one epoch ---------------");

      // DJ._print("　エポック毎の誤差の平均値を保存する");
      double learnErrorVal = Math.sqrt(errorSum / dataNum);
      double learnEntropyVal = entropySum / dataNum;
      double trialErrorVal = Math.sqrt(trialErrorSum / dataNum);
      double trialEntropyVal = trialEntropySum / dataNum;
      errorSum = 0.0; entropySum = 0.0;
      trialErrorSum = 0.0; trialEntropySum = 0.0;
      
      // 実行時間の累積
      endTime = System.nanoTime(); // 休止時刻
      double lapTime_ = (endTime - startTime) / 1000000.0;
      if (lapTime_ > 0.0) lapTime = lapTime_; // オーバーフロー対策
      totalTime = totalTime + lapTime; // 経過時間を追加
      
      // 経過表示インターバル
      if (intervalFlag) {
        DJ._print_(" i=" + i);
//        DJ.print_(" SQ-Root Error = " + learnErrorVal);
//        DJ.print(", Entropy Error = " + learnEntropyVal);
        DJ.print_(" Trial SQ-Root Error = " + trialErrorVal);
        DJ.print(", Trial Entropy Error = " + trialEntropyVal);
        
          // グラフ表示用データを代入
          graphData[0] = learnErrorVal;
          graphData[1] = learnEntropyVal;
          graphData[2] = trialErrorVal;
          graphData[3] = trialEntropyVal;
        updateViewer(i); // ビューワの表示を更新
        
        DJ.print_(", lapTime = " + lapTime + "[msec]");
        
        // スレッドの休止（実行速度の調整および経過表示のため）
        synchronized(this) {
          try {
            // DJ.print("Ｅnter to wait(sleepTime)");
            wait(SLEEP_TIME); // タイムアウト付きで待機状態にする。
            // DJ.print("Resume from wait(sleepTime)");
            if (pauseFlag) wait(); // 休止状態
          }
          catch (InterruptedException e) {
             DJ.print("***** ERROR ***** " + getClass().getName() + "\n"
                 + " Exception occur in wait(SLEEP_TIME):" + e.toString());
          }
        } // synchronized()
      } // interval
    
      // 実行処理の中断
      if (abortFlag) {
        DJ._print("##### Abort action requested");
        epoch = i;
      }

    } // End of one epoch
    
    DJ._print(" End of all epoch -------------------------------------------");
    
    DJ._print("・各層の変数の最終値");
    DJ.print(" Last epoch = ", epoch);
    
    DJ._print(" 抽出サンプルによる学習効果の検証 ----------------------------");
    DJ.print(" サンプル0～2はそれぞれ品種1,2,3,1だが,サンプル3は判別困難");
    double[] sampleData0 = {-1.14E+00,-1.32E-01,-1.34E+00,-1.32E+00};
    double[] sampleData1 = {1.04E+00,9.82E-02,3.65E-01,2.64E-01};
    double[] sampleData2 = {7.96E-01,-1.32E-01,8.20E-01,1.05E+00};
//#    double[] sampleData3 = {-1.63E+00,-1.74E+00,-1.40E+00,-1.18E+00};
     double[] sampleData3 = {5.53E-01,-5.92E-01,7.63E-01,3.96E-01};
   double[][] sampleData = new double[4][4];
   sampleData[0] = sampleData0;
   sampleData[1] = sampleData1;
   sampleData[2] = sampleData2;
   sampleData[3] = sampleData3;
    trialFlag = true; // 試行フラグ　true:試行時　false:学習時
    for (int i = 0; i < 4; i++) {
//      for (int j = 0; j < 4; j++) {
//        mid0X[j] = sampleData[i][j];
//      }
      System.arraycopy(sampleData[i], 0, mid0X, 0, 4);
      middleLayer0.forward(); // 中間０層の順伝播処理を呼び出し
      middleLayer1.forward(); // 中間１層の順伝播処理を呼び出し
      outputLayer.forward(); // 出力層の順伝播処理を呼び出し
      DJ.print("出力（品種の推定確率）" + i + " ", outY);
    }
    trialFlag = false; // 試行フラグ　true:試行時　false:学習時
    
    // 実行時間の表示
    DJ._print_("・総実行時間：" + (totalTime / 1000.0) + " [sec]");
    double aveTime = totalTime / epoch;
    DJ.print(", 平均実行時間：" + aveTime + " [msec/epoch]");
    // 処理時間の表示
    DJ.print_("・タスク終了日時：", TimeStamp.getTimeFormated());
    finishTime = System.currentTimeMillis(); // タスク開始時刻
    DJ.print(", タスク処理時間：" + 
            ((finishTime - beginTime) / 1000.0) + " [sec]");
    
  } // Classifier()
  
} // Classification Class

// End of file
