/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.matrix.data;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.matrix.data.DnnParameters;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNConv2d;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNPooling;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNRelu;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.SparseBlock;
import org.apache.sysml.runtime.util.CommonThreadPool;
import org.apache.sysml.runtime.util.DnnUtils;

public class LibMatrixDNN {
    protected static final Log LOG = LogFactory.getLog((String)LibMatrixDNN.class.getName());
    private static AtomicLong conv2dSparseCount = new AtomicLong(0L);
    private static AtomicLong conv2dDenseCount = new AtomicLong(0L);
    private static AtomicLong conv2dBwdFilterSparseCount = new AtomicLong(0L);
    private static AtomicLong conv2dBwdFilterDenseCount = new AtomicLong(0L);
    private static AtomicLong conv2dBwdDataSparseCount = new AtomicLong(0L);
    private static AtomicLong conv2dBwdDataDenseCount = new AtomicLong(0L);
    private static AtomicLong im2colSparseCount = new AtomicLong(0L);
    private static AtomicLong im2colDenseCount = new AtomicLong(0L);
    private static AtomicLong maxPoolBwdSparseCount = new AtomicLong(0L);
    private static AtomicLong maxPoolBwdDenseCount = new AtomicLong(0L);
    static AtomicLong loopedConvMatMultTime = new AtomicLong(0L);
    static AtomicLong loopedConvIm2ColTime = new AtomicLong(0L);
    static AtomicLong loopedConvBwdFilterMatMultTime = new AtomicLong(0L);
    static AtomicLong loopedConvBwdFilterIm2ColTime = new AtomicLong(0L);
    static AtomicLong loopedConvBwdDataMatMultTime = new AtomicLong(0L);
    static AtomicLong loopedConvBwdDataCol2ImTime = new AtomicLong(0L);

    public static void appendStatistics(StringBuilder sb) {
        if (DMLScript.FINEGRAINED_STATISTICS) {
            sb.append("LibMatrixDNN dense count (conv/bwdF/bwdD/im2col/maxBwd):\t" + conv2dDenseCount.get() + "/" + conv2dBwdFilterDenseCount.get() + "/" + conv2dBwdDataDenseCount.get() + "/" + im2colDenseCount.get() + "/" + maxPoolBwdDenseCount.get() + ".\n");
            sb.append("LibMatrixDNN sparse count (conv/bwdF/bwdD/im2col/maxBwd):\t" + conv2dSparseCount.get() + "/" + conv2dBwdFilterSparseCount.get() + "/" + conv2dBwdDataSparseCount.get() + "/" + im2colSparseCount.get() + "/" + maxPoolBwdSparseCount.get() + ".\n");
            sb.append("LibMatrixDNN conv(im2col/matmult), bwdF (im2col/matmult), bwdD (col2im/matmult) time:\t" + String.format("%.3f", (double)loopedConvIm2ColTime.get() * 1.0E-9) + "/" + String.format("%.3f", (double)loopedConvMatMultTime.get() * 1.0E-9) + "/" + String.format("%.3f", (double)loopedConvBwdFilterIm2ColTime.get() * 1.0E-9) + "/" + String.format("%.3f", (double)loopedConvBwdFilterMatMultTime.get() * 1.0E-9) + "/" + String.format("%.3f", (double)loopedConvBwdDataCol2ImTime.get() * 1.0E-9) + "/" + String.format("%.3f", (double)loopedConvBwdDataMatMultTime.get() * 1.0E-9) + " sec.\n");
        }
    }

    public static void resetStatistics() {
        conv2dDenseCount.set(0L);
        conv2dBwdFilterDenseCount.set(0L);
        conv2dBwdDataDenseCount.set(0L);
        im2colDenseCount.set(0L);
        maxPoolBwdDenseCount.set(0L);
        conv2dSparseCount.set(0L);
        conv2dBwdFilterSparseCount.set(0L);
        conv2dBwdDataSparseCount.set(0L);
        im2colSparseCount.set(0L);
        maxPoolBwdSparseCount.set(0L);
        loopedConvIm2ColTime.set(0L);
        loopedConvMatMultTime.set(0L);
        loopedConvBwdFilterMatMultTime.set(0L);
        loopedConvBwdFilterIm2ColTime.set(0L);
        loopedConvBwdDataMatMultTime.set(0L);
        loopedConvBwdDataCol2ImTime.set(0L);
    }

