import * as tf from '@tensorflow/tfjs';

import i18n from "./i18n";

let isTFStarted = false;

export function startTF() {
    if (!isTFStarted) {
        window.STATUS = document.getElementById('status');
        window.VIDEO = document.getElementById('webcam');
        window.ENABLE_CAM_BUTTON = document.getElementById('enableCam');
        window.RESET_BUTTON = document.getElementById('reset');
        window.TRAIN_BUTTON = document.getElementById('train');
        window.MOBILE_NET_INPUT_WIDTH = 224;
        window.MOBILE_NET_INPUT_HEIGHT = 224;
        window.STOP_DATA_GATHER = -1;
        window.CLASS_NAMES = [];

        window.mobilenet = undefined;
        window.gatherDataState = window.STOP_DATA_GATHER;
        window.videoPlaying = false;
        window.trainingDataInputs = [];
        window.trainingDataOutputs = [];
        window.examplesCount = [];
        window.predict = false;

        loadMobileNetFeatureModel();

    }
}

export function reset() {
    window.predict = false;
    window.examplesCount.length = 0;
    for (let i = 0; i < window.trainingDataInputs.length; i++) {
        window.trainingDataInputs[i].dispose();
    }
    window.trainingDataInputs.length = 0;
    window.trainingDataOutputs.length = 0;
    if(window.STATUS){
        window.STATUS.innerText =  i18n.t("image_classification.no_neural_network_trained");
    }

    window.CLASS_NAMES = [];
    window.PREDICTION_ARRAY = [];
    window.dispatchEvent(new Event('refreshPredictions'));

    console.log('Tensors in memory: ' + tf.memory().numTensors);
}

async function loadMobileNetFeatureModel() {
    const URL =
        './imageClassification/mobilenet_v3/model.json';

    window.mobilenet = await tf.loadGraphModel(URL);

}

function logProgress(epoch, logs) {
    console.log('Data for epoch ' + epoch, logs);
}

function createModel(learningRate) {
    window.model = tf.sequential();

    window.model.add(tf.layers.dense({inputShape: [1024], units: 128, activation: 'relu'}));

    window.model.add(tf.layers.dense({units: window.CLASS_NAMES.length, activation: 'softmax'}));

    window.model.summary();

    window.model.compile({
        optimizer: learningRate && parseInt(learningRate) !== -1 ? tf.train.adam(parseFloat(learningRate)) : 'adam',
        loss: (window.CLASS_NAMES.length === 2) ? 'binaryCrossentropy' : 'categoricalCrossentropy',
        metrics: ['accuracy']
    });
}

async function train(cnnSettings) {
    createModel(cnnSettings[3]);
    window.predict = false;
    tf.util.shuffleCombo(window.trainingDataInputs, window.trainingDataOutputs);
    let outputsAsTensor = tf.tensor1d(window.trainingDataOutputs, 'int32');
    let oneHotOutputs = tf.oneHot(outputsAsTensor, window.CLASS_NAMES.length);
    let inputsAsTensor = tf.stack(window.trainingDataInputs);

    await window.model.fit(inputsAsTensor, oneHotOutputs, {
        shuffle: true, batchSize: cnnSettings[2] ? parseInt(cnnSettings[2]) : 5, epochs: parseInt(cnnSettings[1]) ? cnnSettings[1] : 25,
        callbacks: {onEpochEnd: logProgress}
    });

    outputsAsTensor.dispose();
    oneHotOutputs.dispose();
    inputsAsTensor.dispose();
    window.predict = true;
}

export function trainModel(cnnSettings = ["cnnConfig", null, null, null ]) {
    train(cnnSettings);
}

