/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.FortranMatrixUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.rexp.GLMNetConverter;
import org.jpmml.rexp.RDoubleVector;
import org.jpmml.rexp.RExp;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RIntegerVector;
import org.jpmml.rexp.S4Object;

public class MultNetConverter
extends GLMNetConverter {
    public MultNetConverter(RGenericVector multnet) {
        super(multnet);
    }

    @Override
    public Model encodeModel(RDoubleVector a0, RExp beta, int column, Schema schema) {
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        RIntegerVector a0Dim = a0.dim();
        int a0Rows = a0Dim.getValue(0);
        int a0Columns = a0Dim.getValue(1);
        RGenericVector categoryBetas = (RGenericVector)beta;
        if (categoricalLabel.size() == 2) {
            List categoryA0 = FortranMatrixUtil.getRow(a0.getValues(), (int)a0Rows, (int)a0Columns, (int)1);
            S4Object categoryBeta = (S4Object)categoryBetas.getValue(1);
            Function<Double, Double> function = new Function<Double, Double>(){

                public Double apply(Double value) {
                    return 2.0 * value;
                }
            };
            Double intercept = (Double)function.apply(categoryA0.get(column));
            List coefficients = Lists.transform(MultNetConverter.getCoefficients(categoryBeta, column), (Function)function);
            return RegressionModelUtil.createBinaryLogisticClassification((List)schema.getFeatures(), (List)coefficients, (Number)intercept, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.LOGIT, (boolean)true, (Schema)schema);
        }
        if (categoricalLabel.size() > 2) {
            RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel), null).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX).setOutput(ModelUtil.createProbabilityOutput((DataType)DataType.DOUBLE, (DiscreteLabel)categoricalLabel));
            for (int i = 0; i < categoricalLabel.size(); ++i) {
                Object targetCategory = categoricalLabel.getValue(i);
                List categoryA0 = FortranMatrixUtil.getRow(a0.getValues(), (int)a0Rows, (int)a0Columns, (int)i);
                S4Object categoryBeta = (S4Object)categoryBetas.getElement(ValueUtil.asString((Object)targetCategory));
                Double intercept = (Double)categoryA0.get(column);
                List<Double> coefficients = MultNetConverter.getCoefficients(categoryBeta, column);
                RegressionTable regressionTable = RegressionModelUtil.createRegressionTable((List)schema.getFeatures(), coefficients, (Number)intercept).setTargetCategory(targetCategory);
                regressionModel.addRegressionTables(new RegressionTable[]{regressionTable});
            }
            return regressionModel;
        }
        throw new IllegalArgumentException();
    }
}