    public static void conv2d(MatrixBlock input, MatrixBlock filter, MatrixBlock outputBlock, DnnParameters params) {
        LibMatrixDNN.checkInputsConv2d(input, filter, outputBlock, params);
        if (params.bias != null && params.bias.isInSparseFormat()) {
            params.bias.sparseToDense();
        }
        long nnz = LibMatrixDNN.execute(LibMatrixDNNConv2d.getConv2dWorkers(params), params);
        outputBlock.setNonZeros(nnz);
        outputBlock.examSparsity();
    }

    public static void conv2dBackwardData(MatrixBlock filter, MatrixBlock dout, MatrixBlock outputBlock, DnnParameters params) {
        LibMatrixDNN.checkInputsConv2dBackwardData(filter, dout, outputBlock, params);
        long nnz = LibMatrixDNN.execute(LibMatrixDNNConv2d.getConv2dBackwardDataWorkers(params), params);
        outputBlock.setNonZeros(nnz);
        outputBlock.examSparsity();
    }

    public static void conv2dBackwardFilter(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, DnnParameters params) {
        LibMatrixDNN.checkInputsConv2dBackwardFilter(input, dout, outputBlock, params);
        LibMatrixDNN.execute(LibMatrixDNNConv2d.getConv2dBackwardFilterWorkers(params), params);
        outputBlock.recomputeNonZeros();
        outputBlock.examSparsity();
    }

    public static void pooling(MatrixBlock input, MatrixBlock output, DnnParameters params, PoolingType poolType) {
        params.input1 = input;
        params.output = output;
        if (input.getNumColumns() != params.C * params.H * params.W || input.getNumRows() != params.N) {
            throw new DMLRuntimeException("Incorrect input dimensions in maxpooling:" + input.getNumRows() + " " + input.getNumColumns() + " " + params.N + " " + params.C * params.H * params.W);
        }
        if (!params.isStride1Pad0() || input.sparse) {
            LibMatrixDNN.fillIndexesArray(params);
        }
        long nnz = LibMatrixDNN.execute(LibMatrixDNNPooling.getPoolingWorkers(params, poolType), params);
        output.setNonZeros(nnz);
        output.examSparsity();
    }

    public static void poolingBackward(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, DnnParameters params, boolean performReluBackward, PoolingType poolType) {
        params.input1 = input;
        params.input2 = dout;
        params.output = outputBlock;
        if (poolType == PoolingType.MAX && (input.getNumColumns() != params.C * params.H * params.W || input.getNumRows() != params.N)) {
            throw new DMLRuntimeException("Incorrect input dimensions in maxpooling_backward:" + input.getNumRows() + " " + input.getNumColumns() + " " + params.N + " " + params.K * params.P * params.Q);
        }
        if (dout.getNumColumns() != params.C * params.P * params.Q || dout.getNumRows() != params.N) {
            throw new DMLRuntimeException("Incorrect dout dimensions in pooling_backward:" + input.getNumRows() + " " + input.getNumColumns() + " " + params.N + " " + params.K * params.P * params.Q);
        }
        if (DMLScript.FINEGRAINED_STATISTICS) {
            boolean isSparse;
            boolean bl = poolType == PoolingType.MAX ? input.isInSparseFormat() || dout.isInSparseFormat() : (isSparse = dout.isInSparseFormat());
            if (isSparse) {
                maxPoolBwdSparseCount.addAndGet(1L);
            } else {
                maxPoolBwdDenseCount.addAndGet(1L);
            }
        }
        if (params.output.isInSparseFormat()) {
            throw new DMLRuntimeException("Sparse pooling_backward is not supported");
        }
        if (poolType == PoolingType.AVG) {
            LibMatrixDNN.fillIndexesArray(params);
        } else if (!params.input1.isInSparseFormat() || params.input2.isInSparseFormat()) {
            LibMatrixDNN.fillIndexesArray(params);
        }
        long nnz = LibMatrixDNN.execute(LibMatrixDNNPooling.getPoolingBackwardWorkers(params, performReluBackward, poolType), params);
        outputBlock.setNonZeros(nnz);
        outputBlock.examSparsity();
    }

