/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.SimpleNode;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.FlagManager;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.rexp.RDoubleVector;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RIntegerVector;
import org.jpmml.rexp.RNumberVector;
import org.jpmml.rexp.RStringVector;
import org.jpmml.rexp.RVectorUtil;
import org.jpmml.rexp.TreeModelConverter;

public class GBMConverter
extends TreeModelConverter<RGenericVector> {
    private static final List<Integer> BINARY_CLASSES = Arrays.asList(0, 1);

    public GBMConverter(RGenericVector gbm) {
        super(gbm);
    }

    @Override
    public void encodeSchema(RExpEncoder encoder) {
        DataField dataField;
        RGenericVector gbm = (RGenericVector)this.getObject();
        RGenericVector distribution = gbm.getGenericElement("distribution");
        RStringVector response_name = gbm.getStringElement("response.name", false);
        RGenericVector var_levels = gbm.getGenericElement("var.levels");
        RStringVector var_names = gbm.getStringElement("var.names");
        RNumberVector<?> var_type = gbm.getNumericElement("var.type");
        RStringVector classes = gbm.getStringElement("classes", false);
        RStringVector distributionName = distribution.getStringElement("name");
        RVectorUtil.checkSize(var_names, var_type);
        String responseName = response_name != null ? (String)response_name.asScalar() : "y";
        switch ((String)distributionName.asScalar()) {
            case "gaussian": {
                dataField = encoder.createDataField(responseName, OpType.CONTINUOUS, DataType.DOUBLE);
                break;
            }
            case "adaboost": 
            case "bernoulli": {
                dataField = encoder.createDataField(responseName, OpType.CATEGORICAL, DataType.INTEGER, BINARY_CLASSES);
                break;
            }
            case "multinomial": {
                dataField = encoder.createDataField(responseName, OpType.CATEGORICAL, DataType.STRING, classes.getValues());
                break;
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
        encoder.setLabel(dataField);
        for (int i = 0; i < var_names.size(); ++i) {
            DataField dataField2;
            boolean categorical;
            String varName = var_names.getValue(i);
            boolean bl = categorical = ValueUtil.asInt((Number)((Number)var_type.getValue(i))) > 0;
            if (categorical) {
                RStringVector var_level = var_levels.getStringValue(i);
                dataField2 = encoder.createDataField(varName, OpType.CATEGORICAL, DataType.STRING, var_level.getValues());
            } else {
                dataField2 = encoder.createDataField(varName, OpType.CONTINUOUS, DataType.DOUBLE);
            }
            encoder.addFeature((Field<?>)dataField2);
        }
    }

    public MiningModel encodeModel(Schema schema) {
        RGenericVector gbm = (RGenericVector)this.getObject();
        RDoubleVector initF = gbm.getDoubleElement("initF");
        RGenericVector trees = gbm.getGenericElement("trees");
        RGenericVector c_splits = gbm.getGenericElement("c.splits");
        RGenericVector distribution = gbm.getGenericElement("distribution");
        RStringVector distributionName = distribution.getStringElement("name");
        Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);
        ArrayList<TreeModel> treeModels = new ArrayList<TreeModel>();
        for (int i = 0; i < trees.size(); ++i) {
            RGenericVector tree = trees.getGenericValue(i);
            TreeModel treeModel = this.encodeTreeModel(MiningFunction.REGRESSION, tree, c_splits, segmentSchema);
            treeModels.add(treeModel);
        }
        MiningModel miningModel = this.encodeMiningModel(distributionName, treeModels, (Double)initF.asScalar(), schema);
        return miningModel;
    }

    private MiningModel encodeMiningModel(RStringVector distributionName, List<TreeModel> treeModels, Double initF, Schema schema) {
        switch ((String)distributionName.asScalar()) {
            case "gaussian": {
                return this.encodeRegression(treeModels, initF, schema);
            }
            case "adaboost": {
                return this.encodeBinaryClassification(treeModels, initF, -2.0, schema);
            }
            case "bernoulli": {
                return this.encodeBinaryClassification(treeModels, initF, -1.0, schema);
            }
            case "multinomial": {
                return this.encodeMultinomialClassification(treeModels, initF, schema);
            }
        }
        throw new IllegalArgumentException();
    }

    private MiningModel encodeRegression(List<TreeModel> treeModels, Double initF, Schema schema) {
        MiningModel miningModel = GBMConverter.createMiningModel(treeModels, initF, schema);
        return miningModel;
    }

    private MiningModel encodeBinaryClassification(List<TreeModel> treeModels, Double initF, double coefficient, Schema schema) {
        Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);
        MiningModel miningModel = GBMConverter.createMiningModel(treeModels, initF, segmentSchema).setOutput(ModelUtil.createPredictedOutput((String)"gbmValue", (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[0]));
        return MiningModelUtil.createBinaryLogisticClassification((Model)miningModel, (double)(-coefficient), (double)0.0, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.LOGIT, (boolean)true, (Schema)schema);
    }

    private MiningModel encodeMultinomialClassification(List<TreeModel> treeModels, Double initF, Schema schema) {
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);
        ArrayList<MiningModel> miningModels = new ArrayList<MiningModel>();
        int columns = categoricalLabel.size();
        int rows = treeModels.size() / columns;
        for (int i = 0; i < columns; ++i) {
            MiningModel miningModel = GBMConverter.createMiningModel(CMatrixUtil.getColumn(treeModels, (int)rows, (int)columns, (int)i), initF, segmentSchema).setOutput(ModelUtil.createPredictedOutput((String)FieldNameUtil.create((String)"gbmValue", (Object[])new Object[]{categoricalLabel.getValue(i)}), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[0]));
            miningModels.add(miningModel);
        }
        return MiningModelUtil.createClassification(miningModels, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.SOFTMAX, (boolean)true, (Schema)schema);
    }

    private TreeModel encodeTreeModel(MiningFunction miningFunction, RGenericVector tree, RGenericVector c_splits, Schema schema) {
        Node root = this.encodeNode(0, (Predicate)True.INSTANCE, tree, c_splits, new FlagManager(), new CategoryManager(), schema);
        TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema((Label)schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.MULTI_SPLIT);
        return treeModel;
    }

    private Node encodeNode(int i, Predicate predicate, RGenericVector tree, RGenericVector c_splits, FlagManager flagManager, CategoryManager categoryManager, Schema schema) {
        Integer right;
        Integer left;
        Predicate rightPredicate;
        Predicate leftPredicate;
        Integer id = i + 1;
        RIntegerVector splitVar = tree.getIntegerValue(0);
        RDoubleVector splitCodePred = tree.getDoubleValue(1);
        RIntegerVector leftNode = tree.getIntegerValue(2);
        RIntegerVector rightNode = tree.getIntegerValue(3);
        RIntegerVector missingNode = tree.getIntegerValue(4);
        RDoubleVector prediction = tree.getDoubleValue(7);
        Integer var = splitVar.getValue(i);
        if (var == -1) {
            Double value = prediction.getValue(i);
            SimpleNode result = new LeafNode((Object)value, predicate).setId((Object)id);
            return result;
        }
        FlagManager missingFlagManager = flagManager;
        FlagManager nonMissingFlagManager = flagManager;
        Feature feature = schema.getFeature(var.intValue());
        String name = feature.getName();
        Boolean isMissing = (Boolean)flagManager.getValue(name);
        if (isMissing == null) {
            missingFlagManager = missingFlagManager.fork(name, Boolean.TRUE);
            nonMissingFlagManager = nonMissingFlagManager.fork(name, Boolean.FALSE);
        }
        Predicate missingPredicate = this.createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null);
        CategoryManager leftCategoryManager = categoryManager;
        CategoryManager rightCategoryManager = categoryManager;
        Double split = splitCodePred.getValue(i);
        if (feature instanceof CategoricalFeature) {
            CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
            String name2 = categoricalFeature.getName();
            List values = categoricalFeature.getValues();
            int index = ValueUtil.asInt((Number)split);
            RIntegerVector c_split = c_splits.getIntegerValue(index);
            List<Integer> splitValues = c_split.getValues();
            java.util.function.Predicate valueFilter = categoryManager.getValueFilter(name2);
            List<Object> leftValues = GBMConverter.selectValues(values, valueFilter, splitValues, true);
            List<Object> rightValues = GBMConverter.selectValues(values, valueFilter, splitValues, false);
            leftCategoryManager = leftCategoryManager.fork(name2, leftValues);
            rightCategoryManager = rightCategoryManager.fork(name2, rightValues);
            leftPredicate = this.createPredicate((Feature)categoricalFeature, leftValues);
            rightPredicate = this.createPredicate((Feature)categoricalFeature, rightValues);
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            leftPredicate = this.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_THAN, split);
            rightPredicate = this.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, split);
        }
        SimpleNode result = new BranchNode(null, predicate).setId((Object)id);
        List nodes = result.getNodes();
        Integer missing = missingNode.getValue(i);
        if (missing != -1 && (isMissing == null || isMissing.booleanValue())) {
            Node missingChild = this.encodeNode(missing, missingPredicate, tree, c_splits, missingFlagManager, categoryManager, schema);
            nodes.add(missingChild);
        }
        if (!((left = leftNode.getValue(i)) == -1 || isMissing != null && isMissing.booleanValue())) {
            Node leftChild = this.encodeNode(left, leftPredicate, tree, c_splits, nonMissingFlagManager, leftCategoryManager, schema);
            nodes.add(leftChild);
        }
        if (!((right = rightNode.getValue(i)) == -1 || isMissing != null && isMissing.booleanValue())) {
            Node rightChild = this.encodeNode(right, rightPredicate, tree, c_splits, nonMissingFlagManager, rightCategoryManager, schema);
            nodes.add(rightChild);
        }
        return result;
    }

    private static MiningModel createMiningModel(List<TreeModel> treeModels, Double initF, Schema schema) {
        ContinuousLabel continuousLabel = (ContinuousLabel)schema.getLabel();
        MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)continuousLabel)).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.SUM, (Segmentation.MissingPredictionTreatment)Segmentation.MissingPredictionTreatment.RETURN_MISSING, treeModels)).setTargets(ModelUtil.createRescaleTargets(null, (Number)initF, (ContinuousLabel)continuousLabel));
        return miningModel;
    }

    private static List<Object> selectValues(List<?> values, java.util.function.Predicate<Object> valueFilter, List<Integer> splitValues, boolean left) {
        if (values.size() != splitValues.size()) {
            throw new IllegalArgumentException();
        }
        ArrayList<Object> result = new ArrayList<Object>();
        for (int i = 0; i < values.size(); ++i) {
            boolean append;
            Object value = values.get(i);
            Integer splitValue = splitValues.get(i);
            if (left) {
                append = splitValue == -1;
            } else {
                boolean bl = append = splitValue == 1;
            }
            if (!append || !valueFilter.test(value)) continue;
            result.add(value);
        }
        return result;
    }
}

