/*
 * Decompiled with CFR 0.152.
 */
package task;

import data.IrisData;
import java.util.ArrayList;
import java.util.Collections;
import layer.LinearLayer;
import task.Task;
import util.DJ;
import util.TimeStamp;

public class Classifier
extends Task {
    private boolean trialFlag = false;
    int epoch = 300;
    int batchNum = 10;
    int interval = 5;
    double initialCoef = 0.01;
    double eta = 0.05;
    double dropOutRate = 0.0;
    int inNodeNum = 4;
    int midNodeNum = 50;
    int outNodeNum = 3;
    private LinearLayer middleLayer0;
    private LinearLayer middleLayer1;
    private LinearLayer outputLayer;

    @Override
    public void runTask() {
        DJ._print("\u30fb\u30bf\u30b9\u30af\u958b\u59cb\u65e5\u6642\uff1a", TimeStamp.getTimeFormated());
        this.beginTime = System.currentTimeMillis();
        this.patternViewerFlag = true;
        this.patternData0 = new double[4][75];
        this.patternData1 = new double[4][75];
        this.patternViewerLauncher = DJ.pattern(4, this.patternData0, "Classifier0", 4, this.patternData1, "Classifier1");
        this.graphShift = 1;
        this.graphData = new double[4];
        this.dataName = new String[4];
        this.dataName[0] = "LearnError";
        this.dataName[1] = "LearnEntropy";
        this.dataName[2] = "TrialError";
        this.dataName[3] = "TrialEntropy";
        this.graphViewerLauncher = DJ.graph(this.epoch, this.interval, this.dataName, this.graphData);
        this.classifier();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void classifier() {
        DJ._print("Classifier.classifier() ==============================");
        DJ.print("\u30fb\u30d1\u30e9\u30e1\u30fc\u30bf");
        DJ.print_(" initialCoef=", this.initialCoef);
        DJ.print_(", eta=", this.eta);
        DJ.print(" epoch=", this.epoch);
        DJ.print_(", batchNum=", this.batchNum);
        DJ.print_(", dropOutRate=", this.dropOutRate);
        DJ.print(", interval=", this.interval);
        DJ._print("\u3000\u5404\u5c64\u306e\u30ce\u30fc\u30c9\u6570");
        DJ.print_(" inNodeNum=", this.inNodeNum);
        DJ.print_(", midNodeNum=", this.midNodeNum);
        DJ.print(", outNodeNum=", this.outNodeNum);
        DJ._print("\u30fb\u5165\u529b\u30c7\u30fc\u30bf\u3068\u6559\u5e2b\u30c7\u30fc\u30bf\u3092\u4f5c\u6210");
        int partNum = 25;
        int dataNum = 75;
        DJ.print_(" partNum=", partNum);
        DJ.print(", dataNum=", dataNum);
        double[][] learnData = IrisData.getLearnData();
        double[][] trialData = IrisData.getTrialData();
        double[][] groupData = IrisData.getGroupData();
        DJ.print("\u3000\u5b66\u7fd2\u7528\u30c7\u30fc\u30bf\uff1alearnData=", learnData);
        DJ.print("\u3000\u8a66\u884c\u7528\u30c7\u30fc\u30bf\uff1atrialData=", trialData);
        DJ.print("\u3000\u7a2e\u5225\u30c7\u30fc\u30bf\uff1agroupData=", groupData);
        ArrayList<Integer> randomIndexList = DJ.permutationRandom(dataNum);
        this.patternData0 = new double[4][dataNum];
        this.patternData1 = new double[4][dataNum];
        DJ._print("\u30fb\u30cb\u30e5\u30fc\u30e9\u30eb\u30cd\u30c3\u30c8\u306e\u5404\u5c64\u306e\u521d\u671f\u5316");
        this.middleLayer0 = new LinearLayer(this.inNodeNum, this.midNodeNum);
        this.middleLayer0.initialize(this.eta, this.initialCoef, "active.ReLU", 1);
        double[] mid0X = this.middleLayer0.getX();
        double[] mid0Y = this.middleLayer0.getY();
        this.middleLayer1 = new LinearLayer(this.midNodeNum, this.midNodeNum);
        this.middleLayer1.initialize(this.eta, this.initialCoef, "active.ReLU", 1);
        double[] mid1Y = this.middleLayer1.getY();
        double[] mid1dEdX = this.middleLayer1.getdEdX();
        this.outputLayer = new LinearLayer(this.midNodeNum, this.outNodeNum);
        this.outputLayer.initialize(this.eta, this.initialCoef, "active.SoftMax", 1);
        double[] outY = this.outputLayer.getY();
        double[] outC = this.outputLayer.getC();
        double[] outdEdX = this.outputLayer.getdEdX();
        this.middleLayer0.setdEdY(mid1dEdX);
        this.middleLayer1.setdEdY(outdEdX);
        this.middleLayer1.setX(mid0Y);
        this.outputLayer.setX(mid1Y);
        this.patternViewer = this.patternViewerLauncher.getPatternViewer();
        this.graphViewer = this.graphViewerLauncher.getGraphViewer();
        if (this.graphViewer != null) {
            this.graphViewer.shiftGraphAxis(this.graphShift);
        }
        double errorSum = 0.0;
        double entropySum = 0.0;
        double trialErrorSum = 0.0;
        double trialEntropySum = 0.0;
        DJ._print(" ##### \u30cb\u30e5\u30fc\u30e9\u30eb\u30cd\u30c3\u30c8\u306e\u5b66\u7fd2\u958b\u59cb #####");
        int i = 0;
        while (i <= this.epoch) {
            this.startTime = System.nanoTime();
            this.intervalFlag = i % this.interval == this.interval - 1 | i == this.epoch;
            Collections.shuffle(randomIndexList);
            int j = 0;
            while (j < dataNum) {
                this.trialFlag = false;
                int randomIndex = randomIndexList.get(j);
                System.arraycopy(learnData[randomIndex], 0, mid0X, 0, this.inNodeNum);
                this.middleLayer0.forward();
                this.middleLayer1.forward();
                this.outputLayer.forward();
                System.arraycopy(groupData[randomIndex], 0, outC, 0, this.outNodeNum);
                this.outputLayer.backward(outC);
                this.middleLayer1.backward();
                this.middleLayer0.backward();
                if (j % this.batchNum == 0) {
                    this.middleLayer0.update();
                    this.middleLayer1.update();
                    this.outputLayer.update();
                }
                errorSum += DJ.getSquareError(outY, outC);
                entropySum += DJ.getEntropyError(outY, outC);
                this.trialFlag = true;
                int k = 0;
                while (k < this.inNodeNum) {
                    double trialVal;
                    mid0X[k] = trialVal = trialData[randomIndex][k];
                    if (k < 2) {
                        this.patternData0[k][j] = trialVal;
                    } else {
                        this.patternData1[k - 2][j] = trialVal;
                    }
                    ++k;
                }
                this.middleLayer0.forward();
                this.middleLayer1.forward();
                this.outputLayer.forward();
                if (this.intervalFlag) {
                    k = 0;
                    while (k < this.outNodeNum) {
                        double teachVal = groupData[randomIndex][k];
                        if (teachVal == 1.0) {
                            this.patternData0[2][j] = outY[k];
                            this.patternData0[3][j] = k;
                            this.patternData1[2][j] = outY[k];
                            this.patternData1[3][j] = k;
                        }
                        ++k;
                    }
                }
                trialErrorSum += DJ.getSquareError(outY, outC);
                trialEntropySum += DJ.getEntropyError(outY, outC);
                ++j;
            }
            double learnErrorVal = Math.sqrt(errorSum / (double)dataNum);
            double learnEntropyVal = entropySum / (double)dataNum;
            double trialErrorVal = Math.sqrt(trialErrorSum / (double)dataNum);
            double trialEntropyVal = trialEntropySum / (double)dataNum;
            errorSum = 0.0;
            entropySum = 0.0;
            trialErrorSum = 0.0;
            trialEntropySum = 0.0;
            this.endTime = System.nanoTime();
            double lapTime_ = (double)(this.endTime - this.startTime) / 1000000.0;
            if (lapTime_ > 0.0) {
                this.lapTime = lapTime_;
            }
            this.totalTime += this.lapTime;
            if (this.intervalFlag) {
                DJ._print_(" i=" + i);
                DJ.print_(" Trial SQ-Root Error = " + trialErrorVal);
                DJ.print(", Trial Entropy Error = " + trialEntropyVal);
                this.graphData[0] = learnErrorVal;
                this.graphData[1] = learnEntropyVal;
                this.graphData[2] = trialErrorVal;
                this.graphData[3] = trialEntropyVal;
                this.updateViewer(i);
                DJ.print_(", lapTime = " + this.lapTime + "[msec]");
                Classifier classifier = this;
                synchronized (classifier) {
                    try {
                        this.wait(1L);
                        if (this.pauseFlag) {
                            this.wait();
                        }
                    }
                    catch (InterruptedException e) {
                        DJ.print("***** ERROR ***** " + this.getClass().getName() + "\n" + " Exception occur in wait(SLEEP_TIME):" + e.toString());
                    }
                }
            }
            if (this.abortFlag) {
                DJ._print("##### Abort action requested");
                this.epoch = i;
            }
            ++i;
        }
        DJ._print(" End of all epoch -------------------------------------------");
        DJ._print("\u30fb\u5404\u5c64\u306e\u5909\u6570\u306e\u6700\u7d42\u5024");
        DJ.print(" Last epoch = ", this.epoch);
        DJ._print(" \u62bd\u51fa\u30b5\u30f3\u30d7\u30eb\u306b\u3088\u308b\u5b66\u7fd2\u52b9\u679c\u306e\u691c\u8a3c ----------------------------");
        DJ.print(" \u30b5\u30f3\u30d7\u30eb0\uff5e2\u306f\u305d\u308c\u305e\u308c\u54c1\u7a2e1,2,3,1\u3060\u304c,\u30b5\u30f3\u30d7\u30eb3\u306f\u5224\u5225\u56f0\u96e3");
        double[] sampleData0 = new double[]{-1.14, -0.132, -1.34, -1.32};
        double[] sampleData1 = new double[]{1.04, 0.0982, 0.365, 0.264};
        double[] sampleData2 = new double[]{0.796, -0.132, 0.82, 1.05};
        double[] sampleData3 = new double[]{0.553, -0.592, 0.763, 0.396};
        double[][] sampleData = new double[4][4];
        sampleData[0] = sampleData0;
        sampleData[1] = sampleData1;
        sampleData[2] = sampleData2;
        sampleData[3] = sampleData3;
        this.trialFlag = true;
        int i2 = 0;
        while (i2 < 4) {
            System.arraycopy(sampleData[i2], 0, mid0X, 0, 4);
            this.middleLayer0.forward();
            this.middleLayer1.forward();
            this.outputLayer.forward();
            DJ.print("\u51fa\u529b\uff08\u54c1\u7a2e\u306e\u63a8\u5b9a\u78ba\u7387\uff09" + i2 + " ", outY);
            ++i2;
        }
        this.trialFlag = false;
        DJ._print_("\u30fb\u7dcf\u5b9f\u884c\u6642\u9593\uff1a" + this.totalTime / 1000.0 + " [sec]");
        double aveTime = this.totalTime / (double)this.epoch;
        DJ.print(", \u5e73\u5747\u5b9f\u884c\u6642\u9593\uff1a" + aveTime + " [msec/epoch]");
        DJ.print_("\u30fb\u30bf\u30b9\u30af\u7d42\u4e86\u65e5\u6642\uff1a", TimeStamp.getTimeFormated());
        this.finishTime = System.currentTimeMillis();
        DJ.print(", \u30bf\u30b9\u30af\u51e6\u7406\u6642\u9593\uff1a" + (double)(this.finishTime - this.beginTime) / 1000.0 + " [sec]");
    }
}

