/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Stream;
import java.util.zip.Adler32;
import org.apache.commons.lang3.SerializationUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.lops.PickByCount;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.fed.QuantilePickFEDInstruction;
import org.apache.sysds.runtime.instructions.spark.MultiReturnParameterizedBuiltinSPInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
import org.apache.sysds.runtime.transform.encode.Encoder;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
import org.apache.sysds.runtime.util.IndexRange;

public class MultiReturnParameterizedBuiltinFEDInstruction
extends ComputationFEDInstruction {
    protected final List<CPOperand> _outputs;

    private MultiReturnParameterizedBuiltinFEDInstruction(Operator op, CPOperand input1, CPOperand input2, List<CPOperand> outputs, String opcode, String istr) {
        super(FEDInstruction.FEDType.MultiReturnParameterizedBuiltin, op, input1, input2, null, opcode, istr);
        this._outputs = outputs;
    }

    public CPOperand getOutput(int i) {
        return this._outputs.get(i);
    }

    public static MultiReturnParameterizedBuiltinFEDInstruction parseInstruction(MultiReturnParameterizedBuiltinCPInstruction inst, ExecutionContext ec) {
        CacheableData<?> fo;
        if (inst.getOpcode().equals("transformencode") && inst.input1.isFrame() && (fo = ec.getCacheableData(inst.input1)).isFederatedExcept(FTypes.FType.BROADCAST)) {
            return MultiReturnParameterizedBuiltinFEDInstruction.parseInstruction(inst);
        }
        return null;
    }

    public static MultiReturnParameterizedBuiltinFEDInstruction parseInstruction(MultiReturnParameterizedBuiltinSPInstruction inst, ExecutionContext ec) {
        CacheableData<?> fo;
        if (inst.getOpcode().equals("transformencode") && inst.input1.isFrame() && (fo = ec.getCacheableData(inst.input1)).isFederatedExcept(FTypes.FType.BROADCAST)) {
            return MultiReturnParameterizedBuiltinFEDInstruction.parseInstruction(inst);
        }
        return null;
    }

    private static MultiReturnParameterizedBuiltinFEDInstruction parseInstruction(MultiReturnParameterizedBuiltinCPInstruction instr) {
        return new MultiReturnParameterizedBuiltinFEDInstruction(instr.getOperator(), instr.input1, instr.input2, instr.getOutputs(), instr.getOpcode(), instr.getInstructionString());
    }

    private static MultiReturnParameterizedBuiltinFEDInstruction parseInstruction(MultiReturnParameterizedBuiltinSPInstruction instr) {
        return new MultiReturnParameterizedBuiltinFEDInstruction(instr.getOperator(), instr.input1, instr.input2, instr.getOutputs(), instr.getOpcode(), instr.getInstructionString());
    }

    public static MultiReturnParameterizedBuiltinFEDInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        ArrayList<CPOperand> outputs = new ArrayList<CPOperand>();
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase("transformencode")) {
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            outputs.add(new CPOperand(parts[3], Types.ValueType.FP64, Types.DataType.MATRIX));
            outputs.add(new CPOperand(parts[4], Types.ValueType.STRING, Types.DataType.FRAME));
            return new MultiReturnParameterizedBuiltinFEDInstruction(null, in1, in2, outputs, opcode, str);
        }
        throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + opcode);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        boolean containsEquiWidthEncoder;
        FrameObject fin = ec.getFrameObject(this.input1.getName());
        String spec = ec.getScalarInput(this.input2).getStringValue();
        Object[] colNames = new String[(int)fin.getNumColumns()];
        Arrays.fill(colNames, "");
        MultiColumnEncoder globalEncoder = new MultiColumnEncoder(new ArrayList<ColumnEncoderComposite>());
        FederationMap fedMapping = fin.getFedMapping();
        boolean bl = containsEquiWidthEncoder = !fin.isFederated(FTypes.FType.ROW) && spec.toLowerCase().contains("equi-height");
        if (containsEquiWidthEncoder) {
            EncoderColnames ret = this.createGlobalEncoderWithEquiHeight(ec, fin, spec);
            globalEncoder = ret._encoder;
            colNames = ret._colnames;
        } else {
            MultiColumnEncoder finalGlobalEncoder = globalEncoder;
            Object[] finalColNames = colNames;
            fedMapping.forEachParallel((arg_0, arg_1) -> MultiReturnParameterizedBuiltinFEDInstruction.lambda$processInstruction$0(spec, finalGlobalEncoder, (String[])finalColNames, arg_0, arg_1));
            globalEncoder = finalGlobalEncoder;
            colNames = finalColNames;
        }
        if (ColumnEncoderRecode.SORT_RECODE_MAP) {
            globalEncoder.applyToAll(ColumnEncoderRecode.class, ColumnEncoderRecode::sortCPRecodeMaps);
        }
        FrameBlock meta = new FrameBlock((int)fin.getNumColumns(), Types.ValueType.STRING);
        meta.setColumnNames((String[])colNames);
        globalEncoder.getMetaData(meta);
        globalEncoder.initMetaData(meta);
        MultiReturnParameterizedBuiltinFEDInstruction.encodeFederatedFrames(fedMapping, globalEncoder, ec.getMatrixObject(this.getOutput(0)));
        ec.setFrameOutput(this.getOutput(1).getName(), meta);
    }

    public EncoderColnames createGlobalEncoderWithEquiHeight(ExecutionContext ec, FrameObject fin, String spec) {
        MultiColumnEncoder globalEncoder = new MultiColumnEncoder(new ArrayList<ColumnEncoderComposite>());
        String[] colNames = new String[(int)fin.getNumColumns()];
        HashMap quantilesPerColumn = new HashMap();
        FederationMap fedMapping = fin.getFedMapping();
        fedMapping.forEachParallel((range, data) -> {
            int columnOffset = (int)range.getBeginDims()[1];
            Future<FederatedResponse> responseFuture = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1L, new CreateFrameEncoder(data.getVarID(), spec, columnOffset + 1)));
            try {
                FederatedResponse response = responseFuture.get();
                MultiColumnEncoder encoder = (MultiColumnEncoder)response.getData()[0];
                for (Encoder encoder2 : encoder.getColumnEncoders()) {
                    if (!(encoder2 instanceof ColumnEncoderComposite)) continue;
                    for (Encoder encoder3 : ((ColumnEncoderComposite)encoder2).getEncoders()) {
                        if (!(encoder3 instanceof ColumnEncoderBin) || ((ColumnEncoderBin)encoder3).getBinMethod() != ColumnEncoderBin.BinMethod.EQUI_HEIGHT) continue;
                        double quantilrRange = (double)fin.getNumRows() / (double)((ColumnEncoderBin)encoder3).getNumBin();
                        double[] quantiles = new double[((ColumnEncoderBin)encoder3).getNumBin()];
                        for (int i = 0; i < quantiles.length; ++i) {
                            quantiles[i] = quantilrRange * (double)(i + 1);
                        }
                        quantilesPerColumn.put(((ColumnEncoderBin)encoder3).getColID() + columnOffset - 1, quantiles);
                    }
                }
                MultiColumnEncoder multiColumnEncoder = globalEncoder;
                synchronized (multiColumnEncoder) {
                    globalEncoder.mergeAt(encoder, columnOffset, (int)(range.getBeginDims()[0] + 1L));
                }
                String[] subRangeColNames = (String[])response.getData()[1];
                System.arraycopy(subRangeColNames, 0, colNames, (int)range.getBeginDims()[1], subRangeColNames.length);
            }
            catch (Exception e) {
                throw new DMLRuntimeException("Federated encoder creation failed: ", e);
            }
            return null;
        });
        HashMap<Integer, double[]> equiHeightBinsPerColumn = new HashMap<Integer, double[]>();
        for (Map.Entry entry : quantilesPerColumn.entrySet()) {
            QuantilePickFEDInstruction quantileInstr = new QuantilePickFEDInstruction(null, this.input1, this.output, PickByCount.OperationTypes.VALUEPICK, true, "qpick", "");
            MatrixBlock matrixBlock = quantileInstr.getEquiHeightBins(ec, (Integer)entry.getKey(), (double[])entry.getValue());
            equiHeightBinsPerColumn.put((Integer)entry.getKey(), matrixBlock.getDenseBlockValues());
        }
        for (Encoder encoder : globalEncoder.getColumnEncoders()) {
            if (!(encoder instanceof ColumnEncoderComposite)) continue;
            for (Encoder encoder2 : ((ColumnEncoderComposite)encoder).getEncoders()) {
                if (!(encoder2 instanceof ColumnEncoderBin) || ((ColumnEncoderBin)encoder2).getBinMethod() != ColumnEncoderBin.BinMethod.EQUI_HEIGHT) continue;
                ((ColumnEncoderBin)encoder2).build(null, (double[])equiHeightBinsPerColumn.get(((ColumnEncoderBin)encoder2).getColID() - 1));
            }
            ((ColumnEncoderComposite)encoder).updateAllDCEncoders();
        }
        return new EncoderColnames(globalEncoder, colNames);
    }

    public static void encodeFederatedFrames(FederationMap fedMapping, MultiColumnEncoder globalencoder, MatrixObject transformedMat) {
        long varID = FederationUtils.getNextFedDataID();
        LongAdder nnz = new LongAdder();
        FederationMap tfFedMap = fedMapping.mapParallel(varID, (range, data) -> {
            long[] beginDims = range.getBeginDims();
            long[] endDims = range.getEndDims();
            IndexRange ixRange = new IndexRange(beginDims[0], endDims[0], beginDims[1], endDims[1]).add(1);
            IndexRange ixRangeInv = new IndexRange(0L, beginDims[0], 0L, beginDims[1]);
            MultiColumnEncoder encoder = globalencoder.subRangeEncoder(ixRange);
            encoder.updateIndexRanges(beginDims, endDims, globalencoder.getNumExtraCols(ixRangeInv));
            try {
                FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1L, new ExecuteFrameEncoder(data.getVarID(), varID, encoder))).get();
                if (!response.isSuccessful()) {
                    response.throwExceptionFromResponse();
                }
                nnz.add((Long)response.getData()[0]);
            }
            catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
            return null;
        });
        transformedMat.getDataCharacteristics().setDimension(tfFedMap.getMaxIndexInRange(0), tfFedMap.getMaxIndexInRange(1)).setNonZeros(nnz.longValue());
        transformedMat.setFedMapping(tfFedMap);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static /* synthetic */ Void lambda$processInstruction$0(String spec, MultiColumnEncoder finalGlobalEncoder, String[] finalColNames, FederatedRange range, FederatedData data) {
        int columnOffset = (int)range.getBeginDims()[1];
        Future<FederatedResponse> responseFuture = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1L, new CreateFrameEncoder(data.getVarID(), spec, columnOffset + 1)));
        try {
            FederatedResponse response = responseFuture.get();
            MultiColumnEncoder encoder = (MultiColumnEncoder)response.getData()[0];
            MultiColumnEncoder multiColumnEncoder = finalGlobalEncoder;
            synchronized (multiColumnEncoder) {
                finalGlobalEncoder.mergeAt(encoder, columnOffset, (int)(range.getBeginDims()[0] + 1L));
            }
            String[] subRangeColNames = (String[])response.getData()[1];
            System.arraycopy(subRangeColNames, 0, finalColNames, (int)range.getBeginDims()[1], subRangeColNames.length);
        }
        catch (Exception e) {
            throw new DMLRuntimeException("Federated encoder creation failed: ", e);
        }
        return null;
    }

    public static class ExecuteFrameEncoder
    extends FederatedUDF {
        private static final long serialVersionUID = 6034440964680578276L;
        private final long _outputID;
        private final MultiColumnEncoder _encoder;

        public ExecuteFrameEncoder(long input, long output, MultiColumnEncoder encoder) {
            super(new long[]{input});
            this._outputID = output;
            this._encoder = encoder;
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            FrameBlock fb = (FrameBlock)((FrameObject)data[0]).acquireReadAndRelease();
            this._encoder.applyColumnOffset();
            MatrixBlock mbout = this._encoder.apply(fb, 1);
            MatrixObject mo = ExecutionContext.createMatrixObject(mbout);
            ec.setVariable(String.valueOf(this._outputID), mo);
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS_EMPTY, mbout.getNonZeros());
        }

        @Override
        public List<Long> getOutputIds() {
            return new ArrayList<Long>(Arrays.asList(this._outputID));
        }

        @Override
        public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
            LineageItem[] liUdfInputs = (LineageItem[])Arrays.stream(this.getInputIDs()).mapToObj(id -> ec.getLineage().get(String.valueOf(id))).toArray(LineageItem[]::new);
            Adler32 checksum = new Adler32();
            byte[] bytes = SerializationUtils.serialize((Serializable)this._encoder);
            checksum.update(bytes, 0, bytes.length);
            CPOperand encoder = new CPOperand(String.valueOf(checksum.getValue()), Types.ValueType.INT64, Types.DataType.SCALAR, true);
            LineageItem[] otherInputs = LineageItemUtils.getLineage(ec, encoder);
            LineageItem[] liInputs = (LineageItem[])Stream.concat(Arrays.stream(liUdfInputs), Arrays.stream(otherInputs)).toArray(LineageItem[]::new);
            return Pair.of((Object)String.valueOf(this._outputID), (Object)new LineageItem(this.getClass().getSimpleName(), liInputs));
        }
    }

    public static class CreateFrameEncoder
    extends FederatedUDF {
        private static final long serialVersionUID = 2376756757742169692L;
        private final String _spec;
        private final int _offset;

        public CreateFrameEncoder(long input, String spec, int offset) {
            super(new long[]{input});
            this._spec = spec;
            this._offset = offset;
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            FrameObject fo = (FrameObject)data[0];
            FrameBlock fb = (FrameBlock)fo.acquireRead();
            String[] colNames = fb.getColumnNames();
            MultiColumnEncoder encoder = EncoderFactory.createEncoder(this._spec, colNames, fb.getNumColumns(), null, this._offset, this._offset + fb.getNumColumns());
            encoder.build(fb, 1);
            fo.release();
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[]{encoder, fb.getColumnNames()});
        }

        @Override
        public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
            return null;
        }
    }

    private class EncoderColnames {
        public final MultiColumnEncoder _encoder;
        public final String[] _colnames;

        public EncoderColnames(MultiColumnEncoder encoder, String[] colnames) {
            this._encoder = encoder;
            this._colnames = colnames;
        }
    }
}

