/*
 *  Title: DaiJa_V4 (Digital-Learning Aide Instrument by JAva)
 *  @author Yoshinari Sasaki
 *  @version 4.0
 *  @since 2020.7.1
 *  Copyright: 2020, 2021
 */
package task;

import active.Activator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import layer.Layer;
import layer.LinearLayer;
import util.DJ;
import util.TimeStamp;
import view.PatternViewer;

/**
 * <p> 表　題: Class: AndLogic</p>
 * <p> 説　明: 単層ニューラルネットワークによるAND論理の学習</p>
 * <p> 著　者: Yoshinari Sasaki</p>
 * <p> 著作権: Copyright (c) 2020, 2021</p>
 * <p> 作成日: 2020.03.15</p>
 */
public class AndLogic extends Task {
  
  // 学習・試行制御
  int epoch = 20; // エポックの回数
  int interval = 1; // 経過表示間隔
  
  // 学習パラメータ
  double initialCoef = 0.01; // 重みとバイアスの初期値係数
  double eta = 0.1; // 学習係数
  int dataNum = 64; // エポック毎の学習データ数（２値の２入力の組み合わせは４種）
  
  // 各層のノード数
  // 入出力の個数
  int inNodeNum = 2; // 入力の個数（２入力）
  int outNodeNum = 1; // 出力の個数（１出力）
  // レイヤー構成
  private LinearLayer outputLayer; // 出力層
        
 /**
   * Taskクラスで新しく作られたスレッドから呼び出される
   */
  @Override
  public void runTask() {
    DJ._print("・タスク開始日時：", TimeStamp.getTimeFormated());
    beginTime = System.currentTimeMillis(); // タスク開始時刻
    
    // パターン・ビューワ・ランチャを得る
    patternViewerFlag = true; // パターン・ビューワを表示する
    patternData0 = new double[4][dataNum]; // 入力x0、入力x1、出力y、正解t
    patternViewerLauncher = DJ.pattern(
        PatternViewer.PATTERN_LOGIC, patternData0, "AndLogic"); // 論理
    
    // グラフ・ビューワー・ランチャーを得る
    graphShift = 3; // グラフの縦軸を移動
    graphData = new double[2];
    dataName = new String[2]; // データ名
    dataName[0] = "squareError"; // 学習時の二乗誤差
    dataName[1] = "aveError"; // 学習時の平均二乗誤差
    graphViewerLauncher = DJ.graph(epoch, interval, dataName, graphData);
    
    // 単層ニューラルネットワークによるAnd論理の学習
    andLogic();
  }

