/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.scripts.utils;

import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Script;

public class Metrics
extends Script {
    public Metrics() {
        String string = "scripts/utils/metrics.dml";
        InputStream inputStream = Script.class.getResourceAsStream(new StringBuffer().append("/").append(string).toString());
        InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
        char[] cArray = new char[1024];
        StringBuilder stringBuilder = new StringBuilder();
        try {
            int n;
            while ((n = inputStreamReader.read(cArray)) > 0) {
                stringBuilder.append(cArray, 0, n);
            }
        }
        catch (IOException iOException) {
            iOException.printStackTrace();
        }
        this.setScriptString(stringBuilder.toString());
    }

    public String classification_report(Object object, Object object2, Object object3) {
        String string = "source('scripts/utils/metrics.dml') as mlcontextns;out = mlcontextns::classification_report(y_true, y_pred, labels);";
        Script script = new Script(string);
        script.in("y_true", object).in("y_pred", object2).in("labels", object3).out("out");
        MLResults mLResults = script.execute();
        String string2 = mLResults.getString("out");
        return string2;
    }

    public String classification_report__docs() {
        String string = "classification_report = function(matrix[double] y_true, matrix[double] y_pred, matrix[double] labels) return (string out) {\n\tnum_rows_error_measures = nrow(labels)\n\terror_measures = matrix(0, rows=num_rows_error_measures, cols=5)\n\tfor(i in 1:num_rows_error_measures) {\n\t\tclass_i = labels[i,1]\n        tp = sum( (y_true == y_pred) * (y_true == class_i) )\n        tp_plus_fp = sum( (y_pred == class_i) )\n        tp_plus_fn = sum( (y_true == class_i) )\n        precision = tp / tp_plus_fp\n        recall = tp / tp_plus_fn\n        f1Score = 2*precision*recall / (precision+recall)\n        error_measures[i,1] = class_i\n        error_measures[i,2] = precision\n        error_measures[i,3] = recall\n        error_measures[i,4] = f1Score\n        error_measures[i,5] = tp_plus_fn\n\t}\n\t# Added num_true_labels to debug whether the input data was randomized or now, which is common requirement of SGD-style algorithms.\n\t# Also, helps debug class-skew related problems.\n\tout = \"class    \\tprecision\\trecall  \\tf1-score\\tnum_true_labels\\n\" + toString(error_measures, decimal=7, sep=\"\\t\")\n}\n";
        return string;
    }

    public String classification_report__source() {
        String string = "classification_report = function(matrix[double] y_true, matrix[double] y_pred, matrix[double] labels) return (string out) {\n\tnum_rows_error_measures = nrow(labels)\n\terror_measures = matrix(0, rows=num_rows_error_measures, cols=5)\n\tfor(i in 1:num_rows_error_measures) {\n\t\tclass_i = labels[i,1]\n        tp = sum( (y_true == y_pred) * (y_true == class_i) )\n        tp_plus_fp = sum( (y_pred == class_i) )\n        tp_plus_fn = sum( (y_true == class_i) )\n        precision = tp / tp_plus_fp\n        recall = tp / tp_plus_fn\n        f1Score = 2*precision*recall / (precision+recall)\n        error_measures[i,1] = class_i\n        error_measures[i,2] = precision\n        error_measures[i,3] = recall\n        error_measures[i,4] = f1Score\n        error_measures[i,5] = tp_plus_fn\n\t}\n\t# Added num_true_labels to debug whether the input data was randomized or now, which is common requirement of SGD-style algorithms.\n\t# Also, helps debug class-skew related problems.\n\tout = \"class    \\tprecision\\trecall  \\tf1-score\\tnum_true_labels\\n\" + toString(error_measures, decimal=7, sep=\"\\t\")\n}\n";
        return string;
    }
}