    public static void reluBackward(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, int numThreads) {
        int N = input.getNumRows();
        DnnParameters params = new DnnParameters(N, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, numThreads);
        params.input1 = input;
        params.input2 = dout;
        params.output = outputBlock;
        if (input.getNumRows() != dout.getNumRows() || input.getNumColumns() != dout.getNumColumns()) {
            throw new DMLRuntimeException("Incorrect dimensions for relu_backward:" + input.getNumRows() + " != " + dout.getNumRows() + " || " + input.getNumColumns() + " != " + dout.getNumColumns());
        }
        long nnz = LibMatrixDNN.execute(LibMatrixDNNRelu.getReluBackwardWorkers(params), params);
        outputBlock.setNonZeros(nnz);
        outputBlock.examSparsity();
    }

    public static void biasAdd(MatrixBlock input, MatrixBlock bias, MatrixBlock outputBlock, int numThreads) {
        int N = input.getNumRows();
        int K = bias.getNumRows();
        int PQ = input.getNumColumns() / K;
        if (bias.getNumColumns() != 1 || input.getNumColumns() % K != 0) {
            throw new DMLRuntimeException("Incorrect inputs for bias_add: input[" + N + " X " + input.getNumColumns() + "] and bias[" + K + " X " + bias.getNumColumns() + "]");
        }
        double[] outputArray = outputBlock.getDenseBlockValues();
        if (input.isEmptyBlock()) {
            for (int n = 0; n < N; ++n) {
                DnnUtils.fillBias(bias, outputArray, n, n + 1, N, K, PQ);
            }
        } else {
            outputBlock.copy(input, false);
            if (bias.isInSparseFormat()) {
                bias.sparseToDense();
            }
            double[] biasArr = bias.getDenseBlockValues();
            LibMatrixDNN.addBias(outputArray, biasArr, 1.0, N, K, PQ);
        }
        outputBlock.recomputeNonZeros();
        outputBlock.examSparsity();
    }

    public static void channelSums(MatrixBlock input, MatrixBlock outputBlock, int C, int HW) {
        double[] output = outputBlock.getDenseBlockValues();
        if (input.isInSparseFormat()) {
            SparseBlock sblock = input.getSparseBlock();
            for (int n = 0; n < input.getNumRows(); ++n) {
                if (sblock.isEmpty(n)) continue;
                int apos = sblock.pos(n);
                int alen = sblock.size(n);
                int[] aix = sblock.indexes(n);
                double[] avals = sblock.values(n);
                for (int j = apos; j < apos + alen; ++j) {
                    int c;
                    int chw = aix[j];
                    int n2 = c = chw / HW;
                    output[n2] = output[n2] + avals[j];
                }
            }
        } else {
            double[] inArr = input.getDenseBlockValues();
            if (inArr != null) {
                KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
                for (int c = 0; c < C; ++c) {
                    KahanObject sum = new KahanObject(0.0, 0.0);
                    for (int n = 0; n < input.getNumRows(); ++n) {
                        int index = n * C * HW + c * HW;
                        int hw = 0;
                        while (hw < HW) {
                            kplus.execute2(sum, inArr[index]);
                            ++hw;
                            ++index;
                        }
                    }
                    output[c] = sum._sum;
                }
            }
        }
        outputBlock.recomputeNonZeros();
    }

