/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.meta;

import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.core.driftdetection.ADWIN;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.options.ClassOption;
import moa.options.FlagOption;
import moa.options.FloatOption;
import moa.options.IntOption;
import weka.core.Instance;

public class OzaBoostAdwin
extends AbstractClassifier {
    private static final long serialVersionUID = 1L;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "trees.HoeffdingTree");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of models to boost.", 10, 1, Integer.MAX_VALUE);
    public FlagOption pureBoostOption = new FlagOption("pureBoost", 'p', "Boost with weights only; no poisson.");
    public FloatOption deltaAdwinOption = new FloatOption("deltaAdwin", 'a', "Delta of Adwin change detection", 0.002, 0.0, 1.0);
    public FlagOption outputCodesOption = new FlagOption("outputCodes", 'o', "Use Output Codes to use binary classifiers.");
    public FlagOption sammeOption = new FlagOption("same", 'e', "Use Samme Algorithm.");
    protected Classifier[] ensemble;
    protected double[] scms;
    protected double[] swms;
    protected ADWIN[] ADError;
    protected int numberOfChangesDetected;
    protected int[][] matrixCodes;
    protected boolean initMatrixCodes = false;
    protected double logKm1 = 0.0;
    protected int Km1 = 1;
    protected boolean initKm1 = false;

    public String getPurposeString() {
        return "Boosting for evolving data streams using ADWIN.";
    }

    public void resetLearningImpl() {
        int i;
        this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
        Classifier baseLearner = (Classifier)this.getPreparedClassOption(this.baseLearnerOption);
        baseLearner.resetLearning();
        for (i = 0; i < this.ensemble.length; ++i) {
            this.ensemble[i] = baseLearner.copy();
        }
        this.scms = new double[this.ensemble.length];
        this.swms = new double[this.ensemble.length];
        this.ADError = new ADWIN[this.ensemble.length];
        for (i = 0; i < this.ensemble.length; ++i) {
            this.ADError[i] = new ADWIN(this.deltaAdwinOption.getValue());
        }
        this.numberOfChangesDetected = 0;
        if (this.outputCodesOption.isSet()) {
            this.initMatrixCodes = true;
        }
        if (this.sammeOption.isSet()) {
            this.initKm1 = true;
        }
    }

    public void trainOnInstanceImpl(Instance inst) {
        int numClasses = inst.numClasses();
        if (this.sammeOption.isSet()) {
            this.Km1 = numClasses - 1;
            this.logKm1 = Math.log(this.Km1);
            this.initKm1 = false;
        }
        if (this.initMatrixCodes) {
            this.matrixCodes = new int[this.ensemble.length][inst.numClasses()];
            for (int i = 0; i < this.ensemble.length; ++i) {
                int numberZeros;
                int numberOnes;
                do {
                    numberOnes = 0;
                    numberZeros = 0;
                    for (int j = 0; j < numClasses; ++j) {
                        int result = 0;
                        result = j == 1 && numClasses == 2 ? 1 - this.matrixCodes[i][0] : (this.classifierRandom.nextBoolean() ? 1 : 0);
                        this.matrixCodes[i][j] = result;
                        if (result == 1) {
                            ++numberOnes;
                            continue;
                        }
                        ++numberZeros;
                    }
                } while ((numberOnes - numberZeros) * (numberOnes - numberZeros) > this.ensemble.length % 2);
            }
            this.initMatrixCodes = false;
        }
        boolean Change = false;
        double lambda_d = 1.0;
        Instance weightedInst = (Instance)inst.copy();
        for (int i = 0; i < this.ensemble.length; ++i) {
            boolean correctlyClassifies;
            double k;
            double d = k = this.pureBoostOption.isSet() ? lambda_d : (double)MiscUtils.poisson(lambda_d * (double)this.Km1, this.classifierRandom);
            if (k > 0.0) {
                if (this.outputCodesOption.isSet()) {
                    weightedInst.setClassValue(this.matrixCodes[i][(int)inst.classValue()]);
                }
                weightedInst.setWeight(inst.weight() * k);
                this.ensemble[i].trainOnInstance(weightedInst);
            }
            if (correctlyClassifies = this.ensemble[i].correctlyClassifies(weightedInst)) {
                int n = i;
                this.scms[n] = this.scms[n] + lambda_d;
                lambda_d *= this.trainingWeightSeenByModel / (2.0 * this.scms[i]);
            } else {
                int n = i;
                this.swms[n] = this.swms[n] + lambda_d;
                lambda_d *= this.trainingWeightSeenByModel / (2.0 * this.swms[i]);
            }
            double ErrEstim = this.ADError[i].getEstimation();
            if (!this.ADError[i].setInput(correctlyClassifies ? 0.0 : 1.0) || !(this.ADError[i].getEstimation() > ErrEstim)) continue;
            Change = true;
        }
        if (Change) {
            ++this.numberOfChangesDetected;
            double max = 0.0;
            int imax = -1;
            for (int i = 0; i < this.ensemble.length; ++i) {
                if (!(max < this.ADError[i].getEstimation())) continue;
                max = this.ADError[i].getEstimation();
                imax = i;
            }
            if (imax != -1) {
                this.ensemble[imax].resetLearning();
                this.ADError[imax] = new ADWIN(this.deltaAdwinOption.getValue());
                this.scms[imax] = 0.0;
                this.swms[imax] = 0.0;
            }
        }
    }

    protected double getEnsembleMemberWeight(int i) {
        double em = this.swms[i] / (this.scms[i] + this.swms[i]);
        if (em == 0.0 || em > 0.5) {
            return this.logKm1;
        }
        return Math.log((1.0 - em) / em) + this.logKm1;
    }

    public double[] getVotesForInstance(Instance inst) {
        double memberWeight;
        if (this.outputCodesOption.isSet()) {
            return this.getVotesForInstanceBinary(inst);
        }
        DoubleVector combinedVote = new DoubleVector();
        for (int i = 0; i < this.ensemble.length && (memberWeight = this.getEnsembleMemberWeight(i)) > 0.0; ++i) {
            DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(inst));
            if (!(vote.sumOfValues() > 0.0)) continue;
            vote.normalize();
            vote.scaleValues(memberWeight);
            combinedVote.addValues(vote);
        }
        return combinedVote.getArrayRef();
    }

    public double[] getVotesForInstanceBinary(Instance inst) {
        double[] combinedVote = new double[inst.numClasses()];
        Instance weightedInst = (Instance)inst.copy();
        if (!this.initMatrixCodes) {
            for (int i = 0; i < this.ensemble.length; ++i) {
                weightedInst.setClassValue(this.matrixCodes[i][(int)inst.classValue()]);
                double[] vote = this.ensemble[i].getVotesForInstance(weightedInst);
                int voteClass = 0;
                if (vote.length == 2) {
                    voteClass = vote[1] > vote[0] ? 1 : 0;
                }
                for (int j = 0; j < inst.numClasses(); ++j) {
                    if (this.matrixCodes[i][j] != voteClass) continue;
                    int n = j;
                    combinedVote[n] = combinedVote[n] + this.getEnsembleMemberWeight(i);
                }
            }
        }
        return combinedVote;
    }

    public boolean isRandomizable() {
        return true;
    }

    public void getModelDescription(StringBuilder out, int indent) {
    }

    protected Measurement[] getModelMeasurementsImpl() {
        return new Measurement[]{new Measurement("ensemble size", this.ensemble != null ? (double)this.ensemble.length : 0.0), new Measurement("change detections", this.numberOfChangesDetected)};
    }

    public Classifier[] getSubClassifiers() {
        return (Classifier[])this.ensemble.clone();
    }
}

