/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.matrix.data;

import jcuda.Pointer;
import jcuda.runtime.JCuda;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.instructions.gpu.context.CSRPointer;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNN;
import org.apache.sysml.utils.GPUStatistics;

public class LibMatrixCuDNNInputRowFetcher
extends LibMatrixCUDA
implements AutoCloseable {
    GPUContext gCtx;
    String instName;
    int numColumns;
    boolean isInputInSparseFormat;
    Object inPointer;
    Pointer outPointer;

    public LibMatrixCuDNNInputRowFetcher(GPUContext gCtx, String instName, MatrixObject image) {
        this.gCtx = gCtx;
        this.instName = instName;
        this.numColumns = LibMatrixCUDA.toInt(image.getNumColumns());
        this.isInputInSparseFormat = LibMatrixCUDA.isInSparseFormat(gCtx, image);
        this.inPointer = this.isInputInSparseFormat ? LibMatrixCUDA.getSparsePointer(gCtx, image, instName) : LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, image, instName);
        this.outPointer = gCtx.allocate(instName, this.numColumns * sizeOfDataType);
    }

    public Pointer getNthRow(int n) {
        if (this.isInputInSparseFormat) {
            JCuda.cudaDeviceSynchronize();
            long t0 = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
            JCuda.cudaMemset((Pointer)this.outPointer, (int)0, (long)(this.numColumns * sizeOfDataType));
            JCuda.cudaDeviceSynchronize();
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(this.instName, "az", System.nanoTime() - t0);
            }
            LibMatrixCUDA.sliceSparseDense(this.gCtx, this.instName, (CSRPointer)this.inPointer, this.outPointer, n, n, 0, LibMatrixCUDA.toInt(this.numColumns - 1), this.numColumns);
        } else {
            LibMatrixCUDA.sliceDenseDense(this.gCtx, this.instName, (Pointer)this.inPointer, this.outPointer, n, n, 0, LibMatrixCUDA.toInt(this.numColumns - 1), this.numColumns);
        }
        return this.outPointer;
    }

    @Override
    public void close() {
        try {
            this.gCtx.cudaFreeHelper(null, this.outPointer, true);
        }
        catch (DMLRuntimeException e) {
            throw new RuntimeException(e);
        }
    }
}

