/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.paramserv;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.paramserv.FederatedPSControlThread;
import org.apache.sysds.runtime.controlprogram.paramserv.LocalParamServer;
import org.apache.sysds.runtime.controlprogram.paramserv.NativeHEHelper;
import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.PublicKey;
import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALServer;
import org.apache.sysds.runtime.instructions.cp.CiphertextMatrix;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.PlaintextMatrix;
import org.apache.sysds.utils.stats.ParamServStatistics;
import org.apache.sysds.utils.stats.Timing;

public class HEParamServer
extends LocalParamServer {
    private int _thread_counter = 0;
    private final List<FederatedPSControlThread> _threads;
    private final List<Object> _result_buffer;
    private Object _result;
    private final SEALServer _seal_server = new SEALServer();
    private Timing commTimer;

    public static HEParamServer create(ListObject model, String aggFunc, Statement.PSUpdateType updateType, Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, int numBackupWorkers) {
        NativeHEHelper.initialize();
        return new HEParamServer(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, numBackupWorkers);
    }

    private HEParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, int numBackupWorkers) {
        super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, true, numBackupWorkers);
        this._threads = Collections.synchronizedList(new ArrayList(workerNum));
        for (int i = 0; i < this.getNumWorkers(); ++i) {
            this._threads.add(null);
        }
        this._result_buffer = new ArrayList<Object>(workerNum);
        this.resetResultBuffer();
    }

    public void registerThread(int thread_id, FederatedPSControlThread thread) {
        this._threads.set(thread_id, thread);
    }

    private synchronized void resetResultBuffer() {
        this._result_buffer.clear();
        for (int i = 0; i < this.getNumWorkers(); ++i) {
            this._result_buffer.add(null);
        }
    }

    public byte[] generateA() {
        return this._seal_server.generateA();
    }

    public PublicKey aggregatePartialPublicKeys(PublicKey[] partial_public_keys) {
        return this._seal_server.aggregatePartialPublicKeys(partial_public_keys);
    }

    private synchronized <T, U> U collectAndDo(int workerId, T obj, Function<List<T>, U> f) {
        this._result_buffer.set(workerId, obj);
        ++this._thread_counter;
        if (this._thread_counter == this.getNumWorkers()) {
            List buf = this._result_buffer.stream().map(x -> x).collect(Collectors.toList());
            this._result = f.apply(buf);
            this.resetResultBuffer();
            this._thread_counter = 0;
            this.notifyAll();
        } else {
            try {
                this.wait();
            }
            catch (InterruptedException i) {
                throw new RuntimeException("thread interrupted");
            }
        }
        return (U)this._result;
    }

    private CiphertextMatrix[] homomorphicAggregation(List<ListObject> encrypted_models) {
        Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
        CiphertextMatrix[] result = new CiphertextMatrix[encrypted_models.get(0).getLength()];
        IntStream.range(0, encrypted_models.get(0).getLength()).forEach(matrix_idx -> {
            CiphertextMatrix[] summands = new CiphertextMatrix[encrypted_models.size()];
            for (int i = 0; i < encrypted_models.size(); ++i) {
                summands[i] = (CiphertextMatrix)((ListObject)encrypted_models.get(i)).getData(matrix_idx);
            }
            result[matrix_idx] = this._seal_server.accumulateCiphertexts(summands);
        });
        if (tAgg != null) {
            ParamServStatistics.accHEAccumulation((long)tAgg.stop());
        }
        return result;
    }

    private Void homomorphicAverage(CiphertextMatrix[] encrypted_sums, List<PlaintextMatrix[]> partial_decryptions) {
        Timing tDecrypt = DMLScript.STATISTICS ? new Timing(true) : null;
        MatrixObject[] result = new MatrixObject[partial_decryptions.get(0).length];
        IntStream.range(0, partial_decryptions.get(0).length).forEach(matrix_idx -> {
            PlaintextMatrix[] partial_plaintexts = new PlaintextMatrix[partial_decryptions.size()];
            for (int i = 0; i < partial_decryptions.size(); ++i) {
                partial_plaintexts[i] = ((PlaintextMatrix[])partial_decryptions.get(i))[matrix_idx];
            }
            result[matrix_idx] = this._seal_server.average(encrypted_sums[matrix_idx], partial_plaintexts);
        });
        ListObject old_model = this.getResult();
        ListObject new_model = new ListObject(old_model);
        for (int i = 0; i < new_model.getLength(); ++i) {
            new_model.set(i, (Data)result[i]);
        }
        if (tDecrypt != null) {
            ParamServStatistics.accHEDecryptionTime((long)tDecrypt.stop());
        }
        this.updateAndBroadcastModel(new_model, null);
        return null;
    }

    private void startCommTimer() {
        this.commTimer = new Timing(true);
    }

    private long stopCommTimer() {
        return (long)this.commTimer.stop();
    }

    @Override
    public void push(int workerID, ListObject encrypted_model) {
        CiphertextMatrix[] homomorphic_sum = this.collectAndDo(workerID, encrypted_model, x -> {
            CiphertextMatrix[] res = this.homomorphicAggregation((List<ListObject>)x);
            this.startCommTimer();
            return res;
        });
        PlaintextMatrix[] partial_decryption = this._threads.get(workerID).getPartialDecryption(homomorphic_sum);
        this.collectAndDo(workerID, partial_decryption, x -> {
            ParamServStatistics.accFedNetworkTime(this.stopCommTimer());
            return this.homomorphicAverage(homomorphic_sum, (List<PlaintextMatrix[]>)x);
        });
    }
}

