/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.parfor.opt;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Set;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.ForProgramBlock;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.IfProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;

public class OptTreePlanChecker {
    public static void checkProgramCorrectness(ProgramBlock pb, StatementBlock sb, Set<String> fnStack) {
        Program prog = pb.getProgram();
        DMLProgram dprog = sb.getDMLProg();
        if (pb instanceof FunctionProgramBlock && sb instanceof FunctionStatementBlock) {
            FunctionProgramBlock fpb = (FunctionProgramBlock)pb;
            FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
            FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
            for (int i = 0; i < fpb.getChildBlocks().size(); ++i) {
                ProgramBlock pbc = fpb.getChildBlocks().get(i);
                StatementBlock sbc = fstmt.getBody().get(i);
                OptTreePlanChecker.checkProgramCorrectness(pbc, sbc, fnStack);
            }
        } else if (pb instanceof WhileProgramBlock && sb instanceof WhileStatementBlock) {
            WhileProgramBlock wpb = (WhileProgramBlock)pb;
            WhileStatementBlock wsb = (WhileStatementBlock)sb;
            WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
            OptTreePlanChecker.checkHopDagCorrectness(prog, dprog, wsb.getPredicateHops(), wpb.getPredicate(), fnStack);
            for (int i = 0; i < wpb.getChildBlocks().size(); ++i) {
                ProgramBlock pbc = wpb.getChildBlocks().get(i);
                StatementBlock sbc = wstmt.getBody().get(i);
                OptTreePlanChecker.checkProgramCorrectness(pbc, sbc, fnStack);
            }
            OptTreePlanChecker.checkLinksProgramStatementBlock(wpb, wsb);
        } else if (pb instanceof IfProgramBlock && sb instanceof IfStatementBlock) {
            StatementBlock sbc;
            ProgramBlock pbc;
            int i;
            IfProgramBlock ipb = (IfProgramBlock)pb;
            IfStatementBlock isb = (IfStatementBlock)sb;
            IfStatement istmt = (IfStatement)isb.getStatement(0);
            OptTreePlanChecker.checkHopDagCorrectness(prog, dprog, isb.getPredicateHops(), ipb.getPredicate(), fnStack);
            for (i = 0; i < ipb.getChildBlocksIfBody().size(); ++i) {
                pbc = ipb.getChildBlocksIfBody().get(i);
                sbc = istmt.getIfBody().get(i);
                OptTreePlanChecker.checkProgramCorrectness(pbc, sbc, fnStack);
            }
            for (i = 0; i < ipb.getChildBlocksElseBody().size(); ++i) {
                pbc = ipb.getChildBlocksElseBody().get(i);
                sbc = istmt.getElseBody().get(i);
                OptTreePlanChecker.checkProgramCorrectness(pbc, sbc, fnStack);
            }
            OptTreePlanChecker.checkLinksProgramStatementBlock(ipb, isb);
        } else if (pb instanceof ForProgramBlock && sb instanceof ForStatementBlock) {
            ForProgramBlock fpb = (ForProgramBlock)pb;
            ForStatementBlock fsb = (ForStatementBlock)sb;
            ForStatement fstmt = (ForStatement)sb.getStatement(0);
            OptTreePlanChecker.checkHopDagCorrectness(prog, dprog, fsb.getFromHops(), fpb.getFromInstructions(), fnStack);
            OptTreePlanChecker.checkHopDagCorrectness(prog, dprog, fsb.getToHops(), fpb.getToInstructions(), fnStack);
            OptTreePlanChecker.checkHopDagCorrectness(prog, dprog, fsb.getIncrementHops(), fpb.getIncrementInstructions(), fnStack);
            for (int i = 0; i < fpb.getChildBlocks().size(); ++i) {
                ProgramBlock pbc = fpb.getChildBlocks().get(i);
                StatementBlock sbc = fstmt.getBody().get(i);
                OptTreePlanChecker.checkProgramCorrectness(pbc, sbc, fnStack);
            }
            OptTreePlanChecker.checkLinksProgramStatementBlock(fpb, fsb);
        } else if (pb instanceof BasicProgramBlock) {
            BasicProgramBlock bpb = (BasicProgramBlock)pb;
            OptTreePlanChecker.checkHopDagCorrectness(prog, dprog, sb.getHops(), bpb.getInstructions(), fnStack);
        }
    }

    private static void checkHopDagCorrectness(Program prog, DMLProgram dprog, ArrayList<Hop> roots, ArrayList<Instruction> inst, Set<String> fnStack) {
        if (roots != null) {
            for (Hop hop : roots) {
                OptTreePlanChecker.checkHopDagCorrectness(prog, dprog, hop, inst, fnStack);
            }
        }
    }

    private static void checkHopDagCorrectness(Program prog, DMLProgram dprog, Hop root, ArrayList<Instruction> inst, Set<String> fnStack) {
        OptTreePlanChecker.checkFunctionNames(prog, dprog, root, inst, fnStack);
    }

    private static void checkLinksProgramStatementBlock(ProgramBlock pb, StatementBlock sb) {
        if (pb.getStatementBlock() != sb) {
            throw new DMLRuntimeException("Links between programblocks and statementblocks are incorrect (" + pb + ").");
        }
    }

    private static void checkFunctionNames(Program prog, DMLProgram dprog, Hop root, ArrayList<Instruction> inst, Set<String> fnStack) {
        root.resetVisitStatus();
        HashMap<String, FunctionOp> fops = new HashMap<String, FunctionOp>();
        OptTreePlanChecker.getAllFunctionOps(root, fops);
        for (Instruction linst : inst) {
            String fname;
            if (!(linst instanceof FunctionCallCPInstruction)) continue;
            FunctionCallCPInstruction flinst = (FunctionCallCPInstruction)linst;
            String fnamespace = flinst.getNamespace();
            String key = DMLProgram.constructFunctionKey(fnamespace, fname = flinst.getFunctionName());
            if (!fops.containsKey(key)) {
                throw new DMLRuntimeException("Function Check: instruction and hop names differ (" + key + ", " + fops.keySet() + ")");
            }
            if (!prog.getFunctionProgramBlocks().containsKey(key)) {
                throw new DMLRuntimeException("Function Check: function does not exits (" + key + ")");
            }
            FunctionProgramBlock fpb = prog.getFunctionProgramBlock(fnamespace, fname);
            FunctionStatementBlock fsb = dprog.getFunctionStatementBlock(fnamespace, fname);
            if (fnStack.contains(key)) continue;
            fnStack.add(key);
            OptTreePlanChecker.checkProgramCorrectness(fpb, fsb, fnStack);
            fnStack.remove(key);
        }
    }

    private static void getAllFunctionOps(Hop hop, HashMap<String, FunctionOp> memo) {
        if (hop.isVisited()) {
            return;
        }
        if (hop instanceof FunctionOp) {
            FunctionOp fop = (FunctionOp)hop;
            memo.put(fop.getFunctionKey(), fop);
        }
        for (Hop in : hop.getInput()) {
            OptTreePlanChecker.getAllFunctionOps(in, memo);
        }
        hop.setVisited();
    }
}

