import {
    getActionsFromPosition,
    getCurrentActions,
    getCurrentValueFromAction,
    getNextActionValues,
    getValueFromPositionAndAction,
    isQTableReady,
    updateCurrentAction,
    updateQTable,
    getCellType
} from "./GridWorld";
import {
    BAD_FIELD_REWARD,
    DISCOUNT_FACTOR,
    EXPLORATION_RATE, G, B, getInitialAgent,
    GOAL_REWARD,
    LEARNING_RATE,
    VISIT_REWARD
} from "../constants";
import {getCurrentAgentPosition, setNextAgentPosition, updateAgent} from "./Agent";

let bestAction;
let reward;
let temporalDifferenceReward;

let learningRate = LEARNING_RATE;
let discountFactor = DISCOUNT_FACTOR;
let explorationRate = EXPLORATION_RATE;

export function setParameter(customLearningRate, customDiscountFactor, customExplorationRate){
    if(customLearningRate !== undefined){
        learningRate = customLearningRate;
    }
    if(customDiscountFactor !== undefined){
        discountFactor = customDiscountFactor;
    }
    if(customExplorationRate !== undefined){
        explorationRate = customExplorationRate;
    }
}

export function getCurrentPolicyFromField(x,y) {
    let fieldActions = getActionsFromPosition(x, y);
    let bestAction;
    let bestValue;
    for(let i = 0; i < fieldActions.length; i++){
        if(!isQTableReady()){
            return fieldActions[i];
        }
        if(bestValue === undefined || getValueFromPositionAndAction(x,y,fieldActions[i]) > bestValue){
            bestAction = fieldActions[i];
            bestValue = getValueFromPositionAndAction(x,y,fieldActions[i]);
        }
    }
    return bestAction;
}

function getMaxPossibleAction(){
    let currentActions = getCurrentActions();
    let bestAction;
    let bestValue;
    for(let i = 0; i < currentActions.length; i++){
        if(bestValue === undefined || getCurrentValueFromAction(currentActions[i]) > bestValue){
            bestAction = currentActions[i];
            bestValue = getCurrentValueFromAction(currentActions[i]);
        }
    }
    return bestAction;
}

export function calculateBestAction() {
    let currentActions = getCurrentActions();

    let newAction;
    if(Math.random() < explorationRate){
        newAction = currentActions[Math.floor(Math.random() * currentActions.length)];
    }else {
        newAction = getMaxPossibleAction();
    }

    return bestAction = newAction;
}
function getBestAction() {
    return bestAction;
}

export function calculateTDReward(){
    let nextActions = getNextActionValues();
    let nextAction = Math.max(...nextActions);
    temporalDifferenceReward = reward + discountFactor * nextAction - getCurrentValueFromAction(bestAction);
}

export function updateQValue(episode = 0){
    const currentQValue = getCurrentValueFromAction(bestAction);
    updateCurrentAction(bestAction, currentQValue + learningRate * temporalDifferenceReward);
    if(reward === GOAL_REWARD){
        setNextAgentPosition(0,4);
    }

    updateQTable();
    updateAgent(episode);
}

export function moveAgent(){
    let currentAgent = getCurrentAgentPosition();
    let nextAgentX, nextAgentY;

    if(getBestAction() === "up"){
        nextAgentX = currentAgent.x;
        nextAgentY = currentAgent.y - 1;
    }else if(getBestAction() === "right"){
        nextAgentX = currentAgent.x + 1;
        nextAgentY = currentAgent.y;
    }else if(getBestAction() === "down"){
        nextAgentX = currentAgent.x;
        nextAgentY = currentAgent.y + 1;
    }else if(getBestAction() === "left"){
        nextAgentX = currentAgent.x - 1;
        nextAgentY = currentAgent.y;
    }

    if(getCellType(nextAgentX,nextAgentY) === G){
        reward = GOAL_REWARD;
    } else if(getCellType(nextAgentX,nextAgentY) === B){
        reward = BAD_FIELD_REWARD;
    } else {
        reward = VISIT_REWARD;
    }

    setNextAgentPosition(nextAgentX,nextAgentY);
}

export function checkPolicy(level) {
    let maxSteps = 10;
    let agent = getInitialAgent(level);
    let currentX = agent.x;
    let currentY = agent.y;
    for (let i=0; i < maxSteps; i++){
        let nextAction = getCurrentPolicyFromField(currentX,currentY);

        if(nextAction === "up"){
            currentY = currentY - 1;
        }else if(nextAction === "right"){
            currentX = currentX + 1;
        }else if(nextAction === "down"){
            currentY = currentY + 1;
        }else if(nextAction === "left"){
            currentX = currentX - 1;
        }

        if(getCellType(currentX, currentY) === G){
            return true;
        }
    }
    return false;
}
