/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.fed.MultiReturnParameterizedBuiltinFEDInstruction;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.transform.decode.Decoder;
import org.apache.sysds.runtime.transform.decode.DecoderFactory;
import org.apache.sysds.runtime.transform.encode.Encoder;
import org.apache.sysds.runtime.transform.encode.EncoderComposite;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.encode.EncoderOmit;

public class ParameterizedBuiltinFEDInstruction
extends ComputationFEDInstruction {
    protected final LinkedHashMap<String, String> params;

    protected ParameterizedBuiltinFEDInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out, String opcode, String istr) {
        super(FEDInstruction.FEDType.ParameterizedBuiltin, op, null, null, out, opcode, istr);
        this.params = paramsMap;
    }

    public HashMap<String, String> getParameterMap() {
        return this.params;
    }

    public String getParam(String key) {
        return this.getParameterMap().get(key);
    }

    public static LinkedHashMap<String, String> constructParameterMap(String[] params) {
        LinkedHashMap<String, String> paramMap = new LinkedHashMap<String, String>();
        for (int i = 1; i <= params.length - 2; ++i) {
            String[] parts = params[i].split("=");
            paramMap.put(parts[0], parts[1]);
        }
        return paramMap;
    }

    public static ParameterizedBuiltinFEDInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        CPOperand out = new CPOperand(parts[parts.length - 1]);
        LinkedHashMap<String, String> paramsMap = ParameterizedBuiltinFEDInstruction.constructParameterMap(parts);
        if (opcode.equalsIgnoreCase("replace")) {
            ParameterizedBuiltin func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
            return new ParameterizedBuiltinFEDInstruction(new SimpleOperator(func), paramsMap, out, opcode, str);
        }
        if (opcode.equals("transformapply") || opcode.equals("transformdecode")) {
            return new ParameterizedBuiltinFEDInstruction(null, paramsMap, out, opcode, str);
        }
        throw new DMLRuntimeException("Unsupported opcode (" + opcode + ") for ParameterizedBuiltinFEDInstruction.");
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        String opcode = this.getOpcode();
        if (opcode.equalsIgnoreCase("replace")) {
            MatrixObject mo = (MatrixObject)this.getTarget(ec);
            FederatedRequest fr1 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.getTargetOperand()}, new long[]{mo.getFedMapping().getID()});
            mo.getFedMapping().execute(this.getTID(), true, fr1);
            MatrixObject out = ec.getMatrixObject(this.output);
            out.getDataCharacteristics().set(mo.getDataCharacteristics());
            out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
        } else if (opcode.equalsIgnoreCase("transformdecode")) {
            this.transformDecode(ec);
        } else if (opcode.equalsIgnoreCase("transformapply")) {
            this.transformApply(ec);
        } else {
            throw new DMLRuntimeException("Unknown opcode : " + opcode);
        }
    }

    private void transformDecode(ExecutionContext ec) {
        MatrixObject mo = ec.getMatrixObject(this.params.get("target"));
        FrameBlock meta = ec.getFrameInput(this.params.get("meta"));
        String spec = this.params.get("spec");
        Decoder globalDecoder = DecoderFactory.createDecoder(spec, meta.getColumnNames(), null, meta, (int)mo.getNumColumns());
        FederationMap fedMapping = mo.getFedMapping();
        Types.ValueType[] schema = new Types.ValueType[(int)mo.getNumColumns()];
        long varID = FederationUtils.getNextFedDataID();
        FederationMap decodedMapping = fedMapping.mapParallel(varID, (range, data) -> {
            long[] beginDims = range.getBeginDims();
            long[] endDims = range.getEndDims();
            int colStartBefore = (int)beginDims[1];
            globalDecoder.updateIndexRanges(beginDims, endDims);
            Decoder decoder = globalDecoder.subRangeDecoder((int)beginDims[1] + 1, (int)endDims[1] + 1, colStartBefore);
            FrameBlock metaSlice = new FrameBlock();
            FrameBlock frameBlock = meta;
            synchronized (frameBlock) {
                meta.slice(0, meta.getNumRows() - 1, (int)beginDims[1], (int)endDims[1] - 1, metaSlice);
            }
            try {
                FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1L, new DecodeMatrix(data.getVarID(), varID, metaSlice, decoder))).get();
                if (!response.isSuccessful()) {
                    response.throwExceptionFromResponse();
                }
                Types.ValueType[] subSchema = (Types.ValueType[])response.getData()[0];
                Types.ValueType[] valueTypeArray = schema;
                synchronized (schema) {
                    System.arraycopy(subSchema, 0, schema, colStartBefore, subSchema.length);
                    // ** MonitorExit[var14_15] (shouldn't be in output)
                }
            }
            catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
            {
                return null;
            }
        });
        FrameObject decodedFrame = ec.getFrameObject(this.output);
        decodedFrame.setSchema(globalDecoder.getSchema());
        decodedFrame.getDataCharacteristics().set(mo.getDataCharacteristics());
        decodedFrame.getDataCharacteristics().setCols(globalDecoder.getSchema().length);
        decodedFrame.setFedMapping(decodedMapping);
        ec.releaseFrameInput(this.params.get("meta"));
    }

    private void transformApply(ExecutionContext ec) {
        FrameObject fo = ec.getFrameObject(this.params.get("target"));
        FrameBlock meta = ec.getFrameInput(this.params.get("meta"));
        String spec = this.params.get("spec");
        FederationMap fedMapping = fo.getFedMapping();
        Object[] colNames = new String[(int)fo.getNumColumns()];
        Arrays.fill(colNames, "");
        fedMapping.forEachParallel((arg_0, arg_1) -> ParameterizedBuiltinFEDInstruction.lambda$transformApply$1((String[])colNames, arg_0, arg_1));
        Encoder globalEncoder = EncoderFactory.createEncoder(spec, (String[])colNames, colNames.length, meta);
        List<Encoder> encoders = ((EncoderComposite)globalEncoder).getEncoders();
        int omitIx = -1;
        for (int i = 0; i < encoders.size(); ++i) {
            if (!(encoders.get(i) instanceof EncoderOmit)) continue;
            omitIx = i;
            break;
        }
        if (omitIx != -1) {
            ParameterizedBuiltinFEDInstruction.buildOmitEncoder(fedMapping, encoders, omitIx);
        }
        MultiReturnParameterizedBuiltinFEDInstruction.encodeFederatedFrames(fedMapping, globalEncoder, ec.getMatrixObject(this.getOutputVariableName()));
        ec.releaseFrameInput(this.params.get("meta"));
    }

    private static void buildOmitEncoder(FederationMap fedMapping, List<Encoder> encoders, int omitIx) {
        Encoder omitEncoder = encoders.get(omitIx);
        EncoderOmit newOmit = new EncoderOmit(true);
        fedMapping.forEachParallel((range, data) -> {
            try {
                EncoderOmit subRangeEncoder = (EncoderOmit)omitEncoder.subRangeEncoder(range.asIndexRange().add(1));
                FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1L, new InitRowsToRemoveOmit(data.getVarID(), subRangeEncoder))).get();
                Encoder builtEncoder = (Encoder)response.getData()[0];
                newOmit.mergeAt(builtEncoder, (int)(range.getBeginDims()[0] + 1L), (int)(range.getBeginDims()[1] + 1L));
            }
            catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
            return null;
        });
        encoders.remove(omitIx);
        encoders.add(omitIx, newOmit);
    }

    public CacheableData<?> getTarget(ExecutionContext ec) {
        return ec.getCacheableData(this.params.get("target"));
    }

    private CPOperand getTargetOperand() {
        return new CPOperand(this.params.get("target"), Types.ValueType.FP64, Types.DataType.MATRIX);
    }

    private static /* synthetic */ Void lambda$transformApply$1(String[] colNames, FederatedRange range, FederatedData data) {
        try {
            FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1L, new GetColumnNames(data.getVarID()))).get();
            String[] subRangeColNames = (String[])response.getData()[0];
            System.arraycopy(subRangeColNames, 0, colNames, (int)range.getBeginDims()[1], subRangeColNames.length);
        }
        catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
        return null;
    }

    private static class InitRowsToRemoveOmit
    extends FederatedUDF {
        private static final long serialVersionUID = -8196730717390438411L;
        EncoderOmit _encoder;

        public InitRowsToRemoveOmit(long varID, EncoderOmit encoder) {
            super(new long[]{varID});
            this._encoder = encoder;
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            FrameBlock fb = (FrameBlock)((FrameObject)data[0]).acquireReadAndRelease();
            this._encoder.build(fb);
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[]{this._encoder});
        }
    }

    private static class GetColumnNames
    extends FederatedUDF {
        private static final long serialVersionUID = -7831469841164270004L;

        public GetColumnNames(long varID) {
            super(new long[]{varID});
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            FrameBlock fb = (FrameBlock)((FrameObject)data[0]).acquireReadAndRelease();
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[]{fb.getColumnNames()});
        }
    }

    public static class DecodeMatrix
    extends FederatedUDF {
        private static final long serialVersionUID = 2376756757742169692L;
        private final long _outputID;
        private final FrameBlock _meta;
        private final Decoder _decoder;

        public DecodeMatrix(long input, long outputID, FrameBlock meta, Decoder decoder) {
            super(new long[]{input});
            this._outputID = outputID;
            this._meta = meta;
            this._decoder = decoder;
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            MatrixObject mo = (MatrixObject)data[0];
            MatrixBlock mb = (MatrixBlock)mo.acquireRead();
            String[] colNames = this._meta.getColumnNames();
            FrameBlock fbout = this._decoder.decode(mb, new FrameBlock(this._decoder.getSchema()));
            fbout.setColumnNames(Arrays.copyOfRange(colNames, 0, fbout.getNumColumns()));
            MatrixCharacteristics mc = new MatrixCharacteristics(mo.getDataCharacteristics());
            FrameObject fo = new FrameObject(OptimizerUtils.getUniqueTempFileName(), new MetaDataFormat(mc, Types.FileFormat.BINARY));
            fo.acquireModify(fbout);
            fo.release();
            mo.release();
            ec.setVariable(String.valueOf(this._outputID), fo);
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[]{fo.getSchema()});
        }
    }
}