    public static void batchNorm2DBackward(MatrixBlock image, MatrixBlock dout, MatrixBlock scale, double epsilon, MatrixBlock resultSaveMean, MatrixBlock resultSaveInvVariance, MatrixBlock dX, MatrixBlock dScale, MatrixBlock dBias) {
        int N = image.getNumRows();
        int K = scale.getNumRows();
        int PQ = image.getNumColumns() / K;
        LibMatrixDNN.channelSums(image, dBias, K, PQ);
        if (dBias.isInSparseFormat()) {
            dBias.sparseToDense();
        }
        if (dScale.isInSparseFormat()) {
            dScale.sparseToDense();
        }
        if (dX.isInSparseFormat()) {
            dX.sparseToDense();
        }
        if (resultSaveMean.isInSparseFormat()) {
            resultSaveMean.sparseToDense();
        }
        if (resultSaveInvVariance.isInSparseFormat()) {
            resultSaveInvVariance.sparseToDense();
        }
        if (scale.isInSparseFormat()) {
            scale.sparseToDense();
        }
        double[] dBiasArr = dBias.getDenseBlockValues();
        double[] dScaleArr = dScale.getDenseBlockValues();
        double[] dXArr = dX.getDenseBlockValues();
        double[] mean = resultSaveMean.getDenseBlockValues();
        double[] invVar = resultSaveInvVariance.getDenseBlockValues();
        double[] scaleArr = scale.getDenseBlockValues();
        mean = mean == null ? new double[K] : mean;
        invVar = invVar == null ? new double[K] : invVar;
        double[] dArray = scaleArr = scaleArr == null ? new double[K] : scaleArr;
        if (image.isInSparseFormat()) {
            image.sparseToDense();
        }
        if (dout.isInSparseFormat()) {
            dout.sparseToDense();
        }
        if (!image.isInSparseFormat() && !dout.isInSparseFormat()) {
            double[] imageArr = image.getDenseBlockValues();
            double[] doutArr = dout.getDenseBlockValues();
            double constant1 = Math.pow(N * PQ, -1.0);
            int KPQ = K * PQ;
            for (int k = 0; k < K; ++k) {
                double dvar = 0.0;
                double dmean_norm_branch = 0.0;
                double dmean_var_branch = 0.0;
                double sumDout = 0.0;
                double sum = 0.0;
                for (int n = 0; n < N; ++n) {
                    int index = n * KPQ + k * PQ;
                    int pq = 0;
                    while (pq < PQ) {
                        double doutVal = doutArr != null ? doutArr[index] : 0.0;
                        double centered = imageArr != null ? imageArr[index] : 0.0;
                        double dnorm = doutVal * scaleArr[k];
                        dvar -= 0.5 * (centered -= mean[k]) * Math.pow(invVar[k], 3.0) * dnorm;
                        dmean_norm_branch -= dnorm * invVar[k];
                        sum += centered * invVar[k] * doutVal;
                        sumDout += doutVal;
                        dmean_var_branch -= 2.0 * constant1 * centered;
                        ++pq;
                        ++index;
                    }
                }
                dBiasArr[k] = sumDout;
                dScaleArr[k] = sum;
                double dmean = dmean_norm_branch + (dmean_var_branch *= dvar);
                double dX_mean_branch = constant1 * dmean;
                for (int n = 0; n < N; ++n) {
                    int index = n * KPQ + k * PQ;
                    int pq = 0;
                    while (pq < PQ) {
                        double doutVal = doutArr != null ? doutArr[index] : 0.0;
                        double centered = imageArr != null ? imageArr[index] : 0.0;
                        double dnorm = doutVal * scaleArr[k];
                        double dX_norm_branch = dnorm * invVar[k];
                        double dX_var_branch = 2.0 * constant1 * (centered -= mean[k]) * dvar;
                        dXArr[index] = dX_norm_branch + dX_mean_branch + dX_var_branch;
                        ++pq;
                        ++index;
                    }
                }
            }
        } else {
            throw new DMLRuntimeException("Sparse format is not yet supported for batch norm backward");
        }
        dBias.recomputeNonZeros();
        dScale.recomputeNonZeros();
        dX.recomputeNonZeros();
    }