export async function createScreenshotForDogsAndCats(className, img) {
    let classNumber = window.CLASS_NAMES.indexOf(className);
    if(classNumber === -1){
        window.CLASS_NAMES.push(className);
        classNumber = window.CLASS_NAMES.indexOf(className);
    }

    window.gatherDataState = classNumber;

    if (window.gatherDataState !== window.STOP_DATA_GATHER) {
        let imageFeatures = tf.tidy(function () {
            let videoFrameAsTensor = tf.browser.fromPixels(img);
            let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [window.MOBILE_NET_INPUT_HEIGHT,
                window.MOBILE_NET_INPUT_WIDTH], true);
            let normalizedTensorFrame = resizedTensorFrame.div(255);
            return window.mobilenet.predict(normalizedTensorFrame.expandDims()).squeeze();
        });

        window.trainingDataInputs.push(imageFeatures);
        window.trainingDataOutputs.push(window.gatherDataState);

        if (window.examplesCount[window.gatherDataState] === undefined) {
            window.examplesCount[window.gatherDataState] = 0;
        }
        window.examplesCount[window.gatherDataState]++;

        window.STATUS.innerText = '';
        for (let n = 0; n < window.CLASS_NAMES.length; n++) {
            window.STATUS.innerHTML += window.CLASS_NAMES[n] + i18n.t("image_classification.image_counter_1") + (window.examplesCount[n] ? window.examplesCount[n] : '0') + i18n.t("image_classification.image_counter_2") + '. <br />';
        }
    }
}

export async function createScreenshotFor(className) {
    let classNumber = window.CLASS_NAMES.indexOf(className);
    if(classNumber === -1){
        window.CLASS_NAMES.push(className);
        classNumber = window.CLASS_NAMES.indexOf(className);
    }

    window.gatherDataState = classNumber;

    if (window.videoPlaying && window.gatherDataState !== window.STOP_DATA_GATHER) {
        let imageFeatures = tf.tidy(function () {
            let videoFrameAsTensor = tf.browser.fromPixels(window.VIDEO);
            let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [window.MOBILE_NET_INPUT_HEIGHT,
                window.MOBILE_NET_INPUT_WIDTH], true);
            let normalizedTensorFrame = resizedTensorFrame.div(255);
            return window.mobilenet.predict(normalizedTensorFrame.expandDims()).squeeze();
        });

        window.trainingDataInputs.push(imageFeatures);
        window.trainingDataOutputs.push(window.gatherDataState);

        if (window.examplesCount[window.gatherDataState] === undefined) {
            window.examplesCount[window.gatherDataState] = 0;
        }
        window.examplesCount[window.gatherDataState]++;

        window.STATUS.innerText = '';
        for (let n = 0; n < window.CLASS_NAMES.length; n++) {
            window.STATUS.innerHTML += i18n.t("image_classification.blocks." + window.CLASS_NAMES[n]) + i18n.t("image_classification.image_counter_1") + (window.examplesCount[n] ? window.examplesCount[n] : '0') + i18n.t("image_classification.image_counter_2") + '. <br />';
        }
    }
}

export function webcamPredict() {
    if (window.predict) {
        tf.tidy(function () {
            let videoFrameAsTensor = tf.browser.fromPixels(window.VIDEO).div(255);
            let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [window.MOBILE_NET_INPUT_HEIGHT,
                window.MOBILE_NET_INPUT_WIDTH], true);

            let imageFeatures = window.mobilenet.predict(resizedTensorFrame.expandDims());
            let prediction = window.model.predict(imageFeatures).squeeze();
            let predictionArray = prediction.arraySync();

            window.PREDICTION_ARRAY = predictionArray;
            window.dispatchEvent(new Event('refreshPredictions'));
        });
    }
}

export function imagePredict() {
    if (window.predict) {
        tf.tidy(function () {
            const img = document.getElementById('testImage');
            if(img === null){
                return;
            }
            let videoFrameAsTensor = tf.browser.fromPixels(img).div(255);
            let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [window.MOBILE_NET_INPUT_HEIGHT,
                window.MOBILE_NET_INPUT_WIDTH], true);

            let imageFeatures = window.mobilenet.predict(resizedTensorFrame.expandDims());
            let prediction = window.model.predict(imageFeatures).squeeze();
            let predictionArray = prediction.arraySync();

            window.PREDICTION_ARRAY = predictionArray;
            window.dispatchEvent(new Event('refreshPredictions'));
        });
    }
}
