/*
 * Decompiled with CFR 0.152.
 */
package task;

import java.util.ArrayList;
import java.util.Collections;
import layer.ConvoluteLayer;
import layer.LinearLayer;
import layer.PoolingLayer;
import task.Task;
import util.DJ;
import util.LogEditor;
import util.TimeStamp;
import view.GraphViewer;
import view.GraphViewerLauncher;
import view.PatternViewer;
import view.PatternViewerLauncher;

public class ConvNet
extends Task {
    private boolean trialFlag = false;
    double[] inData;
    double[] teachData;
    int epoch = 500;
    int batchNum = 10;
    int interval = 5;
    double initialCoef = 0.1;
    double filterCoef = 0.05;
    double eta = 0.01;
    double dropOutRate = 0.0;
    int filterNum = 10;
    int filterHight;
    int filterWidth = this.filterHight = 3;
    int filterStride = 1;
    int padding = 0;
    int poolingHight;
    int poolingWidth = this.poolingHight = 2;
    int inImageHight;
    int inImageWidth = this.inImageHight = 8;
    int inNodeNum = this.inImageHight * this.inImageWidth;
    private final int convoNodeHight = (this.inImageHight - this.filterHight) / this.filterStride + 1;
    private final int convoNodeWidth = (this.inImageWidth - this.filterWidth) / this.filterStride + 1;
    private final int convoNodeNum = this.filterNum * this.convoNodeHight * this.convoNodeWidth;
    private final int poolNodeHight = (this.convoNodeHight + this.padding * 2) / this.poolingHight;
    private final int poolNodeWidth = (this.convoNodeWidth + this.padding * 2) / this.poolingWidth;
    private final int poolNodeNum;
    private final int midNodeNum = this.poolNodeNum = this.filterNum * this.poolNodeHight * this.poolNodeWidth;
    private final int outNodeNum = 10;
    private ConvoluteLayer convoLayer0;
    private PoolingLayer poolLayer0;
    private LinearLayer middleLayer0;
    private LinearLayer middleLayer1;
    private LinearLayer outputLayer;
    int trialCount = 0;
    int correctCount = 0;
    double correctLevel = 0.5;
    double correctRatio;
    boolean patternViewerFlag;
    PatternViewerLauncher patternViewerLauncher;
    PatternViewer patternViewer;
    double[][][] patternData0;
    double[][][] patternData1;
    int graphShift = 0;
    GraphViewerLauncher graphViewerLauncher;
    GraphViewer graphViewer;
    double[] graphData;
    String[] dataName;

    @Override
    public void runTask() {
        DJ._print("\u30fb\u30bf\u30b9\u30af\u958b\u59cb\u65e5\u6642\uff1a", TimeStamp.getTimeFormated());
        this.beginTime = System.currentTimeMillis();
        this.patternData0 = new double[1][10][2];
        this.patternData1 = new double[this.filterNum][this.filterHight][this.filterWidth];
        this.patternViewerLauncher = DJ.pattern(5, this.patternData0, "ConvNet0", 1, this.patternData1, "ConvNet1");
        this.graphData = new double[4];
        this.dataName = new String[4];
        this.dataName[0] = "LearnError";
        this.dataName[1] = "LearnEntropy";
        this.dataName[2] = "TrialError";
        this.dataName[3] = "TrialEntropy";
        this.graphViewerLauncher = DJ.graph(this.epoch, this.interval, this.dataName, this.graphData);
        this.convNet();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void convNet() {
        DJ._print("ConvNet.convNet() ============================");
        DJ._print(" \u30b3\u30f3\u30dc\u30ea\u30e5\u30fc\u30b7\u30e7\u30ca\u30eb\u30fb\u30cb\u30e5\u30fc\u30e9\u30eb\u30fb\u30cd\u30c3\u30c8\u30ef\u30fc\u30af\u306b\u3088\u308b\u753b\u50cf\u8a8d\u8b58");
        DJ._print("\u30fb\u30d1\u30e9\u30e1\u30fc\u30bf");
        DJ.print_(" \u91cd\u307f\u3068\u30d0\u30a4\u30a2\u30b9\u306e\u521d\u671f\u5024\u4fc2\u6570:initialCoef=", this.initialCoef);
        DJ.print_(" \u30d5\u30a3\u30eb\u30bf\u306e\u521d\u671f\u5024\u4fc2\u6570:filterCoef=", this.filterCoef);
        DJ.print_(", \u5b66\u7fd2\u4fc2\u6570:eta=", this.eta);
        DJ.print_(" \u30a8\u30dd\u30c3\u30af\u56de\u6570:epoch=", this.epoch);
        DJ.print_(", \u30d0\u30c3\u30c1\u6570:batchNum=", this.batchNum);
        DJ.print(", \u30c9\u30ed\u30c3\u30d7\u30fb\u30a2\u30a6\u30c8\u7387:dropOutRate=", this.dropOutRate);
        DJ.print(", \u7d4c\u904e\u8868\u793a\u9593\u9694:interval=", this.interval);
        DJ._print("\u30fb\u30cb\u30e5\u30fc\u30e9\u30eb\u30cd\u30c3\u30c8\u306e\u69cb\u6210");
        DJ.print_(" \u5165\u529b\u5c64\u30ce\u30fc\u30c9\u6570:inNodeNum=", this.inNodeNum);
        DJ.print(", \u7573\u8fbc\u5c64\u30ce\u30fc\u30c9\u6570:convoNodeNum=", this.convoNodeNum);
        DJ.print(" \u30d7\u30fc\u30ea\u30f3\u30b0\u5c64\u30ce\u30fc\u30c9\u6570:poolNodeNum=", this.poolNodeNum);
        DJ.print_(" \u4e2d\u9593\u5c64\u30ce\u30fc\u30c9\u6570:midNodeNum=", this.midNodeNum);
        DJ.print(", \u51fa\u529b\u5c64\u30ce\u30fc\u30c9\u6570:outNodeNum=", 10);
        DJ._print("\u30fb\u7573\u307f\u8fbc\u307f\u3068\u30d7\u30fc\u30ea\u30f3\u30b0\u306e\u69cb\u6210");
        DJ.print_(" \u30d5\u30a1\u30a4\u30eb\u6570:filterNum=", this.filterNum);
        DJ.print_(", \u30d5\u30a3\u30eb\u30bf\u306e\u9ad8\u3055:filterHight=", this.filterHight);
        DJ.print(", \u30d5\u30a3\u30eb\u30bf\u306e\u5e45:filterHight=", this.filterWidth);
        DJ.print_(" \u30d5\u30a3\u30eb\u30bf\u306e\u79fb\u52d5\u91cf:filterStride=", this.filterStride);
        DJ.print(", \u30d1\u30c7\u30f3\u30b0\u91cf\uff08\u7573\u8fbc\u5c64\u7528\uff09:padding=", this.padding);
        DJ.print_(" \u30d7\u30fc\u30ea\u30f3\u30b0\u67a0\u306e\u9ad8\u3055:poolingHight=", this.poolingHight);
        DJ.print(", \u30d7\u30fc\u30ea\u30f3\u30b0\u67a0\u306e\u5e45:poolingWidth=", this.poolingWidth);
        DJ._print("\u30fb\u5165\u529b\u753b\u50cf\u3092\u30d5\u30a1\u30a4\u30eb\u304b\u3089\u8aad\u8fbc");
        DJ.print(" DaiJa/resorce/DLearn/input_image.txt, \u30b3\u30f3\u30de\u533a\u5207\u308a, 1797[\u679a]\u00d764[\u753b\u7d20]");
        double[][] inputImage = LogEditor.loadData("input_image.txt");
        int inputImageNum = inputImage.length;
        int imageElementNum = inputImage[0].length;
        DJ.print_(" \u5165\u529b\u753b\u50cf\u306e\u679a\u6570=" + inputImageNum);
        DJ.print(",  \u5165\u529b\u753b\u50cf\u306e\u8981\u7d20\u6570=" + imageElementNum);
        double average = DJ.average(inputImage);
        double stdDev = DJ.stdDev(inputImage);
        DJ.print(" \u5165\u529b\u753b\u50cf\u30c7\u30fc\u30bf\u306e\u5e73\u5747: average = inputImage.average()=", average);
        DJ.print(" \u6a19\u6e96\u504f\u5dee: stdDev = inputImage.stdDev()=", stdDev);
        DJ._print("\u30fb\u5165\u529b\u753b\u50cf\u30c7\u30fc\u30bf\u306e\u6b63\u898f\u5316\uff1a\u5e73\u5747\u5024 = 0, \u6a19\u6e96\u504f\u5dee = 1");
        double[][] inputImageStd = DJ.normalize(inputImage);
        double averageNomal = DJ.average(inputImageStd);
        double stdDevNomal = DJ.stdDev(inputImageStd);
        DJ.print_(" \u6b63\u898f\u5316\u3055\u308c\u305f\u5165\u529b\u753b\u50cf\u30c7\u30fc\u30bf\u306e\u5e73\u5747:averageNomal=", averageNomal);
        DJ.print(",  \u6a19\u6e96\u504f\u5dee=", stdDevNomal);
        DJ._print("\u30fb\u6b63\u89e3\u6587\u5b57\u5217\u3092\u30d5\u30a1\u30a4\u30eb\u304b\u3089\u8aad\u8fbc");
        DJ.print(" \u30d7\u30ed\u30b8\u30a7\u30af\u30c8\u30fb\u30d5\u30a9\u30eb\u30c0/resorce/target_text.txt, \u6574\u6570\u4e00\u6841, \u30b3\u30f3\u30de\u533a\u5207\u308a, 1797[\u6587\u5b57]");
        int[][] targetText = LogEditor.loadIntData("target_text.txt");
        int targetNum = targetText[0].length;
        DJ.print(" \u6b63\u89e3\u6587\u5b57\u5217\u306e\u9577\u3055:targetNum=", targetNum);
        DJ._print("\u30fb\u6b63\u89e3\u30c7\u30fc\u30bf\u3092one-hot\u8868\u73fe\u3067\u751f\u6210, 10[\u884c]\u00d71797[\u5217]");
        double[][] correctData = new double[targetNum][10];
        int i = 0;
        while (i < targetNum) {
            int j = 0;
            while (j < 10) {
                correctData[i][j] = 0.0;
                ++j;
            }
            int k = targetText[0][i];
            correctData[i][k] = 1.0;
            ++i;
        }
        DJ._print("\u30fb\u30e9\u30f3\u30c0\u30e0\u306a\u30a4\u30f3\u30c7\u30c3\u30af\u30b9\u3092\u751f\u6210, 1797[\u500b]");
        DJ.print(" \u30e9\u30f3\u30c0\u30e0\u30fb\u30a4\u30f3\u30c7\u30c3\u30af\u30b9\u6570:inputImageNum=", inputImageNum);
        ArrayList<Integer> randomIndexList = DJ.permutationRandom(inputImageNum);
        int randomListNum = randomIndexList.size();
        DJ.print(" \u30e9\u30f3\u30c0\u30e0\u30fb\u30ea\u30b9\u30c8\u9577:ranListNum=", randomListNum);
        DJ.print(" \u30e9\u30f3\u30c0\u30e0\u306a\u30a4\u30f3\u30c7\u30c3\u30af\u30b9\u30922:1\u3067\u5b66\u7fd2\u7528\u3068\u8a66\u884c\u7528\u306b\u5272\u308a\u632f\u308a");
        ArrayList<Integer> learnIndexList = new ArrayList<Integer>();
        ArrayList<Integer> trialIndexList = new ArrayList<Integer>();
        int i2 = 0;
        while (i2 < randomListNum) {
            if (i2 % 3 != 0) {
                learnIndexList.add(randomIndexList.get(i2));
            } else {
                trialIndexList.add(randomIndexList.get(i2));
            }
            ++i2;
        }
        int learnNum = learnIndexList.size();
        int trialNum = trialIndexList.size();
        DJ.print_(" \u5b66\u7fd2\u7528\u30c7\u30fc\u30bf\u6570:learnNum=", learnNum);
        DJ.print(", \u8a66\u884c\u7528\u30c7\u30fc\u30bf\u6570:trialNum=", trialNum);
        this.inData = new double[this.inNodeNum];
        this.teachData = new double[10];
        DJ._print("\u30fb\u30cb\u30e5\u30fc\u30e9\u30eb\u30cd\u30c3\u30c8\u306e\u5404\u5c64\u306e\u521d\u671f\u5316");
        this.convoLayer0 = new ConvoluteLayer(this.inImageHight, this.inImageWidth);
        this.convoLayer0.initialize(this.eta, this.initialCoef, this.filterNum, this.filterHight, this.filterWidth, this.filterStride, 1);
        ArrayList<double[][]> convo0yList = this.convoLayer0.get_yList();
        this.poolLayer0 = new PoolingLayer(this.convoNodeHight, this.convoNodeWidth);
        this.poolLayer0.initialize(this.filterNum, this.padding, this.poolingHight, this.poolingWidth);
        ArrayList<double[][]> pool0dEdXList = this.poolLayer0.get_dEdXList();
        this.middleLayer0 = new LinearLayer(this.poolNodeNum, this.midNodeNum);
        this.middleLayer0.initialize(this.eta, this.initialCoef, "active.ReLU", 1);
        double[] mid0Y = this.middleLayer0.getY();
        this.middleLayer1 = new LinearLayer(this.midNodeNum, this.midNodeNum);
        this.middleLayer1.initialize(this.eta, this.initialCoef, "active.ReLU", 1);
        double[] mid1Y = this.middleLayer1.getY();
        double[] mid1dEdX = this.middleLayer1.getdEdX();
        this.outputLayer = new LinearLayer(this.midNodeNum, 10);
        this.outputLayer.initialize(this.eta, this.initialCoef, "active.SoftMax", 1);
        double[] outY = this.outputLayer.getY();
        double[] outdEdX = this.outputLayer.getdEdX();
        this.middleLayer0.setdEdY(mid1dEdX);
        this.middleLayer1.setdEdY(outdEdX);
        this.middleLayer1.setX(mid0Y);
        this.outputLayer.setX(mid1Y);
        this.convoLayer0.set_dEdYList(pool0dEdXList);
        this.poolLayer0.set_xList(convo0yList);
        this.patternViewer = this.patternViewerLauncher.getPatternViewer();
        this.graphViewer = this.graphViewerLauncher.getGraphViewer();
        if (this.graphViewer != null) {
            this.graphViewer.shiftGraphAxis(1);
        }
        DJ._print(" ##### \u30cb\u30e5\u30fc\u30e9\u30eb\u30cd\u30c3\u30c8\u306e\u5b66\u7fd2\u958b\u59cb #####");
        int i3 = 0;
        while (i3 <= this.epoch) {
            this.startTime = System.nanoTime();
            this.intervalFlag = i3 % this.interval == this.interval - 1 | i3 == this.epoch;
            if (this.intervalFlag) {
                this.trialCount = 0;
                this.correctCount = 0;
            }
            Collections.shuffle(learnIndexList);
            Collections.shuffle(trialIndexList);
            double learnErrorSum = 0.0;
            double learnEntropySum = 0.0;
            double trialErrorSum = 0.0;
            double trialEntropySum = 0.0;
            this.patternData0 = new double[trialNum][10][2];
            int j = 0;
            while (j < learnNum) {
                this.trialFlag = false;
                int learnIndex = (Integer)learnIndexList.get(j);
                this.inData = inputImageStd[learnIndex];
                this.forward(this.inData);
                this.teachData = correctData[learnIndex];
                this.backward(this.teachData);
                if (j % this.batchNum == 0) {
                    this.update();
                }
                learnErrorSum += DJ.getSquareError(outY, this.teachData);
                learnEntropySum += DJ.getEntropyError(outY, this.teachData);
                if (j < trialNum) {
                    this.trialFlag = true;
                    int trialIndex = (Integer)trialIndexList.get(j);
                    this.inData = inputImageStd[trialIndex];
                    this.forward(this.inData);
                    this.teachData = correctData[trialIndex];
                    trialErrorSum += DJ.getSquareError(outY, this.teachData);
                    trialEntropySum += DJ.getEntropyError(outY, this.teachData);
                    if (this.intervalFlag) {
                        int k = 0;
                        while (k < 10) {
                            this.patternData0[j][k][0] = outY[k];
                            this.patternData0[j][k][1] = this.teachData[k];
                            ++k;
                        }
                        ++this.trialCount;
                        int correctNumber = targetText[0][trialIndex];
                        double output = outY[correctNumber];
                        if (output >= this.correctLevel) {
                            ++this.correctCount;
                        }
                    }
                    this.trialFlag = false;
                }
                ++j;
            }
            double learnErrorVal = Math.sqrt(learnErrorSum / (double)learnNum);
            double learnEntropyVal = learnEntropySum / (double)learnNum;
            double trialErrorVal = Math.sqrt(trialErrorSum / (double)trialNum);
            double trialEntropyVal = trialEntropySum / (double)trialNum;
            this.endTime = System.nanoTime();
            double lapTime_ = (double)(this.endTime - this.startTime) / 1000000.0;
            if (lapTime_ > 0.0) {
                this.lapTime = lapTime_;
            }
            this.totalTime += this.lapTime;
            if (this.intervalFlag) {
                DJ.print_(" i=" + i3);
                this.correctRatio = (double)this.correctCount / (double)this.trialCount * 100.0;
                DJ.print_(", \u6b63\u7b54\u6bd4\u7387=", String.valueOf(this.correctRatio) + "[%]");
                ArrayList<double[][]> filterWeightList = this.convoLayer0.getFilterWeightList();
                int k = 0;
                while (k < this.filterNum) {
                    int m = 0;
                    while (m < this.filterHight) {
                        int n = 0;
                        while (n < this.filterWidth) {
                            double[][] filterWeight = filterWeightList.get(k);
                            this.patternData1[k][m][n] = filterWeight[m][n];
                            ++n;
                        }
                        ++m;
                    }
                    ++k;
                }
                if (this.patternViewer == null) {
                    this.patternViewer = this.patternViewerLauncher.getPatternViewer();
                }
                if (this.patternViewer != null) {
                    this.patternViewer.setVisible(true);
                    this.patternViewer.updatePattern(this.patternData0, this.patternData1);
                }
                if (this.graphViewer == null) {
                    this.graphViewer = this.graphViewerLauncher.getGraphViewer();
                    if (this.graphViewer != null) {
                        this.graphViewer.shiftGraphAxis(1);
                    }
                }
                if (this.graphViewer != null) {
                    this.graphData[0] = learnErrorVal;
                    this.graphData[1] = learnEntropyVal;
                    this.graphData[2] = trialErrorVal;
                    this.graphData[3] = trialEntropyVal;
                    this.graphViewer.updateGraph(i3, this.graphData);
                }
                DJ.print(", lapTime=" + this.lapTime + "[msec]");
                ConvNet convNet = this;
                synchronized (convNet) {
                    try {
                        this.wait(1L);
                        if (this.pauseFlag) {
                            this.wait();
                        }
                    }
                    catch (InterruptedException e) {
                        DJ.print("***** ERROR ***** ConvolutionalNN.\n Exception occur in wait(sleepTime):" + e.toString());
                    }
                }
            }
            if (this.abortFlag) {
                DJ._print("##### Abort action requested");
                this.epoch = i3;
            }
            ++i3;
        }
        DJ._print(" End of all epoch --------------------------------------------");
        DJ._print(" \u62bd\u51fa\u30b5\u30f3\u30d7\u30eb\u306b\u3088\u308b\u5b66\u7fd2\u52b9\u679c\u306e\u691c\u8a3c ----------------------------");
        this.trialFlag = true;
        int j = 0;
        while (j < 10) {
            DJ._print("\u62bd\u51fa\u30b5\u30f3\u30d7\u30eb:" + j);
            int trialIndex = (Integer)trialIndexList.get(j);
            int correctDigit = targetText[0][trialIndex];
            DJ.print("\u6b63\u89e3\u6587\u5b57:correctDigit=", correctDigit);
            this.inData = inputImageStd[trialIndex];
            this.forward(this.inData);
            DJ.print("\u8a66\u884c\u51fa\u529b:outLayer.y=", DJ.reshape(outY, 10, 1));
            ++j;
        }
        this.trialFlag = false;
        DJ._print("\u30fb\u6700\u7d42\u8aa4\u5dee: ");
        int k = 0;
        while (k < this.graphData.length) {
            DJ.print("  " + this.dataName[k] + "=" + this.graphData[k]);
            ++k;
        }
        DJ._print_("\u30fb\u8a66\u884c\u56de\u6570=", this.trialCount);
        DJ.print_(", \u6b63\u7b54\u56de\u6570=", this.correctCount);
        DJ.print("\u30fb\u6700\u7d42\u6b63\u7b54\u6bd4\u7387=", String.valueOf(this.correctRatio) + "[%]");
        DJ.print_("\u30fb\u7dcf\u5b9f\u884c\u6642\u9593\uff1a" + this.totalTime / 1000.0 + " [sec]");
        double aveTime = this.totalTime / (double)this.epoch;
        DJ.print(", \u5e73\u5747\u5b9f\u884c\u6642\u9593\uff1a" + aveTime + " [msec/epoch]");
        DJ.print_("\u30fb\u30bf\u30b9\u30af\u7d42\u4e86\u65e5\u6642\uff1a", TimeStamp.getTimeFormated());
        this.finishTime = System.currentTimeMillis();
        DJ.print(", \u30bf\u30b9\u30af\u51e6\u7406\u6642\u9593\uff1a" + (double)(this.finishTime - this.beginTime) / 1000.0 + " [sec]");
    }

    private void forward(double[] x) {
        ArrayList<double[][]> xList = this.convoLayer0.get_xList();
        double[][] x_ = xList.get(0);
        int i = 0;
        while (i < this.inImageHight) {
            int j = 0;
            while (j < this.inImageWidth) {
                x_[i][j] = x[this.inImageWidth * i + j];
                ++j;
            }
            ++i;
        }
        this.convoLayer0.forward();
        this.poolLayer0.forward();
        ArrayList<double[][]> yList = this.poolLayer0.get_yList();
        double[] mid0X = this.middleLayer0.getX();
        int k = 0;
        while (k < this.filterNum) {
            int i2 = 0;
            while (i2 < this.poolNodeHight) {
                System.arraycopy(yList.get(k)[i2], 0, mid0X, this.poolNodeWidth * (this.poolNodeHight * k + i2), this.poolNodeWidth);
                ++i2;
            }
            ++k;
        }
        this.middleLayer0.forward();
        this.middleLayer1.forward();
        this.outputLayer.forward();
    }

    private void backward(double[] teach) {
        double[] outC = this.outputLayer.getC();
        System.arraycopy(teach, 0, outC, 0, 10);
        this.outputLayer.backward(outC);
        this.middleLayer1.backward();
        this.middleLayer0.backward();
        double[] mid0dEdX = this.middleLayer0.getdEdX();
        ArrayList<double[][]> dEdYList = this.poolLayer0.get_dEdYList();
        int k = 0;
        while (k < this.filterNum) {
            double[][] dEdY = dEdYList.get(k);
            int i = 0;
            while (i < this.poolNodeHight) {
                int j = 0;
                while (j < this.poolNodeWidth) {
                    dEdY[i][j] = mid0dEdX[this.poolNodeWidth * (this.poolNodeHight * k + i) + j];
                    ++j;
                }
                ++i;
            }
            ++k;
        }
        this.poolLayer0.backward();
        this.convoLayer0.backward();
    }

    private void update() {
        this.convoLayer0.update();
        this.poolLayer0.update();
        this.middleLayer0.update();
        this.middleLayer1.update();
        this.outputLayer.update();
    }
}