    public static void batchNorm2D(MatrixBlock image, MatrixBlock scale, MatrixBlock bias, MatrixBlock runningMean, MatrixBlock runningVar, String phase, double epsilon, double mu, MatrixBlock ret, MatrixBlock retRunningMean, MatrixBlock retRunningVar, MatrixBlock resultSaveMean, MatrixBlock resultSaveInvVariance) {
        if (bias.isInSparseFormat()) {
            bias.sparseToDense();
        }
        double[] biasArr = bias.getDenseBlockValues();
        if (scale.isInSparseFormat()) {
            scale.sparseToDense();
        }
        double[] scaleArr = scale.getDenseBlockValues();
        if (runningMean.isInSparseFormat()) {
            runningMean.sparseToDense();
        }
        double[] runningMeanArr = runningMean.getDenseBlockValues();
        if (runningVar.isInSparseFormat()) {
            runningVar.sparseToDense();
        }
        double[] runningVarArr = runningVar.getDenseBlockValues();
        double[] retRunningMeanArr = retRunningMean.getDenseBlockValues();
        double[] retRunningVarArr = retRunningVar.getDenseBlockValues();
        double[] resultSaveMeanArr = resultSaveMean.getDenseBlockValues();
        double[] resultSaveInvVarianceArr = resultSaveInvVariance.getDenseBlockValues();
        int N = image.getNumRows();
        int K = bias.getNumRows();
        int PQ = image.getNumColumns() / K;
        if (phase.equalsIgnoreCase("train")) {
            LibMatrixDNN.computeBiasSumAndSumSquares(image, resultSaveMeanArr, resultSaveInvVarianceArr, K, PQ);
            int NPQ = N * PQ;
            for (int k = 0; k < K; ++k) {
                double mean = resultSaveMeanArr[k] / (double)NPQ;
                double var = resultSaveInvVarianceArr[k] / (double)NPQ - Math.pow(mean, 2.0);
                resultSaveMeanArr[k] = mean;
                resultSaveInvVarianceArr[k] = Math.pow(Math.sqrt(var + epsilon), -1.0);
                retRunningMeanArr[k] = mu * (runningMeanArr != null ? runningMeanArr[k] : 0.0) + (1.0 - mu) * mean;
                retRunningVarArr[k] = mu * (runningVarArr != null ? runningVarArr[k] : 0.0) + (1.0 - mu) * mean;
            }
        } else if (phase.equalsIgnoreCase("test")) {
            LibMatrixDNN.copy(runningMean, retRunningMeanArr);
            LibMatrixDNN.copy(runningVar, retRunningVarArr);
            LibMatrixDNN.copy(runningMean, resultSaveMeanArr);
            double invSqrtEps = Math.pow(Math.sqrt(epsilon), -1.0);
            double[] inArr = runningVar.getDenseBlockValues();
            if (inArr != null) {
                for (int i = 0; i < inArr.length; ++i) {
                    resultSaveInvVarianceArr[i] = Math.pow(Math.sqrt(inArr[i] + epsilon), -1.0);
                }
            } else {
                Arrays.fill(resultSaveInvVarianceArr, invSqrtEps);
            }
        } else {
            throw new DMLRuntimeException("Incorrect mode: Expected either train or test, but found " + phase);
        }
        double[] retArr = ret.getDenseBlockValues();
        LibMatrixDNN.copy(image, retArr);
        if (resultSaveMean != null && resultSaveInvVariance != null && biasArr != null && scaleArr != null) {
            int index = 0;
            for (int n = 0; n < N; ++n) {
                for (int k = 0; k < K; ++k) {
                    int pq = 0;
                    while (pq < PQ) {
                        retArr[index] = (retArr[index] - resultSaveMeanArr[k]) * resultSaveInvVarianceArr[k] * scaleArr[k] + biasArr[k];
                        ++pq;
                        ++index;
                    }
                }
            }
        } else {
            LibMatrixDNN.addBias(retArr, resultSaveMeanArr, -1.0, N, K, PQ);
            LibMatrixDNN.multBias(retArr, resultSaveInvVarianceArr, N, K, PQ);
            LibMatrixDNN.multBias(retArr, scaleArr, N, K, PQ);
            LibMatrixDNN.addBias(retArr, biasArr, 1.0, N, K, PQ);
        }
        ret.recomputeNonZeros();
        retRunningMean.recomputeNonZeros();
        retRunningVar.recomputeNonZeros();
        resultSaveMean.recomputeNonZeros();
        resultSaveInvVariance.recomputeNonZeros();
    }

