/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.estim;

import java.util.BitSet;
import java.util.stream.IntStream;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.estim.MMNode;
import org.apache.sysds.hops.estim.SparsityEstimator;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

public class EstimatorBitsetMM
extends SparsityEstimator {
    @Override
    public DataCharacteristics estim(MMNode root) {
        BitsetMatrix m1Map = this.getCachedSynopsis(root.getLeft());
        BitsetMatrix m2Map = this.getCachedSynopsis(root.getRight());
        BitsetMatrix outMap = EstimatorBitsetMM.estimInternal(m1Map, m2Map, root.getOp());
        root.setSynopsis(outMap);
        return root.setDataCharacteristics(new MatrixCharacteristics((long)outMap.getNumRows(), (long)outMap.getNumColumns(), outMap.getNonZeros()));
    }

    @Override
    public double estim(MatrixBlock m1, MatrixBlock m2) {
        return this.estim(m1, m2, SparsityEstimator.OpCode.MM);
    }

    @Override
    public double estim(MatrixBlock m1, MatrixBlock m2, SparsityEstimator.OpCode op) {
        if (this.isExactMetadataOp(op)) {
            return this.estimExactMetaData(m1.getDataCharacteristics(), m2.getDataCharacteristics(), op).getSparsity();
        }
        BitsetMatrix m1Map = EstimatorBitsetMM.createBitset(m1);
        BitsetMatrix m2Map = m1 == m2 ? m1Map : EstimatorBitsetMM.createBitset(m2);
        BitsetMatrix outMap = EstimatorBitsetMM.estimInternal(m1Map, m2Map, op);
        return OptimizerUtils.getSparsity(outMap.getNumRows(), outMap.getNumColumns(), outMap.getNonZeros());
    }

    @Override
    public double estim(MatrixBlock m, SparsityEstimator.OpCode op) {
        if (this.isExactMetadataOp(op)) {
            return this.estimExactMetaData(m.getDataCharacteristics(), null, op).getSparsity();
        }
        BitsetMatrix m1Map = EstimatorBitsetMM.createBitset(m);
        BitsetMatrix outMap = EstimatorBitsetMM.estimInternal(m1Map, null, op);
        return OptimizerUtils.getSparsity(outMap.getNumRows(), outMap.getNumColumns(), outMap.getNonZeros());
    }

    private BitsetMatrix getCachedSynopsis(MMNode node) {
        if (node == null) {
            return null;
        }
        if (node.isLeaf() && node.getSynopsis() == null) {
            node.setSynopsis(EstimatorBitsetMM.createBitset(node.getData()));
        } else if (!node.isLeaf()) {
            this.estim(node);
        }
        return (BitsetMatrix)node.getSynopsis();
    }

    private static BitsetMatrix estimInternal(BitsetMatrix m1Map, BitsetMatrix m2Map, SparsityEstimator.OpCode op) {
        switch (op) {
            case MM: {
                return m1Map.matMult(m2Map);
            }
            case MULT: {
                return m1Map.and(m2Map);
            }
            case PLUS: {
                return m1Map.or(m2Map);
            }
            case RBIND: {
                return m1Map.rbind(m2Map);
            }
            case CBIND: {
                return m1Map.cbind(m2Map);
            }
            case NEQZERO: {
                return m1Map;
            }
            case EQZERO: {
                return m1Map.flip();
            }
            case TRANS: {
                return m1Map.transpose();
            }
        }
        throw new NotImplementedException();
    }

    public static BitsetMatrix createBitset(int m, int n) {
        return (long)m * (long)n < Integer.MAX_VALUE ? new BitsetMatrix1(m, n) : new BitsetMatrix2(m, n);
    }

    public static BitsetMatrix createBitset(MatrixBlock in) {
        return in.getLength() < Integer.MAX_VALUE ? new BitsetMatrix1(in) : new BitsetMatrix2(in);
    }

    public static class BitsetMatrix2
    extends BitsetMatrix {
        private BitSet[] _data;

        public BitsetMatrix2(int rlen, int clen) {
            super(rlen, clen);
            this._data = new BitSet[this._rlen];
        }

        public BitsetMatrix2(MatrixBlock in) {
            this(in.getNumRows(), in.getNumColumns());
            this.init(in);
        }

        @Override
        protected BitsetMatrix createBitSetMatrix(int rlen, int clen) {
            return new BitsetMatrix2(rlen, clen);
        }

        @Override
        protected void buildIntern(MatrixBlock in, int rl, int ru) {
            int clen = in.getNumColumns();
            if (in.isInSparseFormat()) {
                SparseBlock sblock = in.getSparseBlock();
                for (int i = rl; i < ru; ++i) {
                    if (sblock.isEmpty(i)) continue;
                    BitSet lbs = this._data[i] = new BitSet(clen);
                    int alen = sblock.size(i);
                    int apos = sblock.pos(i);
                    int[] aix = sblock.indexes(i);
                    for (int k = apos; k < apos + alen; ++k) {
                        lbs.set(aix[k]);
                    }
                }
            } else {
                DenseBlock dblock = in.getDenseBlock();
                for (int i = rl; i < ru; ++i) {
                    BitSet lbs = this._data[i] = new BitSet(clen);
                    double[] avals = dblock.values(i);
                    int aix = dblock.pos(i);
                    for (int j = 0; j < in.getNumColumns(); ++j) {
                        if (avals[aix + j] == 0.0) continue;
                        lbs.set(j);
                    }
                }
            }
        }

        @Override
        protected long matMultIntern(BitsetMatrix bsb2, BitsetMatrix bsc2, int rl, int ru) {
            BitsetMatrix2 bsb = (BitsetMatrix2)bsb2;
            BitsetMatrix2 bsc = (BitsetMatrix2)bsc2;
            int cd = this._clen;
            int n = bsb._clen;
            long lnnz = 0L;
            for (int i = rl; i < ru; ++i) {
                BitSet a = this._data[i];
                if (a == null) continue;
                BitSet c = bsc._data[i] = new BitSet(n);
                for (int k = 0; k < cd; ++k) {
                    BitSet b = bsb._data[k];
                    if (!a.get(k) || b == null) continue;
                    c.or(b);
                }
                lnnz += (long)c.cardinality();
            }
            return lnnz;
        }

        @Override
        public BitsetMatrix and(BitsetMatrix bsb) {
            if (!(bsb instanceof BitsetMatrix2)) {
                throw new HopsException("Incompatible bitset types: " + this.getClass().getSimpleName() + " and " + bsb.getClass().getSimpleName());
            }
            BitsetMatrix2 b = (BitsetMatrix2)bsb;
            BitsetMatrix2 ret = new BitsetMatrix2(this.getNumRows(), this.getNumColumns());
            for (int i = 0; i < this._data.length; ++i) {
                ret._data[i] = (BitSet)this._data[i].clone();
                ret._data[i].and(b._data[i]);
                ret._nonZeros += (long)ret._data[i].cardinality();
            }
            return ret;
        }

        @Override
        public BitsetMatrix or(BitsetMatrix bsb) {
            if (!(bsb instanceof BitsetMatrix2)) {
                throw new HopsException("Incompatible bitset types: " + this.getClass().getSimpleName() + " and " + bsb.getClass().getSimpleName());
            }
            BitsetMatrix2 b = (BitsetMatrix2)bsb;
            BitsetMatrix2 ret = new BitsetMatrix2(this.getNumRows(), this.getNumColumns());
            for (int i = 0; i < this._data.length; ++i) {
                ret._data[i] = (BitSet)this._data[i].clone();
                ret._data[i].or(b._data[i]);
                ret._nonZeros += (long)ret._data[i].cardinality();
            }
            return ret;
        }

        @Override
        public BitsetMatrix rbind(BitsetMatrix bsb) {
            if (!(bsb instanceof BitsetMatrix2)) {
                throw new HopsException("Incompatible bitset types: " + this.getClass().getSimpleName() + " and " + bsb.getClass().getSimpleName());
            }
            BitsetMatrix2 b = (BitsetMatrix2)bsb;
            BitsetMatrix2 ret = new BitsetMatrix2(this.getNumRows() + bsb.getNumRows(), this.getNumColumns());
            System.arraycopy(this._data, 0, ret._data, 0, this._rlen);
            System.arraycopy(b._data, 0, ret._data, this._rlen, b._rlen);
            return ret;
        }

        @Override
        protected BitsetMatrix cbind(BitsetMatrix bsb) {
            int i;
            if (!(bsb instanceof BitsetMatrix2)) {
                throw new HopsException("Incompatible bitset types: " + this.getClass().getSimpleName() + " and " + bsb.getClass().getSimpleName());
            }
            BitsetMatrix2 b = (BitsetMatrix2)bsb;
            BitsetMatrix2 ret = new BitsetMatrix2(this.getNumRows(), this.getNumColumns() + bsb.getNumColumns());
            for (i = 0; i < this.getNumRows(); ++i) {
                ret._data[i] = (BitSet)this._data[i].clone();
            }
            for (i = 0; i < this.getNumRows(); ++i) {
                for (int j = 0; j < b.getNumColumns(); ++j) {
                    if (!b.get(i, j)) continue;
                    ret.set(i, this.getNumColumns() + j);
                }
                ret._nonZeros += (long)ret._data[i].cardinality();
            }
            return ret;
        }

        @Override
        public BitsetMatrix flip() {
            BitsetMatrix2 ret = new BitsetMatrix2(this.getNumRows(), this.getNumColumns());
            for (int i = 0; i < this._data.length; ++i) {
                ret._data[i] = (BitSet)this._data[i].clone();
                ret._data[i].flip(0, this._data[i].size());
                ret._nonZeros += (long)ret._data[i].cardinality();
            }
            return ret;
        }

        @Override
        public boolean get(int r, int c) {
            return this._data[r].get(c);
        }

        @Override
        public void set(int r, int c) {
            this._data[r].set(c);
        }
    }

    public static class BitsetMatrix1
    extends BitsetMatrix {
        private final int _rowLen;
        private final long[] _data;

        public BitsetMatrix1(int rlen, int clen) {
            super(rlen, clen);
            this._rowLen = (int)Math.ceil((double)clen / 64.0);
            this._data = new long[rlen * this._rowLen];
        }

        public BitsetMatrix1(MatrixBlock in) {
            this(in.getNumRows(), in.getNumColumns());
            this.init(in);
        }

        @Override
        protected BitsetMatrix createBitSetMatrix(int rlen, int clen) {
            return new BitsetMatrix1(rlen, clen);
        }

        @Override
        protected void buildIntern(MatrixBlock in, int rl, int ru) {
            if (in.isInSparseFormat()) {
                SparseBlock sblock = in.getSparseBlock();
                for (int i = rl; i < ru; ++i) {
                    if (sblock.isEmpty(i)) continue;
                    int alen = sblock.size(i);
                    int apos = sblock.pos(i);
                    int[] aix = sblock.indexes(i);
                    for (int k = apos; k < apos + alen; ++k) {
                        this.set(i, aix[k]);
                    }
                }
            } else {
                DenseBlock dblock = in.getDenseBlock();
                for (int i = rl; i < ru; ++i) {
                    double[] avals = dblock.values(i);
                    int aix = dblock.pos(i);
                    for (int j = 0; j < in.getNumColumns(); ++j) {
                        if (avals[aix + j] == 0.0) continue;
                        this.set(i, j);
                    }
                }
            }
        }

        @Override
        protected long matMultIntern(BitsetMatrix bsb2, BitsetMatrix bsc2, int rl, int ru) {
            BitsetMatrix1 bsb = (BitsetMatrix1)bsb2;
            BitsetMatrix1 bsc = (BitsetMatrix1)bsc2;
            long[] b = bsb._data;
            long[] c = bsc._data;
            int cd = this._clen;
            int n = bsb._clen;
            int n64 = bsb._rowLen;
            int blocksizeI = 32;
            int blocksizeK = 24;
            int blocksizeJ = 65536;
            long lnnz = 0L;
            for (int bi = rl; bi < ru; bi += 32) {
                int bimin = Math.min(ru, bi + 32);
                for (int bk = 0; bk < cd; bk += 24) {
                    int bkmin = Math.min(cd, bk + 24);
                    for (int bj = 0; bj < n; bj += 65536) {
                        int bjlen64 = (int)Math.ceil((double)(Math.min(n, bj + 65536) - bj) / 64.0);
                        int bj64 = bj / 64;
                        int i = bi;
                        int off = i * this._rowLen;
                        while (i < bimin) {
                            for (int k = bk; k < bkmin; ++k) {
                                if (!this.getCol(off, k)) continue;
                                BitsetMatrix1.or(b, c, k * n64 + bj64, i * n64 + bj64, bjlen64);
                            }
                            ++i;
                            off += this._rowLen;
                        }
                    }
                }
                lnnz += (long)BitsetMatrix1.card(c, bi * n64, (bimin - bi) * n64);
            }
            return lnnz;
        }

        @Override
        public BitsetMatrix and(BitsetMatrix bsb) {
            if (!(bsb instanceof BitsetMatrix1)) {
                throw new HopsException("Incompatible bitset types: " + this.getClass().getSimpleName() + " and " + bsb.getClass().getSimpleName());
            }
            BitsetMatrix1 b = (BitsetMatrix1)bsb;
            BitsetMatrix1 ret = new BitsetMatrix1(this.getNumRows(), this.getNumColumns());
            for (int i = 0; i < this._data.length; ++i) {
                ret._data[i] = this._data[i] & b._data[i];
            }
            ret._nonZeros = BitsetMatrix1.card(ret._data, 0, ret._data.length);
            return ret;
        }

        @Override
        public BitsetMatrix or(BitsetMatrix bsb) {
            if (!(bsb instanceof BitsetMatrix1)) {
                throw new HopsException("Incompatible bitset types: " + this.getClass().getSimpleName() + " and " + bsb.getClass().getSimpleName());
            }
            BitsetMatrix1 b = (BitsetMatrix1)bsb;
            BitsetMatrix1 ret = new BitsetMatrix1(this.getNumRows(), this.getNumColumns());
            for (int i = 0; i < this._data.length; ++i) {
                ret._data[i] = this._data[i] | b._data[i];
            }
            ret._nonZeros = BitsetMatrix1.card(ret._data, 0, ret._data.length);
            return ret;
        }

        @Override
        public BitsetMatrix rbind(BitsetMatrix bsb) {
            if (!(bsb instanceof BitsetMatrix1)) {
                throw new HopsException("Incompatible bitset types: " + this.getClass().getSimpleName() + " and " + bsb.getClass().getSimpleName());
            }
            BitsetMatrix1 b = (BitsetMatrix1)bsb;
            BitsetMatrix1 ret = new BitsetMatrix1(this.getNumRows() + bsb.getNumRows(), this.getNumColumns());
            System.arraycopy(this._data, 0, ret._data, 0, this._rlen * this._rowLen);
            System.arraycopy(b._data, 0, ret._data, this._rlen * this._rowLen, b._rlen * this._rowLen);
            ret._nonZeros = BitsetMatrix1.card(ret._data, 0, ret._data.length);
            return ret;
        }

        @Override
        public BitsetMatrix cbind(BitsetMatrix bsb) {
            int i;
            if (!(bsb instanceof BitsetMatrix1)) {
                throw new HopsException("Incompatible bitset types: " + this.getClass().getSimpleName() + " and " + bsb.getClass().getSimpleName());
            }
            BitsetMatrix1 b = (BitsetMatrix1)bsb;
            BitsetMatrix1 ret = new BitsetMatrix1(this.getNumRows(), this.getNumColumns() + bsb.getNumColumns());
            for (i = 0; i < this.getNumRows(); ++i) {
                System.arraycopy(this._data, i * this._rowLen, ret._data, i * ret._rowLen, this._rowLen);
            }
            for (i = 0; i < this.getNumRows(); ++i) {
                for (int j = 0; j < b.getNumColumns(); ++j) {
                    if (!b.get(i, j)) continue;
                    ret.set(i, this.getNumColumns() + j);
                }
            }
            ret._nonZeros = BitsetMatrix1.card(ret._data, 0, ret._data.length);
            return ret;
        }

        @Override
        public BitsetMatrix flip() {
            BitsetMatrix1 ret = new BitsetMatrix1(this.getNumRows(), this.getNumColumns());
            for (int i = 0; i < this._data.length; ++i) {
                ret._data[i] = this._data[i] ^ 0xFFFFFFFFFFFFFFFFL;
            }
            ret._nonZeros = (long)this.getNumRows() * (long)this.getNumColumns() - this.getNonZeros();
            return ret;
        }

        @Override
        public void set(int r, int c) {
            int off = r * this._rowLen;
            int wordIndex = BitsetMatrix1.wordIndex(c);
            int n = off + wordIndex;
            this._data[n] = this._data[n] | 1L << c;
        }

        @Override
        public boolean get(int r, int c) {
            int off = r * this._rowLen;
            int wordIndex = BitsetMatrix1.wordIndex(c);
            return (this._data[off + wordIndex] & 1L << c) != 0L;
        }

        private boolean getCol(int off, int c) {
            int wordIndex = BitsetMatrix1.wordIndex(c);
            return (this._data[off + wordIndex] & 1L << c) != 0L;
        }

        private static int wordIndex(int bitIndex) {
            return bitIndex >> 6;
        }

        private static int card(long[] c, int ci, int len) {
            int sum = 0;
            for (int i = ci; i < ci + len; ++i) {
                sum += Long.bitCount(c[i]);
            }
            return sum;
        }

        private static void or(long[] b, long[] c, int bi, int ci, int len) {
            int bn = len % 8;
            int i = 0;
            while (i < bn) {
                int n = ci++;
                c[n] = c[n] | b[bi];
                ++i;
                ++bi;
            }
            i = bn;
            while (i < len) {
                int n = ci + 0;
                c[n] = c[n] | b[bi + 0];
                int n2 = ci + 1;
                c[n2] = c[n2] | b[bi + 1];
                int n3 = ci + 2;
                c[n3] = c[n3] | b[bi + 2];
                int n4 = ci + 3;
                c[n4] = c[n4] | b[bi + 3];
                int n5 = ci + 4;
                c[n5] = c[n5] | b[bi + 4];
                int n6 = ci + 5;
                c[n6] = c[n6] | b[bi + 5];
                int n7 = ci + 6;
                c[n7] = c[n7] | b[bi + 6];
                int n8 = ci + 7;
                c[n8] = c[n8] | b[bi + 7];
                i += 8;
                bi += 8;
                ci += 8;
            }
        }
    }

    public static abstract class BitsetMatrix {
        protected final int _rlen;
        protected final int _clen;
        protected long _nonZeros;

        public BitsetMatrix(int rlen, int clen) {
            this._rlen = rlen;
            this._clen = clen;
            this._nonZeros = 0L;
        }

        public int getNumRows() {
            return this._rlen;
        }

        public int getNumColumns() {
            return this._clen;
        }

        public long getNonZeros() {
            return this._nonZeros;
        }

        public abstract boolean get(int var1, int var2);

        public abstract void set(int var1, int var2);

        protected void init(MatrixBlock in) {
            if (in.isEmptyBlock(false)) {
                return;
            }
            if (SparsityEstimator.MULTI_THREADED_BUILD && in.getNonZeros() > 10240L) {
                int k = 4 * InfrastructureAnalyzer.getLocalParallelism();
                int blklen = (int)Math.ceil((double)this._rlen / (double)k);
                IntStream.range(0, k).parallel().forEach(i -> this.buildIntern(in, i * blklen, Math.min((i + 1) * blklen, this._rlen)));
            } else {
                this.buildIntern(in, 0, in.getNumRows());
            }
            this._nonZeros = in.getNonZeros();
        }

        public BitsetMatrix matMult(BitsetMatrix m2) {
            BitsetMatrix out = this.createBitSetMatrix(this._rlen, m2._clen);
            if (this.getNonZeros() == 0L || m2.getNonZeros() == 0L) {
                return out;
            }
            long size = (long)this._rlen * (long)this._clen + (long)m2._rlen * (long)m2._clen;
            if (SparsityEstimator.MULTI_THREADED_ESTIM && size > 10240L) {
                int k = 4 * InfrastructureAnalyzer.getLocalParallelism();
                int blklen = (int)Math.ceil((double)this._rlen / (double)k);
                out._nonZeros = IntStream.range(0, k).parallel().mapToLong(i -> this.matMultIntern(m2, out, i * blklen, Math.min((i + 1) * blklen, this._rlen))).sum();
            } else {
                out._nonZeros = this.matMultIntern(m2, out, 0, this._rlen);
            }
            return out;
        }

        protected abstract BitsetMatrix createBitSetMatrix(int var1, int var2);

        protected abstract void buildIntern(MatrixBlock var1, int var2, int var3);

        protected abstract long matMultIntern(BitsetMatrix var1, BitsetMatrix var2, int var3, int var4);

        protected abstract BitsetMatrix and(BitsetMatrix var1);

        protected abstract BitsetMatrix or(BitsetMatrix var1);

        protected abstract BitsetMatrix rbind(BitsetMatrix var1);

        protected abstract BitsetMatrix cbind(BitsetMatrix var1);

        protected abstract BitsetMatrix flip();

        public BitsetMatrix transpose() {
            BitsetMatrix1 ret = new BitsetMatrix1(this.getNumRows(), this.getNumColumns());
            for (int i = 0; i < this.getNumColumns(); ++i) {
                for (int k = 0; k < this.getNumRows(); ++k) {
                    if (!this.get(i, k)) continue;
                    ret.set(k, i);
                }
            }
            return ret;
        }
    }
}