  /**
   * 単層ニューラルネットワークによるAnd論理の学習
   */
  public void andLogic() {
    DJ._print("AndLogic.andLogic() ==================");
    
    DJ.print("・パラメータ");
    DJ.print_(" initialCoef=", initialCoef);
    DJ.print_(", eta=", eta);
    DJ.print_(", epoch=", epoch);
    DJ.print_(", interval=", interval);
    DJ.print(", 学習データ数：", dataNum);

    DJ._print_(" 入力の個数：", inNodeNum);
    DJ.print(", 出力の個数：", outNodeNum);
    
    DJ._print("・入力データと正解データを作成");
    int[][] inputData = new int[dataNum][2]; // 入力データ
    int[] correctData = new int[dataNum]; // 正解データ
    Random randum =  DJ.getRandom(); // 乱数
    for (int i = 0; i < dataNum; i++) {
      boolean b0 = randum.nextBoolean();
      boolean b1 = randum.nextBoolean();
      // AND論理
/**/      boolean bAnd = b0 & b1; // 「&」はAND論理演算子
//      boolean bAnd = b0 | b1; // 「|」はOR論理演算子
      inputData[i][0] = DJ.boolToInt(b0); // 入力データ
      inputData[i][1] = DJ.boolToInt(b1); // 入力データ
      correctData[i] = DJ.boolToInt(bAnd); // 正解データ
    }
    
    // DJ.print(" 順序をランダムにシャッフルしたインデックスのリスト");
    ArrayList<Integer> indexList = DJ.permutationRandom(dataNum);
    
    DJ.print("inputData", inputData);
    DJ.print("correctData", correctData);
    
    // パターン・ビューワに渡すデータ
    patternData0 = new double[4][dataNum]; // 入力x0、入力x1、出力y、正解t
    
    DJ._print("・ニューラルネットの各層の初期化");
    DJ.print(" outLayer  = new Layer(inNodeNum, outNodeNum)");
    outputLayer = new LinearLayer(inNodeNum, outNodeNum);
    outputLayer.initialize(eta, initialCoef, Activator.RELU, Layer.SGD);
    double[] x = outputLayer.getX();
    double[] y = outputLayer.getY();
    double[] e = outputLayer.getE();
    double[] c = outputLayer.getC();
    
    // DJ._print("PatternViewerへの参照を得る");
    patternViewer = patternViewerLauncher.getPatternViewer();
    // DJ._print("GraphViewerへの参照を得る");
    graphViewer = graphViewerLauncher.getGraphViewer();
    if (graphViewer != null) graphViewer.shiftGraphAxis(graphShift); // グラフの縦軸をシフトする

    double squareError = 0.0; // 誤差、二乗誤差
    double meanError = 0.0; // 累積誤差
    
    DJ._print("・実行開始時刻：", TimeStamp.getTimeFormated());
    
    DJ._print(" ##### ニューラルネットの学習開始 #####");
    for (int i = 0; i <= epoch; i++) {
      startTime = System.nanoTime(); // 実行開始時刻
      intervalFlag = (i % interval == interval - 1)
              | (i == epoch); //経過表示フラグ
      
      // DJ._print("・学習用データのインデックスをシャッフル");
      Collections.shuffle(indexList);
      
      // 全データを学習
      for (int j = 0; j < dataNum; j++) {
        // DJ._print(" Learning loop started. ------------------------------");
//         DJ.print_(" i=" + i + ", j=" + j);

        int index = indexList.get(j); // ランダム・インデックス
//         DJ.print(", index = " + index );

        // DJ._print(" ##### 前向き伝播 #####");
        // 入力データをランダムに取出し
        for (int k = 0; k < 2; k++) x[k] = inputData[index][k];
        outputLayer.forward(); // 出力層の順伝播処理を呼び出し
//          DJ._print("x", x); DJ.print("y", y);
        
        // DJ._print(" ##### 後向き伝播 #####");
        c[0] = correctData[index]; // 正解データをランダムに取出し
        outputLayer.backward(c); // 出力層の逆伝播処理を呼び出し

        // DJ._print(" ##### バイアスと重みの更新 #####");
        outputLayer.update(); // 出力層の更新処理を呼び出し

        if (intervalFlag) {
         // 二乗誤差の算出
          squareError = e[0] * e[0] / 2.0;
          meanError = meanError + squareError;

          // パターン表示用データの保存
          patternData0[0][j] = x[0];
          patternData0[1][j] = x[1];
          patternData0[2][j] = y[0];
          patternData0[3][j] = c[0];
        }

        // DJ._print(" End of one data -------------------------------------");
      }
      // DJ._print(" End of one epoch ---------------");

      // DJ._print("　エポック毎の誤差の平均値を保存する");
      double aveError = meanError / dataNum ;
      meanError = 0.0;
      
      // 実行時間の累積
      endTime = System.nanoTime(); // 休止時刻
      double lapTime_ = (endTime - startTime) / 1000000.0;
      if (lapTime_ > 0.0) lapTime = lapTime_; // オーバーフロー対策
      totalTime = totalTime + (double)lapTime; // 経過時間を追加
      
      if (intervalFlag) {
        // DJ._print(" ##### １エポックの実行結果 #####");
        DJ._print_(" i=" + i);
        DJ.print_(", squareError = ", squareError);
        DJ.print(", aveError = ", aveError);
        
        // グラフ表示用データを代入
        graphData[0] = squareError;
        graphData[1] = aveError;
        updateViewer(i); // ビューワの表示を更新
        
        DJ.print_(" lapTime = ", lapTime); DJ.print("[msec]");
      
        // スレッドの休止（実行速度の調整および経過表示のため）
        synchronized(this) {
          try {
            // DJ.print("Ｅnter to wait(sleepTime)");
            wait(SLEEP_TIME); // タイムアウト付きで待機状態
            // DJ.print("Resume from wait(sleepTime)");
            if (pauseFlag) wait(); // 休止状態
          }
          catch (InterruptedException ex) {
             DJ.print("***** ERROR ***** " + getClass().getName() + "\n"
                 + " Exception occur in wait(sleepTime):" + ex.toString());
          }
        } // synchronized()
      } // interval
      
      // 実行処理の中断
      if (abortFlag) {
        DJ._print("##### Abort action requested");
        epoch = i; // エポック回数を強制的に現在の実行回数に置き換える
      }
      
      // DJ._print(" End of one epoch --------------------------------------");
    } // End of One epoch
    DJ._print(" End of all epoch ---------------");
    
    DJ._print("・各層の変数の最終値");
    DJ.print(" Last epoch = ", epoch);
    DJ._print("　出力層のバイアスと重み");
    DJ.print("outputLayer.b", outputLayer.getB());
    DJ.print("outputLayer.w", outputLayer.getW());
    
    // 実行時間の算出
    DJ._print_("・総実行時間：" + (totalTime / 1000.0) + " [sec]");
    double aveTime = totalTime / epoch;
    DJ.print(", 平均実行時間：" + aveTime + " [msec/epoch]");
    
    DJ.print_("・タスク終了日時：", TimeStamp.getTimeFormated());
    finishTime = System.currentTimeMillis(); // タスク開始時刻
    DJ.print(", タスク処理時間：" + 
            ((finishTime - beginTime) / 1000.0) + " [sec]");

  } // andLogic()
  
} // AndLogic Class

// End of File