    private static void copy(MatrixBlock input, double[] output) {
        if (input.isInSparseFormat()) {
            SparseBlock sblock = input.getSparseBlock();
            int numCols = input.getNumColumns();
            for (int n = 0; n < input.getNumRows(); ++n) {
                if (sblock.isEmpty(n)) continue;
                int apos = sblock.pos(n);
                int alen = sblock.size(n);
                int[] aix = sblock.indexes(n);
                double[] avals = sblock.values(n);
                for (int j = apos; j < apos + alen; ++j) {
                    output[n * numCols + aix[j]] = avals[j];
                }
            }
        } else {
            double[] inputArr = input.getDenseBlockValues();
            if (inputArr != null) {
                System.arraycopy(inputArr, 0, output, 0, inputArr.length);
            }
        }
    }

    public static void addBias(double[] a, double[] bias, double biasMultiplier, int N, int K, int PQ) {
        if (bias == null) {
            return;
        }
        int index = 0;
        for (int n = 0; n < N; ++n) {
            for (int k = 0; k < K; ++k) {
                double biasVal = biasMultiplier * bias[k];
                for (int pq = 0; pq < PQ; ++pq) {
                    int n2 = index++;
                    a[n2] = a[n2] + biasVal;
                }
            }
        }
    }

    public static void multBias(double[] a, double[] bias, int N, int K, int PQ) {
        if (bias == null) {
            Arrays.fill(a, 0.0);
            return;
        }
        int index = 0;
        for (int n = 0; n < N; ++n) {
            for (int k = 0; k < K; ++k) {
                double biasVal = bias[k];
                for (int pq = 0; pq < PQ; ++pq) {
                    int n2 = index++;
                    a[n2] = a[n2] * biasVal;
                }
            }
        }
    }

    private static void computeBiasSumAndSumSquares(MatrixBlock image, double[] sumArr, double[] sumSquaresArr, int K, int PQ) {
        block8: {
            block7: {
                if (sumArr.length != K) {
                    throw new DMLRuntimeException("Expected the length of array to be " + K + ", but instead is " + sumArr.length);
                }
                if (sumSquaresArr.length != K) {
                    throw new DMLRuntimeException("Expected the length of array to be " + K + ", but instead is " + sumSquaresArr.length);
                }
                if (!image.isInSparseFormat()) break block7;
                SparseBlock sblock = image.getSparseBlock();
                for (int r = 0; r < image.getNumRows(); ++r) {
                    if (sblock.isEmpty(r)) continue;
                    int apos = sblock.pos(r);
                    int alen = sblock.size(r);
                    int[] aix = sblock.indexes(r);
                    double[] avals = sblock.values(r);
                    for (int j = apos; j < apos + alen; ++j) {
                        int k;
                        int n = k = aix[j] / PQ;
                        sumArr[n] = sumArr[n] + avals[j];
                        int n2 = k;
                        sumSquaresArr[n2] = sumSquaresArr[n2] + Math.pow(avals[j], 2.0);
                    }
                }
                break block8;
            }
            double[] X = image.getDenseBlockValues();
            int N = image.getNumRows();
            if (X == null) break block8;
            int index = 0;
            for (int n = 0; n < N; ++n) {
                for (int k = 0; k < K; ++k) {
                    int pq = 0;
                    while (pq < PQ) {
                        int n3 = k;
                        sumArr[n3] = sumArr[n3] + X[index];
                        int n4 = k;
                        sumSquaresArr[n4] = sumSquaresArr[n4] + Math.pow(X[index], 2.0);
                        ++pq;
                        ++index;
                    }
                }
            }
        }
    }

