/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.meta;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.core.Measurement;
import moa.options.ClassOption;
import moa.options.FlagOption;
import moa.options.IntOption;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;

public class TemporallyAugmentedClassifier
extends AbstractClassifier {
    private static final long serialVersionUID = 1L;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "trees.HoeffdingTree");
    public IntOption numOldLabelsOption = new IntOption("numOldLabels", 'n', "The number of old labels to add to each example.", 1, 0, Integer.MAX_VALUE);
    protected Classifier baseLearner;
    protected double[] oldLabels;
    protected Instances header;
    public FlagOption labelDelayOption = new FlagOption("labelDelay", 'd', "Labels arrive with Delay. Use predictions instead of true Labels.");

    public String getPurposeString() {
        return "Add some old labels to every instance";
    }

    public void resetLearningImpl() {
        this.baseLearner = (Classifier)this.getPreparedClassOption(this.baseLearnerOption);
        this.oldLabels = new double[this.numOldLabelsOption.getValue()];
        this.header = null;
        this.baseLearner.resetLearning();
    }

    public void trainOnInstanceImpl(Instance instance) {
        this.baseLearner.trainOnInstance(this.extendWithOldLabels(instance));
        if (!this.labelDelayOption.isSet()) {
            this.addOldLabel(instance.classValue());
        }
    }

    public void addOldLabel(double newPrediction) {
        int numLabels = this.oldLabels.length;
        if (numLabels > 0) {
            for (int i = 1; i < numLabels; ++i) {
                this.oldLabels[i - 1] = this.oldLabels[i];
            }
            this.oldLabels[numLabels - 1] = newPrediction;
        }
    }

    public void initHeader(Instances dataset) {
        int i;
        int numLabels = this.numOldLabelsOption.getValue();
        Attribute target = dataset.classAttribute();
        ArrayList<String> possibleValues = new ArrayList<String>();
        int n = target.numValues();
        for (int i2 = 0; i2 < n; ++i2) {
            possibleValues.add(target.value(i2));
        }
        ArrayList<Attribute> attrs = new ArrayList<Attribute>(numLabels + dataset.numAttributes());
        for (i = 0; i < numLabels; ++i) {
            attrs.add(new Attribute(target.name() + "_" + i, possibleValues));
        }
        for (i = 0; i < dataset.numAttributes(); ++i) {
            attrs.add((Attribute)dataset.attribute(i).copy());
        }
        this.header = new Instances("extended_" + dataset.relationName(), attrs, 0);
        this.header.setClassIndex(numLabels + dataset.classIndex());
    }

    public Instance extendWithOldLabels(Instance instance) {
        int numLabels;
        if (this.header == null) {
            this.initHeader(instance.dataset());
        }
        if ((numLabels = this.oldLabels.length) == 0) {
            return instance;
        }
        double[] x = instance.toDoubleArray();
        double[] x2 = Arrays.copyOfRange(this.oldLabels, 0, numLabels + x.length);
        System.arraycopy(x, 0, x2, numLabels, x.length);
        DenseInstance extendedInstance = new DenseInstance(instance.weight(), x2);
        extendedInstance.setDataset(this.header);
        return extendedInstance;
    }

    public double[] getVotesForInstance(Instance instance) {
        double[] prediction = this.baseLearner.getVotesForInstance(this.extendWithOldLabels(instance));
        if (this.labelDelayOption.isSet()) {
            this.addOldLabel(Utils.maxIndex(prediction));
        }
        return prediction;
    }

    public boolean isRandomizable() {
        return false;
    }

    protected Measurement[] getModelMeasurementsImpl() {
        LinkedList<Measurement> measurementList = new LinkedList<Measurement>();
        Measurement[] modelMeasurements = ((AbstractClassifier)this.baseLearner).getModelMeasurements();
        if (modelMeasurements != null) {
            for (Measurement measurement : modelMeasurements) {
                measurementList.add(measurement);
            }
        }
        return measurementList.toArray(new Measurement[measurementList.size()]);
    }

    public void getModelDescription(StringBuilder out, int indent) {
    }

    public String toString() {
        return "TemporallyAugmentedClassifier using " + this.numOldLabelsOption.getValue() + " labels\n" + this.baseLearner;
    }
}