    public static void biasMultiply(MatrixBlock input, MatrixBlock bias, MatrixBlock outputBlock, int numThreads) {
        int N = input.getNumRows();
        int K = bias.getNumRows();
        int PQ = input.getNumColumns() / K;
        DnnParameters params = new DnnParameters(N, PQ, -1, -1, K, -1, -1, -1, -1, -1, -1, numThreads);
        params.input1 = input;
        params.input2 = bias;
        params.output = outputBlock;
        if (bias.getNumColumns() != 1 || input.getNumColumns() % K != 0) {
            throw new DMLRuntimeException("Incorrect inputs for bias_multiply: input[" + N + " X " + input.getNumColumns() + "] and bias[" + K + " X " + bias.getNumColumns() + "]");
        }
        if (!input.isEmptyBlock() && !bias.isEmptyBlock()) {
            outputBlock.copy(input);
            if (bias.isInSparseFormat()) {
                bias.sparseToDense();
            }
            double[] biasArr = bias.getDenseBlockValues();
            if (!input.isInSparseFormat()) {
                double[] outputArray = outputBlock.getDenseBlockValues();
                int index = 0;
                for (int n = 0; n < N; ++n) {
                    for (int k = 0; k < K; ++k) {
                        double biasVal = biasArr[k];
                        for (int pq = 0; pq < PQ; ++pq) {
                            int n2 = index++;
                            outputArray[n2] = outputArray[n2] * biasVal;
                        }
                    }
                }
            } else {
                SparseBlock sblock = outputBlock.sparseBlock;
                for (int k = 0; k < K; ++k) {
                    if (biasArr[k] != 0.0) continue;
                    for (int n = 0; n < N; ++n) {
                        if (sblock.isEmpty(n)) continue;
                        sblock.deleteIndexRange(n, k * PQ, (k + 1) * PQ);
                    }
                }
                for (int n = 0; n < N; ++n) {
                    if (sblock.isEmpty(n)) continue;
                    int apos = sblock.pos(n);
                    int alen = sblock.size(n);
                    int[] aix = sblock.indexes(n);
                    double[] avals = sblock.values(n);
                    for (int j = apos; j < apos + alen; ++j) {
                        int k = aix[j] / PQ;
                        if (biasArr[k] == 0.0) continue;
                        int n3 = j;
                        avals[n3] = avals[n3] * biasArr[k];
                    }
                }
            }
            params.output.recomputeNonZeros();
            params.output.examSparsity();
        } else {
            params.output.setNonZeros(0L);
        }
    }

    private static long execute(ArrayList<Callable<Long>> tasks, DnnParameters params) {
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        long lnnz = 0L;
        try {
            if (k == 1) {
                for (Callable<Long> task : tasks) {
                    lnnz += task.call().longValue();
                }
            } else {
                ExecutorService pool = CommonThreadPool.get(Math.min(k, params.N));
                List<Future<Long>> taskret = pool.invokeAll(tasks);
                pool.shutdown();
                for (Future<Long> task : taskret) {
                    lnnz += task.get().longValue();
                }
            }
        }
        catch (Exception e) {
            throw new DMLRuntimeException("Error while executing multi-threaded tasks", e);
        }
        return lnnz;
    }

    private static void checkOrThrowException(String msg, long lhs, long rhs) {
        if (lhs != rhs) {
            throw new DMLRuntimeException(msg + ":" + lhs + " != " + rhs);
        }
    }

    private static void checkOrThrowException(String msg, long lhs, long rhs1, long rhs2, long rhs3) {
        if (lhs != rhs1 * rhs2 * rhs3) {
            throw new DMLRuntimeException(msg + ":" + lhs + " != (" + rhs1 + " * " + rhs2 + " * " + rhs3);
        }
    }

    static void checkInputsConv2dBackwardData(MatrixBlock filter, MatrixBlock dout, MatrixBlock outputBlock, DnnParameters params) {
        params.input1 = filter;
        params.input2 = dout;
        params.output = outputBlock;
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d_backward_data: Number of rows of input filter != number of filters in filter_shape", filter.getNumRows(), params.K);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d_backward_data: Number of columns of input filter != channels*filter_height*filter_height in filter_shape", filter.getNumColumns(), params.C, params.R, params.S);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d_backward_data: Number of rows of input errors != batch size in input_shape", dout.getNumRows(), params.N);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d_backward_data: Number of columns of input errors != expected input error channels*height*width", dout.getNumColumns(), params.K, params.P, params.Q);
        if (params.stride_h <= 0 || params.stride_w <= 0) {
            throw new DMLRuntimeException("Only positive strides supported:" + params.stride_h + ", " + params.stride_w);
        }
        if (DMLScript.FINEGRAINED_STATISTICS) {
            if (filter.isInSparseFormat() || dout.isInSparseFormat()) {
                conv2dBwdDataSparseCount.addAndGet(1L);
            } else {
                conv2dBwdDataDenseCount.addAndGet(1L);
            }
        }
    }

    static void checkInputsConv2dBackwardFilter(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, DnnParameters params) {
        params.input1 = input;
        params.input2 = dout;
        params.output = outputBlock;
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d_backward_filter: Number of rows of input data != batch size in input_shape", input.getNumRows(), params.N);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d_backward_filter: Number of columns of input data != channels*input_height*input_height in input_shape", input.getNumColumns(), params.C, params.H, params.W);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d_backward_filter: Number of rows of input errors != batch size in input_shape", dout.getNumRows(), params.N);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d_backward_filter: Number of columns of input errors != expected input error channels*height*width", dout.getNumColumns(), params.K, params.P, params.Q);
        if (params.stride_h <= 0 || params.stride_w <= 0) {
            throw new DMLRuntimeException("Only positive strides supported:" + params.stride_h + ", " + params.stride_w);
        }
        if (DMLScript.FINEGRAINED_STATISTICS) {
            if (input.isInSparseFormat() || dout.isInSparseFormat()) {
                conv2dBwdFilterSparseCount.addAndGet(1L);
            } else {
                conv2dBwdFilterDenseCount.addAndGet(1L);
            }
        }
    }

    static void checkInputsConv2d(MatrixBlock input, MatrixBlock filter, MatrixBlock outputBlock, DnnParameters params) {
        params.input1 = input;
        params.input2 = filter;
        params.output = outputBlock;
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d: Number of rows of input filter != number of filters in filter_shape", filter.getNumRows(), params.K);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d: Number of columns of input filter != channels*filter_height*filter_height in filter_shape", filter.getNumColumns(), params.C, params.R, params.S);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d: Number of rows of input data != batch size in input_shape", input.getNumRows(), params.N);
        LibMatrixDNN.checkOrThrowException("Incorrect input to conv2d: Number of columns of input data != channels*input_height*input_height in input_shape", input.getNumColumns(), params.C, params.H, params.W);
        if (params.stride_h <= 0 || params.stride_w <= 0) {
            throw new DMLRuntimeException("Only positive strides supported:" + params.stride_h + ", " + params.stride_w);
        }
        if (DMLScript.FINEGRAINED_STATISTICS) {
            if (input.isInSparseFormat() || filter.isInSparseFormat()) {
                conv2dSparseCount.addAndGet(1L);
            } else {
                conv2dDenseCount.addAndGet(1L);
            }
        }
    }

    private static void fillIndexesArray(DnnParameters params) {
        params.start_indexes_h = new int[params.P];
        params.end_indexes_h = new int[params.P];
        params.start_indexes_w = new int[params.Q];
        params.end_indexes_w = new int[params.Q];
        int p = 0;
        int ix = -params.pad_h;
        while (p < params.P) {
            params.start_indexes_h[p] = Math.max(ix, 0);
            params.end_indexes_h[p] = Math.min(ix + params.R, params.H);
            ++p;
            ix += params.stride_h;
        }
        int q = 0;
        ix = -params.pad_w;
        while (q < params.Q) {
            params.start_indexes_w[q] = Math.max(ix, 0);
            params.end_indexes_w[q] = Math.min(ix + params.S, params.W);
            ++q;
            ix += params.stride_w;
        }
    }

    public static enum PoolingType {
        MAX,
        AVG;

    }
}

